Skip to content

Commit ba7bb57

Browse files
authored
Merge branch 'main' into cogvideo_dataset_resize__crop
2 parents 8115e41 + 63a5c87 commit ba7bb57

File tree

10 files changed

+69
-18
lines changed

10 files changed

+69
-18
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ transformer = FluxTransformer2DModel.from_pretrained(
177177
```
178178

179179
> [!TIP]
180-
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
180+
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
181181
182182
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
183183

examples/controlnet/train_controlnet_sd3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,11 @@ def parse_args(input_args=None):
357357
action="store_true",
358358
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
359359
)
360+
parser.add_argument(
361+
"--upcast_vae",
362+
action="store_true",
363+
help="Whether or not to upcast vae to fp32",
364+
)
360365
parser.add_argument(
361366
"--learning_rate",
362367
type=float,
@@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir):
10941099
weight_dtype = torch.bfloat16
10951100

10961101
# Move vae, transformer and text_encoder to device and cast to weight_dtype
1097-
vae.to(accelerator.device, dtype=torch.float32)
1102+
if args.upcast_vae:
1103+
vae.to(accelerator.device, dtype=torch.float32)
1104+
else:
1105+
vae.to(accelerator.device, dtype=weight_dtype)
10981106
transformer.to(accelerator.device, dtype=weight_dtype)
10991107
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
11001108
text_encoder_two.to(accelerator.device, dtype=weight_dtype)

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
632632
new_key += ".lora_B.weight"
633633

634634
# Handle single_blocks
635-
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
635+
elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
636636
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637637
new_key = f"transformer.single_transformer_blocks.{block_num}"
638638

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def __init__(
251251
if hasattr(self, "transformer") and self.transformer is not None
252252
else 128
253253
)
254+
self.patch_size = (
255+
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
256+
)
254257

255258
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
256259
def _get_t5_prompt_embeds(
@@ -577,8 +580,14 @@ def check_inputs(
577580
callback_on_step_end_tensor_inputs=None,
578581
max_sequence_length=None,
579582
):
580-
if height % 8 != 0 or width % 8 != 0:
581-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
583+
if (
584+
height % (self.vae_scale_factor * self.patch_size) != 0
585+
or width % (self.vae_scale_factor * self.patch_size) != 0
586+
):
587+
raise ValueError(
588+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
589+
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
590+
)
582591

583592
if callback_on_step_end_tensor_inputs is not None and not all(
584593
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs

src/diffusers/pipelines/pag/pipeline_pag_sd_3.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ def __init__(
212212
if hasattr(self, "transformer") and self.transformer is not None
213213
else 128
214214
)
215+
self.patch_size = (
216+
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
217+
)
215218

216219
self.set_pag_applied_layers(
217220
pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0())
@@ -542,8 +545,14 @@ def check_inputs(
542545
callback_on_step_end_tensor_inputs=None,
543546
max_sequence_length=None,
544547
):
545-
if height % 8 != 0 or width % 8 != 0:
546-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
548+
if (
549+
height % (self.vae_scale_factor * self.patch_size) != 0
550+
or width % (self.vae_scale_factor * self.patch_size) != 0
551+
):
552+
raise ValueError(
553+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
554+
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
555+
)
547556

548557
if callback_on_step_end_tensor_inputs is not None and not all(
549558
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def load_sub_model(
601601
variant: str,
602602
low_cpu_mem_usage: bool,
603603
cached_folder: Union[str, os.PathLike],
604+
use_safetensors: bool,
604605
):
605606
"""Helper method to load the module `name` from `library_name` and `class_name`"""
606607

@@ -670,6 +671,7 @@ def load_sub_model(
670671
loading_kwargs["offload_folder"] = offload_folder
671672
loading_kwargs["offload_state_dict"] = offload_state_dict
672673
loading_kwargs["variant"] = model_variants.pop(name, None)
674+
loading_kwargs["use_safetensors"] = use_safetensors
673675

674676
if from_flax:
675677
loading_kwargs["from_flax"] = True

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def load_module(name, value):
905905
variant=variant,
906906
low_cpu_mem_usage=low_cpu_mem_usage,
907907
cached_folder=cached_folder,
908+
use_safetensors=use_safetensors,
908909
)
909910
logger.info(
910911
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def __init__(
203203
if hasattr(self, "transformer") and self.transformer is not None
204204
else 128
205205
)
206+
self.patch_size = (
207+
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
208+
)
206209

207210
def _get_t5_prompt_embeds(
208211
self,
@@ -525,8 +528,14 @@ def check_inputs(
525528
callback_on_step_end_tensor_inputs=None,
526529
max_sequence_length=None,
527530
):
528-
if height % 8 != 0 or width % 8 != 0:
529-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
531+
if (
532+
height % (self.vae_scale_factor * self.patch_size) != 0
533+
or width % (self.vae_scale_factor * self.patch_size) != 0
534+
):
535+
raise ValueError(
536+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
537+
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
538+
)
530539

531540
if callback_on_step_end_tensor_inputs is not None and not all(
532541
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs

src/diffusers/utils/import_utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,9 @@ def __getattr__(cls, key):
668668
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
669669
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
670670
"""
671-
Args:
672671
Compares a library version to some requirement using a given operation.
672+
673+
Args:
673674
library_or_version (`str` or `packaging.version.Version`):
674675
A library name or a version to check.
675676
operation (`str`):
@@ -688,8 +689,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
688689
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
689690
def is_torch_version(operation: str, version: str):
690691
"""
691-
Args:
692692
Compares the current PyTorch version to a given reference with an operation.
693+
694+
Args:
693695
operation (`str`):
694696
A string representation of an operator, such as `">"` or `"<="`
695697
version (`str`):
@@ -700,8 +702,9 @@ def is_torch_version(operation: str, version: str):
700702

701703
def is_transformers_version(operation: str, version: str):
702704
"""
703-
Args:
704705
Compares the current Transformers version to a given reference with an operation.
706+
707+
Args:
705708
operation (`str`):
706709
A string representation of an operator, such as `">"` or `"<="`
707710
version (`str`):
@@ -714,8 +717,9 @@ def is_transformers_version(operation: str, version: str):
714717

715718
def is_accelerate_version(operation: str, version: str):
716719
"""
717-
Args:
718720
Compares the current Accelerate version to a given reference with an operation.
721+
722+
Args:
719723
operation (`str`):
720724
A string representation of an operator, such as `">"` or `"<="`
721725
version (`str`):
@@ -728,8 +732,9 @@ def is_accelerate_version(operation: str, version: str):
728732

729733
def is_peft_version(operation: str, version: str):
730734
"""
731-
Args:
732735
Compares the current PEFT version to a given reference with an operation.
736+
737+
Args:
733738
operation (`str`):
734739
A string representation of an operator, such as `">"` or `"<="`
735740
version (`str`):
@@ -742,8 +747,9 @@ def is_peft_version(operation: str, version: str):
742747

743748
def is_k_diffusion_version(operation: str, version: str):
744749
"""
745-
Args:
746750
Compares the current k-diffusion version to a given reference with an operation.
751+
752+
Args:
747753
operation (`str`):
748754
A string representation of an operator, such as `">"` or `"<="`
749755
version (`str`):
@@ -756,8 +762,9 @@ def is_k_diffusion_version(operation: str, version: str):
756762

757763
def get_objects_from_module(module):
758764
"""
759-
Args:
760765
Returns a dict of object names and values in a module, while skipping private/internal objects
766+
767+
Args:
761768
module (ModuleType):
762769
Module to extract the objects from.
763770
@@ -775,7 +782,9 @@ def get_objects_from_module(module):
775782

776783

777784
class OptionalDependencyNotAvailable(BaseException):
778-
"""An error indicating that an optional dependency of Diffusers was not found in the environment."""
785+
"""
786+
An error indicating that an optional dependency of Diffusers was not found in the environment.
787+
"""
779788

780789

781790
class _LazyModule(ModuleType):

tests/lora/test_lora_layers_flux.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,11 @@ def test_modify_padding_mode(self):
169169
@unittest.skip("We cannot run inference on this model with the current CI hardware")
170170
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
171171
class FluxLoRAIntegrationTests(unittest.TestCase):
172-
"""internal note: The integration slices were obtained on audace."""
172+
"""internal note: The integration slices were obtained on audace.
173+
174+
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
175+
assertions to pass.
176+
"""
173177

174178
num_inference_steps = 10
175179
seed = 0

0 commit comments

Comments
 (0)