Skip to content

Commit b295b69

Browse files
authored
Merge branch 'main' into enable-hotswap-testing-ci
2 parents 4e8dffe + fb29132 commit b295b69

38 files changed

+2427
-104
lines changed

.github/workflows/pr_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ on:
1111
- "tests/**.py"
1212
- ".github/**.yml"
1313
- "utils/**.py"
14+
- "setup.py"
1415
push:
1516
branches:
1617
- ci-*

docs/source/en/api/pipelines/hunyuan_video.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ The following models are available for the image-to-video pipeline:
5252
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
5353
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
5454
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
55+
- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | lllyasviel's paper introducing a new technique for long-context video generation called [Framepack](https://arxiv.org/abs/2504.12626). |
5556

5657
## Quantization
5758

docs/source/en/quantization/bitsandbytes.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
4848
```py
4949
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
5050
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
51-
51+
import torch
5252
from diffusers import AutoModel
5353
from transformers import T5EncoderModel
5454

@@ -88,6 +88,8 @@ Setting `device_map="auto"` automatically fills all available space on the GPU(s
8888
CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
8989

9090
```py
91+
from diffusers import FluxPipeline
92+
9193
pipe = FluxPipeline.from_pretrained(
9294
"black-forest-labs/FLUX.1-dev",
9395
transformer=transformer_8bit,
@@ -132,7 +134,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
132134
```py
133135
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
134136
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
135-
137+
import torch
136138
from diffusers import AutoModel
137139
from transformers import T5EncoderModel
138140

@@ -171,6 +173,8 @@ Let's generate an image using our quantized models.
171173
Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
172174

173175
```py
176+
from diffusers import FluxPipeline
177+
174178
pipe = FluxPipeline.from_pretrained(
175179
"black-forest-labs/FLUX.1-dev",
176180
transformer=transformer_4bit,
@@ -214,6 +218,8 @@ Check your memory footprint with the `get_memory_footprint` method:
214218
print(model.get_memory_footprint())
215219
```
216220

221+
Note that this only tells you the memory footprint of the model params and does _not_ estimate the inference memory requirements.
222+
217223
Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
218224

219225
```py
@@ -413,4 +419,4 @@ transformer_4bit.dequantize()
413419
## Resources
414420

415421
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
416-
* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
422+
* [Training](https://github.com/huggingface/diffusers/blob/8c661ea586bf11cb2440da740dd3c4cf84679b85/examples/dreambooth/README_hidream.md#using-quantization)

docs/source/en/quantization/torchao.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ The quantization methods supported are as follows:
8585
| **Category** | **Full Function Names** | **Shorthands** |
8686
|--------------|-------------------------|----------------|
8787
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
88-
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
88+
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
8989
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
9090
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
9191

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,15 @@ def parse_args(input_args=None):
673673
default=False,
674674
help="Cache the VAE latents",
675675
)
676+
parser.add_argument(
677+
"--image_interpolation_mode",
678+
type=str,
679+
default="lanczos",
680+
choices=[
681+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
682+
],
683+
help="The image interpolation method to use for resizing images.",
684+
)
676685

677686
if input_args is not None:
678687
args = parser.parse_args(input_args)
@@ -907,6 +916,10 @@ def __init__(
907916
self.num_instance_images = len(self.instance_images)
908917
self._length = self.num_instance_images
909918

919+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
920+
if interpolation is None:
921+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
922+
910923
if class_data_root is not None:
911924
self.class_data_root = Path(class_data_root)
912925
self.class_data_root.mkdir(parents=True, exist_ok=True)
@@ -921,7 +934,7 @@ def __init__(
921934

922935
self.image_transforms = transforms.Compose(
923936
[
924-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
937+
transforms.Resize(size, interpolation=interpolation),
925938
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
926939
transforms.ToTensor(),
927940
transforms.Normalize([0.5], [0.5]),

examples/dreambooth/README_hidream.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization:
117117
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
118118

119119
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
120+
121+
## Using quantization
122+
123+
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
124+
125+
```json
126+
{
127+
"load_in_4bit": true,
128+
"bnb_4bit_quant_type": "nf4"
129+
}
130+
```
131+
132+
Below, we provide some numbers with and without the use of NF4 quantization when training:
133+
134+
```
135+
(with quantization)
136+
Memory (before device placement): 9.085089683532715 GB.
137+
Memory (after device placement): 34.59585428237915 GB.
138+
Memory (after backward): 36.90267467498779 GB.
139+
140+
(without quantization)
141+
Memory (before device placement): 0.0 GB.
142+
Memory (after device placement): 57.6400408744812 GB.
143+
Memory (after backward): 59.932212829589844 GB.
144+
```
145+
146+
The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import argparse
1717
import copy
1818
import itertools
19+
import json
1920
import logging
2021
import math
2122
import os
@@ -27,14 +28,13 @@
2728

2829
import numpy as np
2930
import torch
30-
import torch.utils.checkpoint
3131
import transformers
3232
from accelerate import Accelerator
3333
from accelerate.logging import get_logger
3434
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3535
from huggingface_hub import create_repo, upload_folder
3636
from huggingface_hub.utils import insecure_hashlib
37-
from peft import LoraConfig, set_peft_model_state_dict
37+
from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
3838
from peft.utils import get_peft_model_state_dict
3939
from PIL import Image
4040
from PIL.ImageOps import exif_transpose
@@ -47,6 +47,7 @@
4747
import diffusers
4848
from diffusers import (
4949
AutoencoderKL,
50+
BitsAndBytesConfig,
5051
FlowMatchEulerDiscreteScheduler,
5152
HiDreamImagePipeline,
5253
HiDreamImageTransformer2DModel,
@@ -282,6 +283,12 @@ def parse_args(input_args=None):
282283
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
283284
help="Path to pretrained model or model identifier from huggingface.co/models.",
284285
)
286+
parser.add_argument(
287+
"--bnb_quantization_config_path",
288+
type=str,
289+
default=None,
290+
help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
291+
)
285292
parser.add_argument(
286293
"--revision",
287294
type=str,
@@ -1056,6 +1063,14 @@ def main(args):
10561063
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
10571064
)
10581065

1066+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
1067+
# as these weights are only used for inference, keeping weights in full precision is not required.
1068+
weight_dtype = torch.float32
1069+
if accelerator.mixed_precision == "fp16":
1070+
weight_dtype = torch.float16
1071+
elif accelerator.mixed_precision == "bf16":
1072+
weight_dtype = torch.bfloat16
1073+
10591074
# Load scheduler and models
10601075
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
10611076
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0
@@ -1064,20 +1079,31 @@ def main(args):
10641079
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(
10651080
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
10661081
)
1067-
10681082
vae = AutoencoderKL.from_pretrained(
10691083
args.pretrained_model_name_or_path,
10701084
subfolder="vae",
10711085
revision=args.revision,
10721086
variant=args.variant,
10731087
)
1088+
quantization_config = None
1089+
if args.bnb_quantization_config_path is not None:
1090+
with open(args.bnb_quantization_config_path, "r") as f:
1091+
config_kwargs = json.load(f)
1092+
if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
1093+
config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
1094+
quantization_config = BitsAndBytesConfig(**config_kwargs)
1095+
10741096
transformer = HiDreamImageTransformer2DModel.from_pretrained(
10751097
args.pretrained_model_name_or_path,
10761098
subfolder="transformer",
10771099
revision=args.revision,
10781100
variant=args.variant,
1101+
quantization_config=quantization_config,
1102+
torch_dtype=weight_dtype,
10791103
force_inference_output=True,
10801104
)
1105+
if args.bnb_quantization_config_path is not None:
1106+
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
10811107

10821108
# We only train the additional adapter LoRA layers
10831109
transformer.requires_grad_(False)
@@ -1087,14 +1113,6 @@ def main(args):
10871113
text_encoder_three.requires_grad_(False)
10881114
text_encoder_four.requires_grad_(False)
10891115

1090-
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
1091-
# as these weights are only used for inference, keeping weights in full precision is not required.
1092-
weight_dtype = torch.float32
1093-
if accelerator.mixed_precision == "fp16":
1094-
weight_dtype = torch.float16
1095-
elif accelerator.mixed_precision == "bf16":
1096-
weight_dtype = torch.bfloat16
1097-
10981116
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
10991117
# due to pytorch#99272, MPS does not yet support bfloat16.
11001118
raise ValueError(
@@ -1109,7 +1127,12 @@ def main(args):
11091127
text_encoder_three.to(**to_kwargs)
11101128
text_encoder_four.to(**to_kwargs)
11111129
# we never offload the transformer to CPU, so we can just use the accelerator device
1112-
transformer.to(accelerator.device, dtype=weight_dtype)
1130+
transformer_to_kwargs = (
1131+
{"device": accelerator.device}
1132+
if args.bnb_quantization_config_path is not None
1133+
else {"device": accelerator.device, "dtype": weight_dtype}
1134+
)
1135+
transformer.to(**transformer_to_kwargs)
11131136

11141137
# Initialize a text encoding pipeline and keep it to CPU for now.
11151138
text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(
@@ -1695,10 +1718,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16951718
accelerator.wait_for_everyone()
16961719
if accelerator.is_main_process:
16971720
transformer = unwrap_model(transformer)
1698-
if args.upcast_before_saving:
1699-
transformer.to(torch.float32)
1700-
else:
1701-
transformer = transformer.to(weight_dtype)
1721+
if args.bnb_quantization_config_path is None:
1722+
if args.upcast_before_saving:
1723+
transformer.to(torch.float32)
1724+
else:
1725+
transformer = transformer.to(weight_dtype)
17021726
transformer_lora_layers = get_peft_model_state_dict(transformer)
17031727

17041728
HiDreamImagePipeline.save_lora_weights(

examples/dreambooth/train_dreambooth_lora_lumina2.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,15 @@ def parse_args(input_args=None):
599599
"Defaults to precision dtype used for training to save memory"
600600
),
601601
)
602+
parser.add_argument(
603+
"--image_interpolation_mode",
604+
type=str,
605+
default="lanczos",
606+
choices=[
607+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
608+
],
609+
help="The image interpolation method to use for resizing images.",
610+
)
602611
parser.add_argument(
603612
"--offload",
604613
action="store_true",
@@ -724,7 +733,11 @@ def __init__(
724733
self.instance_images.extend(itertools.repeat(img, repeats))
725734

726735
self.pixel_values = []
727-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
736+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
737+
if interpolation is None:
738+
raise ValueError(f"Unsupported interpolation mode: {args.image_interpolation_mode}")
739+
740+
train_resize = transforms.Resize(size, interpolation=interpolation)
728741
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
729742
train_flip = transforms.RandomHorizontalFlip(p=1.0)
730743
train_transforms = transforms.Compose(
@@ -768,7 +781,7 @@ def __init__(
768781

769782
self.image_transforms = transforms.Compose(
770783
[
771-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
784+
transforms.Resize(size, interpolation=interpolation),
772785
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
773786
transforms.ToTensor(),
774787
transforms.Normalize([0.5], [0.5]),

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def __init__(
852852

853853
self.image_transforms = transforms.Compose(
854854
[
855-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
855+
transforms.Resize(size, interpolation=interpolation),
856856
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
857857
transforms.ToTensor(),
858858
transforms.Normalize([0.5], [0.5]),

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,15 @@ def parse_args(input_args=None):
480480
action="store_true",
481481
help="debug loss for each image, if filenames are available in the dataset",
482482
)
483+
parser.add_argument(
484+
"--image_interpolation_mode",
485+
type=str,
486+
default="lanczos",
487+
choices=[
488+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
489+
],
490+
help="The image interpolation method to use for resizing images.",
491+
)
483492

484493
if input_args is not None:
485494
args = parser.parse_args(input_args)
@@ -913,8 +922,14 @@ def tokenize_captions(examples, is_train=True):
913922
tokens_two = tokenize_prompt(tokenizer_two, captions)
914923
return tokens_one, tokens_two
915924

925+
# Get the specified interpolation method from the args
926+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
927+
928+
# Raise an error if the interpolation method is invalid
929+
if interpolation is None:
930+
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
916931
# Preprocessing the datasets.
917-
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
932+
train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method
918933
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
919934
train_flip = transforms.RandomHorizontalFlip(p=1.0)
920935
train_transforms = transforms.Compose(

0 commit comments

Comments
 (0)