|
52 | 52 | from diffusers.optimization import get_scheduler |
53 | 53 | from diffusers.utils import check_min_version, is_wandb_available, make_image_grid |
54 | 54 | from diffusers.utils.import_utils import is_xformers_available |
| 55 | +from diffusers.utils.torch_utils import is_compiled_module |
55 | 56 |
|
56 | 57 |
|
57 | 58 | if is_wandb_available(): |
@@ -847,6 +848,11 @@ def main(args): |
847 | 848 | logger.info("Initializing controlnet weights from unet") |
848 | 849 | controlnet = ControlNetModel.from_unet(unet) |
849 | 850 |
|
| 851 | + def unwrap_model(model): |
| 852 | + model = accelerator.unwrap_model(model) |
| 853 | + model = model._orig_mod if is_compiled_module(model) else model |
| 854 | + return model |
| 855 | + |
850 | 856 | # `accelerate` 0.16.0 will have better support for customized saving |
851 | 857 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): |
852 | 858 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format |
@@ -908,9 +914,9 @@ def load_model_hook(models, input_dir): |
908 | 914 | " doing mixed precision training, copy of the weights should still be float32." |
909 | 915 | ) |
910 | 916 |
|
911 | | - if accelerator.unwrap_model(controlnet).dtype != torch.float32: |
| 917 | + if unwrap_model(controlnet).dtype != torch.float32: |
912 | 918 | raise ValueError( |
913 | | - f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" |
| 919 | + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" |
914 | 920 | ) |
915 | 921 |
|
916 | 922 | # Enable TF32 for faster training on Ampere GPUs, |
@@ -1158,7 +1164,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer |
1158 | 1164 | sample.to(dtype=weight_dtype) for sample in down_block_res_samples |
1159 | 1165 | ], |
1160 | 1166 | mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), |
1161 | | - ).sample |
| 1167 | + return_dict=False, |
| 1168 | + )[0] |
1162 | 1169 |
|
1163 | 1170 | # Get the target for loss depending on the prediction type |
1164 | 1171 | if noise_scheduler.config.prediction_type == "epsilon": |
@@ -1223,7 +1230,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer |
1223 | 1230 | # Create the pipeline using using the trained modules and save it. |
1224 | 1231 | accelerator.wait_for_everyone() |
1225 | 1232 | if accelerator.is_main_process: |
1226 | | - controlnet = accelerator.unwrap_model(controlnet) |
| 1233 | + controlnet = unwrap_model(controlnet) |
1227 | 1234 | controlnet.save_pretrained(args.output_dir) |
1228 | 1235 |
|
1229 | 1236 | if args.push_to_hub: |
|
0 commit comments