Skip to content

Commit c04b0d2

Browse files
authored
Merge branch 'main' into enable-hotswap-testing-ci
2 parents e2cd241 + b4be422 commit c04b0d2

23 files changed

+6657
-82
lines changed

examples/community/pipeline_controlnet_xl_kolors.py

Lines changed: 1355 additions & 0 deletions
Large diffs are not rendered by default.

examples/community/pipeline_controlnet_xl_kolors_img2img.py

Lines changed: 1557 additions & 0 deletions
Large diffs are not rendered by default.

examples/community/pipeline_controlnet_xl_kolors_inpaint.py

Lines changed: 1871 additions & 0 deletions
Large diffs are not rendered by default.

examples/community/pipeline_kolors_inpainting.py

Lines changed: 1728 additions & 0 deletions
Large diffs are not rendered by default.

examples/dreambooth/README_hidream.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pip install -e .
2525

2626
Then cd in the `examples/dreambooth` folder and run
2727
```bash
28-
pip install -r requirements_sana.txt
28+
pip install -r requirements_hidream.txt
2929
```
3030

3131
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,15 @@ def parse_args(input_args=None):
618618
),
619619
)
620620
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
621+
parser.add_argument(
622+
"--image_interpolation_mode",
623+
type=str,
624+
default="lanczos",
625+
choices=[
626+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
627+
],
628+
help="The image interpolation method to use for resizing images.",
629+
)
621630

622631
if input_args is not None:
623632
args = parser.parse_args(input_args)
@@ -737,7 +746,10 @@ def __init__(
737746
self.instance_images.extend(itertools.repeat(img, repeats))
738747

739748
self.pixel_values = []
740-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
749+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
750+
if interpolation is None:
751+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
752+
train_resize = transforms.Resize(size, interpolation=interpolation)
741753
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
742754
train_flip = transforms.RandomHorizontalFlip(p=1.0)
743755
train_transforms = transforms.Compose(

src/diffusers/hooks/group_offloading.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
non_blocking: bool = False,
5858
stream: Optional[torch.cuda.Stream] = None,
5959
record_stream: Optional[bool] = False,
60-
low_cpu_mem_usage=False,
60+
low_cpu_mem_usage: bool = False,
6161
onload_self: bool = True,
6262
) -> None:
6363
self.modules = modules
@@ -498,6 +498,8 @@ def _apply_group_offloading_block_level(
498498
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
499499
the CPU memory is a bottleneck but may counteract the benefits of using streams.
500500
"""
501+
if stream is not None and num_blocks_per_group != 1:
502+
raise ValueError(f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}.")
501503

502504
# Create module groups for ModuleList and Sequential blocks
503505
modules_with_group_offloading = set()
@@ -521,20 +523,16 @@ def _apply_group_offloading_block_level(
521523
stream=stream,
522524
record_stream=record_stream,
523525
low_cpu_mem_usage=low_cpu_mem_usage,
524-
onload_self=stream is None,
526+
onload_self=True,
525527
)
526528
matched_module_groups.append(group)
527529
for j in range(i, i + len(current_modules)):
528530
modules_with_group_offloading.add(f"{name}.{j}")
529531

530532
# Apply group offloading hooks to the module groups
531533
for i, group in enumerate(matched_module_groups):
532-
next_group = (
533-
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
534-
)
535-
536534
for group_module in group.modules:
537-
_apply_group_offloading_hook(group_module, group, next_group)
535+
_apply_group_offloading_hook(group_module, group, None)
538536

539537
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
540538
# when the forward pass of this module is called. This is because the top-level module is not
@@ -560,8 +558,10 @@ def _apply_group_offloading_block_level(
560558
record_stream=False,
561559
onload_self=True,
562560
)
563-
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
564-
_apply_group_offloading_hook(module, unmatched_group, next_group)
561+
if stream is None:
562+
_apply_group_offloading_hook(module, unmatched_group, None)
563+
else:
564+
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
565565

566566

567567
def _apply_group_offloading_leaf_level(

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
433433
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
434434
if not is_sparse:
435435
# down_weight is copied to each split
436-
ait_sd.update({k: down_weight for k in ait_down_keys})
436+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
437437

438438
# up_weight is split to each split
439439
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -923,7 +923,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
923923
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
924924

925925
# down_weight is copied to each split
926-
ait_sd.update({k: down_weight for k in ait_down_keys})
926+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
927927

928928
# up_weight is split to each split
929929
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416

src/diffusers/models/controlnets/controlnet_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23-
from ...models.attention_processor import AttentionProcessor
24-
from ...models.modeling_utils import ModelMixin
2523
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24+
from ..attention_processor import AttentionProcessor
2625
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
2726
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
2827
from ..modeling_outputs import Transformer2DModelOutput
28+
from ..modeling_utils import ModelMixin
2929
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
3030

3131

src/diffusers/models/controlnets/multicontrolnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import torch
55
from torch import nn
66

7-
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
8-
from ...models.modeling_utils import ModelMixin
97
from ...utils import logging
8+
from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
9+
from ..modeling_utils import ModelMixin
1010

1111

1212
logger = logging.get_logger(__name__)

0 commit comments

Comments
 (0)