Skip to content

Commit 4fae1c4

Browse files
authored
Merge branch 'main' into skyreels-v2
2 parents 3bdbad4 + 06fd427 commit 4fae1c4

23 files changed

+470
-303
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,39 +174,36 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
174174

175175
### Regional compilation
176176

177+
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
178+
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x.
177179

178-
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
179-
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
180-
181-
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
180+
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
182181

183182
```py
184183
# pip install -U diffusers
185184
import torch
186185
from diffusers import StableDiffusionXLPipeline
187186

188-
pipe = StableDiffusionXLPipeline.from_pretrained(
187+
pipeline = StableDiffusionXLPipeline.from_pretrained(
189188
"stabilityai/stable-diffusion-xl-base-1.0",
190189
torch_dtype=torch.float16,
191190
).to("cuda")
192191

193-
# Compile only the repeated Transformer layers inside the UNet
194-
pipe.unet.compile_repeated_blocks(fullgraph=True)
192+
# compile only the repeated transformer layers inside the UNet
193+
pipeline.unet.compile_repeated_blocks(fullgraph=True)
195194
```
196195

197-
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
198-
196+
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
199197

200198
```py
201199
class MyUNet(ModelMixin):
202200
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
203201
```
204202

205-
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
206-
207-
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
208-
203+
> [!TIP]
204+
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
209205
206+
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
210207

211208
```py
212209
# pip install -U accelerate
@@ -219,8 +216,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
219216
).to("cuda")
220217
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
221218
```
222-
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
223219

220+
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
224221

225222
### Graph breaks
226223

@@ -296,3 +293,9 @@ An input is projected into three subspaces, represented by the projection matric
296293
```py
297294
pipeline.fuse_qkv_projections()
298295
```
296+
297+
## Resources
298+
299+
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
300+
301+
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).

docs/source/en/optimization/speed-memory-optims.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ specific language governing permissions and limitations under the License.
1414

1515
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
1616

17+
> [!TIP]
18+
> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
19+
1720
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
1821

1922
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
@@ -25,7 +28,7 @@ The table below provides a comparison of optimization strategy combinations and
2528
| quantization | 32.602 | 14.9453 |
2629
| quantization, torch.compile | 25.847 | 14.9448 |
2730
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
28-
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
31+
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.</small>
2932

3033
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
3134

examples/controlnet/train_controlnet_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13301330
# controlnet(s) inference
13311331
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
13321332
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
1333-
controlnet_image = controlnet_image * vae.config.scaling_factor
1333+
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
13341334

