Skip to content

Commit 7a14bd5

Browse files
authored
Merge branch 'main' into xla_sana
2 parents 4eabcf3 + 825979d commit 7a14bd5

11 files changed

+32
-18
lines changed

examples/flux-control/train_control_flux.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def main(args):
795795
flux_transformer.x_embedder = new_linear
796796

797797
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
798-
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
798+
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
799799

800800
def unwrap_model(model):
801801
model = accelerator.unwrap_model(model)
@@ -1166,6 +1166,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11661166
flux_transformer.to(torch.float32)
11671167
flux_transformer.save_pretrained(args.output_dir)
11681168

1169+
del flux_transformer
1170+
del text_encoding_pipeline
1171+
del vae
1172+
free_memory()
1173+
11691174
# Run a final round of validation.
11701175
image_logs = None
11711176
if args.validation_prompt is not None:

examples/flux-control/train_control_lora_flux.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def main(args):
830830
flux_transformer.x_embedder = new_linear
831831

832832
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
833-
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
833+
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
834834

835835
if args.train_norm_layers:
836836
for name, param in flux_transformer.named_parameters():
@@ -1319,6 +1319,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13191319
transformer_lora_layers=transformer_lora_layers,
13201320
)
13211321

1322+
del flux_transformer
1323+
del text_encoding_pipeline
1324+
del vae
1325+
free_memory()
1326+
13221327
# Run a final round of validation.
13231328
image_logs = None
13241329
if args.validation_prompt is not None:

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4839,6 +4839,8 @@ def __call__(
48394839
)
48404840
else:
48414841
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
4842+
if mask is None:
4843+
continue
48424844
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
48434845
raise ValueError(
48444846
"Each element of the ip_adapter_masks array should be a tensor with shape "
@@ -5056,6 +5058,8 @@ def __call__(
50565058
)
50575059
else:
50585060
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
5061+
if mask is None:
5062+
continue
50595063
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
50605064
raise ValueError(
50615065
"Each element of the ip_adapter_masks array should be a tensor with shape "

tests/single_file/single_file_testing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,14 @@ def test_single_file_components_with_diffusers_config_local_files_only(
378378
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
379379
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None)
380380
sf_pipe.unet.set_default_attn_processor()
381-
sf_pipe.enable_model_cpu_offload()
381+
sf_pipe.enable_model_cpu_offload(device=torch_device)
382382

383383
inputs = self.get_inputs(torch_device)
384384
image_single_file = sf_pipe(**inputs).images[0]
385385

386386
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None)
387387
pipe.unet.set_default_attn_processor()
388-
pipe.enable_model_cpu_offload()
388+
pipe.enable_model_cpu_offload(device=torch_device)
389389

390390
inputs = self.get_inputs(torch_device)
391391
image = pipe(**inputs).images[0]

tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
7676
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
7777
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
7878
pipe.unet.set_default_attn_processor()
79-
pipe.enable_model_cpu_offload()
79+
pipe.enable_model_cpu_offload(device=torch_device)
8080

8181
pipe_sf = self.pipeline_class.from_single_file(
8282
self.ckpt_path,
8383
controlnet=controlnet,
8484
)
8585
pipe_sf.unet.set_default_attn_processor()
86-
pipe_sf.enable_model_cpu_offload()
86+
pipe_sf.enable_model_cpu_offload(device=torch_device)
8787

8888
inputs = self.get_inputs(torch_device)
8989
output = pipe(**inputs).images[0]

tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
7373
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
7474
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, safety_checker=None)
7575
pipe.unet.set_default_attn_processor()
76-
pipe.enable_model_cpu_offload()
76+
pipe.enable_model_cpu_offload(device=torch_device)
7777

7878
pipe_sf = self.pipeline_class.from_single_file(self.ckpt_path, controlnet=controlnet, safety_checker=None)
7979
pipe_sf.unet.set_default_attn_processor()
80-
pipe_sf.enable_model_cpu_offload()
80+
pipe_sf.enable_model_cpu_offload(device=torch_device)
8181

8282
inputs = self.get_inputs()
8383
output = pipe(**inputs).images[0]

tests/single_file/test_stable_diffusion_controlnet_single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
6767
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
6868
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
6969
pipe.unet.set_default_attn_processor()
70-
pipe.enable_model_cpu_offload()
70+
pipe.enable_model_cpu_offload(device=torch_device)
7171

7272
pipe_sf = self.pipeline_class.from_single_file(
7373
self.ckpt_path,
7474
controlnet=controlnet,
7575
)
7676
pipe_sf.unet.set_default_attn_processor()
77-
pipe_sf.enable_model_cpu_offload()
77+
pipe_sf.enable_model_cpu_offload(device=torch_device)
7878

7979
inputs = self.get_inputs()
8080
output = pipe(**inputs).images[0]

tests/single_file/test_stable_diffusion_upscale_single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
4949

5050
prompt = "a cat sitting on a park bench"
5151
pipe = StableDiffusionUpscalePipeline.from_pretrained(self.repo_id)
52-
pipe.enable_model_cpu_offload()
52+
pipe.enable_model_cpu_offload(device=torch_device)
5353

5454
generator = torch.Generator("cpu").manual_seed(0)
5555
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
5656
image_from_pretrained = output.images[0]
5757

5858
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(self.ckpt_path)
59-
pipe_from_single_file.enable_model_cpu_offload()
59+
pipe_from_single_file.enable_model_cpu_offload(device=torch_device)
6060

6161
generator = torch.Generator("cpu").manual_seed(0)
6262
output_from_single_file = pipe_from_single_file(

tests/single_file/test_stable_diffusion_xl_adapter_single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
7676
torch_dtype=torch.float16,
7777
safety_checker=None,
7878
)
79-
pipe_single_file.enable_model_cpu_offload()
79+
pipe_single_file.enable_model_cpu_offload(device=torch_device)
8080
pipe_single_file.set_progress_bar_config(disable=None)
8181

8282
inputs = self.get_inputs()
@@ -88,7 +88,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
8888
torch_dtype=torch.float16,
8989
safety_checker=None,
9090
)
91-
pipe.enable_model_cpu_offload()
91+
pipe.enable_model_cpu_offload(device=torch_device)
9292

9393
inputs = self.get_inputs()
9494
images = pipe(**inputs).images[0]

tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
6969
self.ckpt_path, controlnet=controlnet, torch_dtype=torch.float16
7070
)
7171
pipe_single_file.unet.set_default_attn_processor()
72-
pipe_single_file.enable_model_cpu_offload()
72+
pipe_single_file.enable_model_cpu_offload(device=torch_device)
7373
pipe_single_file.set_progress_bar_config(disable=None)
7474

7575
inputs = self.get_inputs(torch_device)
7676
single_file_images = pipe_single_file(**inputs).images[0]
7777

7878
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, torch_dtype=torch.float16)
7979
pipe.unet.set_default_attn_processor()
80-
pipe.enable_model_cpu_offload()
80+
pipe.enable_model_cpu_offload(device=torch_device)
8181

8282
inputs = self.get_inputs(torch_device)
8383
images = pipe(**inputs).images[0]

0 commit comments

Comments
 (0)