13351335
control_block_res_samples = controlnet(
13361336
hidden_states=noisy_model_input,

examples/server/requirements.txt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# This file was autogenerated by uv via the following command:
22
# uv pip compile requirements.in -o requirements.txt
3-
aiohappyeyeballs==2.4.3
3+
aiohappyeyeballs==2.6.1
44
# via aiohttp
5-
aiohttp==3.10.10
5+
aiohttp==3.12.14
66
# via -r requirements.in
7-
aiosignal==1.3.1
7+
aiosignal==1.4.0
88
# via aiohttp
99
annotated-types==0.7.0
1010
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
2929
# huggingface-hub
3030
# torch
3131
# transformers
32-
# triton
3332
frozenlist==1.5.0
3433
# via
3534
# aiohttp
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
111110
prometheus-fastapi-instrumentator==7.0.0
112111
# via -r requirements.in
113112
propcache==0.2.0
114-
# via yarl
113+
# via
114+
# aiohttp
115+
# yarl
115116
py-consul==1.5.3
116117
# via -r requirements.in
117118
pydantic==2.9.2
@@ -155,7 +156,9 @@ triton==3.3.0
155156
# via torch
156157
typing-extensions==4.12.2
157158
# via
159+
# aiosignal
158160
# anyio
161+
# exceptiongroup
159162
# fastapi
160163
# huggingface-hub
161164
# multidict
@@ -168,5 +171,5 @@ urllib3==2.5.0
168171
# via requests
169172
uvicorn==0.32.0
170173
# via -r requirements.in
171-
yarl==1.16.0
174+
yarl==1.18.3
172175
# via aiohttp

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
2626
from ..utils import deprecate, is_accelerate_available, logging
27+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2728
from .single_file_utils import (
2829
SingleFileComponentError,
2930
convert_animatediff_checkpoint_to_diffusers,
@@ -430,6 +431,10 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
430431
keep_in_fp32_modules=keep_in_fp32_modules,
431432
unexpected_keys=unexpected_keys,
432433
)
434+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
435+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
436+
empty_device_cache()
437+
device_synchronize()
433438
else:
434439
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
435440

src/diffusers/loaders/single_file_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
4848
from ..utils.hub_utils import _get_model_file
49+
from ..utils.torch_utils import device_synchronize, empty_device_cache
4950

5051

5152
if is_transformers_available():
@@ -1689,6 +1690,10 @@ def create_diffusers_clip_model_from_ldm(
16891690

16901691
if is_accelerate_available():
16911692
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1693+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
1694+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
1695+
empty_device_cache()
1696+
device_synchronize()
16921697
else:
16931698
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16941699

@@ -2148,6 +2153,10 @@ def create_diffusers_t5_model_from_checkpoint(
21482153

21492154
if is_accelerate_available():
21502155
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2156+
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
2157+
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
2158+
empty_device_cache()
2159+
device_synchronize()
21512160
else:
21522161
model.load_state_dict(diffusers_format_checkpoint)
21532162

src/diffusers/loaders/transformer_flux.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
MultiIPAdapterImageProjection,
1919
)
2020
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
21-
from ..utils import (
22-
is_accelerate_available,
23-
is_torch_version,
24-
logging,
25-
)
21+
from ..utils import is_accelerate_available, is_torch_version, logging
22+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2623

2724

2825
if is_accelerate_available():
@@ -84,6 +81,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8481
else:
8582
device_map = {"": self.device}
8683
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
84+
empty_device_cache()
85+
device_synchronize()
8786

8887
return image_projection
8988

@@ -158,6 +157,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
158157

159158
key_id += 1
160159

160+
empty_device_cache()
161+
device_synchronize()
162+
161163
return attn_procs
162164

163165
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

src/diffusers/loaders/transformer_sd3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..models.embeddings import IPAdapterTimeImageProjection
1919
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2020
from ..utils import is_accelerate_available, is_torch_version, logging
21+
from ..utils.torch_utils import device_synchronize, empty_device_cache
2122

2223

2324
logger = logging.get_logger(__name__)
@@ -80,6 +81,9 @@ def _convert_ip_adapter_attn_to_diffusers(
8081
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
8182
)
8283

84+
empty_device_cache()
85+
device_synchronize()
86+
8387
return attn_procs
8488

8589
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +151,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(
147151
else:
148152
device_map = {"": self.device}
149153
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
154+
empty_device_cache()
155+
device_synchronize()
150156

151157
return image_proj
152158

src/diffusers/loaders/unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
is_torch_version,
4444
logging,
4545
)
46+
from ..utils.torch_utils import device_synchronize, empty_device_cache
4647
from .lora_base import _func_optionally_disable_offloading
4748
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
4849
from .utils import AttnProcsLayers
@@ -753,6 +754,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
753754
else:
754755
device_map = {"": self.device}
755756
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
757+
empty_device_cache()
758+
device_synchronize()
756759

757760
return image_projection
758761

@@ -850,6 +853,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
850853

851854
key_id += 2
852855

856+
empty_device_cache()
857+
device_synchronize()
858+
853859
return attn_procs
854860

855861
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def forward(
752752
condition = self.controlnet_cond_embedding(cond)
753753
feat_seq = torch.mean(condition, dim=(2, 3))
754754
feat_seq = feat_seq + self.task_embedding[control_idx]
755-
if from_multi:
755+
if from_multi or len(control_type_idx) == 1:
756756
inputs.append(feat_seq.unsqueeze(1))
757757
condition_list.append(condition)
758758
else:
@@ -772,7 +772,7 @@ def forward(
772772
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
773773
alpha = self.spatial_ch_projs(x[:, idx])
774774
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
775-
if from_multi:
775+
if from_multi or len(control_type_idx) == 1:
776776
controlnet_cond_fuser += condition + alpha
777777
else:
778778
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ def forward(
819819
# 6. scaling
820820
if guess_mode and not self.config.global_pool_conditions:
821821
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
822-
if from_multi:
822+
if from_multi or len(control_type_idx) == 1:
823823
scales = scales * conditioning_scale[0]
824824
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
825825
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
826-
elif from_multi:
826+
elif from_multi or len(control_type_idx) == 1:
827827
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828828
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
829829

0 commit comments

Comments
 (0)