Skip to content

Commit 6bafe8f

Browse files
authored
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
2 parents 3b9dad8 + 9c13f86 commit 6bafe8f

22 files changed

+346
-243
lines changed

.github/workflows/pr_tests_gpu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- "src/diffusers/loaders/peft.py"
1414
- "tests/pipelines/test_pipelines_common.py"
1515
- "tests/models/test_modeling_common.py"
16+
- "examples/**/*.py"
1617
workflow_dispatch:
1718

1819
concurrency:

examples/controlnet/train_controlnet_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13301330
# controlnet(s) inference
13311331
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
13321332
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
1333-
controlnet_image = controlnet_image * vae.config.scaling_factor
1333+
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
13341334

13351335
control_block_res_samples = controlnet(
13361336
hidden_states=noisy_model_input,

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
compute_density_for_timestep_sampling,
6060
compute_loss_weighting_for_sd3,
6161
free_memory,
62+
offload_models,
6263
)
6364
from diffusers.utils import (
6465
check_min_version,
@@ -1375,43 +1376,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13751376
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
13761377
# the redundant encoding.
13771378
if not train_dataset.custom_instance_prompts:
1378-
if args.offload:
1379-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1380-
(
1381-
instance_prompt_hidden_states_t5,
1382-
instance_prompt_hidden_states_llama3,
1383-
instance_pooled_prompt_embeds,
1384-
_,
1385-
_,
1386-
_,
1387-
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
1388-
if args.offload:
1389-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1379+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1380+
(
1381+
instance_prompt_hidden_states_t5,
1382+
instance_prompt_hidden_states_llama3,
1383+
instance_pooled_prompt_embeds,
1384+
_,
1385+
_,
1386+
_,
1387+
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
13901388

13911389
# Handle class prompt for prior-preservation.
13921390
if args.with_prior_preservation:
1393-
if args.offload:
1394-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1395-
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1396-
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1397-
)
1398-
if args.offload:
1399-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1391+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1392+
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1393+
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1394+
)
14001395

14011396
validation_embeddings = {}
14021397
if args.validation_prompt is not None:
1403-
if args.offload:
1404-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1405-
(
1406-
validation_embeddings["prompt_embeds_t5"],
1407-
validation_embeddings["prompt_embeds_llama3"],
1408-
validation_embeddings["pooled_prompt_embeds"],
1409-
validation_embeddings["negative_prompt_embeds_t5"],
1410-
validation_embeddings["negative_prompt_embeds_llama3"],
1411-
validation_embeddings["negative_pooled_prompt_embeds"],
1412-
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
1413-
if args.offload:
1414-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1398+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1399+
(
1400+
validation_embeddings["prompt_embeds_t5"],
1401+
validation_embeddings["prompt_embeds_llama3"],
1402+
validation_embeddings["pooled_prompt_embeds"],
1403+
validation_embeddings["negative_prompt_embeds_t5"],
1404+
validation_embeddings["negative_prompt_embeds_llama3"],
1405+
validation_embeddings["negative_pooled_prompt_embeds"],
1406+
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
14151407

14161408
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14171409
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1593,12 +1585,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15931585
if args.cache_latents:
15941586
model_input = latents_cache[step].sample()
15951587
else:
1596-
if args.offload:
1597-
vae = vae.to(accelerator.device)
1598-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1588+
with offload_models(vae, device=accelerator.device, offload=args.offload):
1589+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
15991590
model_input = vae.encode(pixel_values).latent_dist.sample()
1600-
if args.offload:
1601-
vae = vae.to("cpu")
1591+
16021592
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
16031593
model_input = model_input.to(dtype=weight_dtype)
16041594

examples/server/requirements.txt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# This file was autogenerated by uv via the following command:
22
# uv pip compile requirements.in -o requirements.txt
3-
aiohappyeyeballs==2.4.3
3+
aiohappyeyeballs==2.6.1
44
# via aiohttp
5-
aiohttp==3.10.10
5+
aiohttp==3.12.14
66
# via -r requirements.in
7-
aiosignal==1.3.1
7+
aiosignal==1.4.0
88
# via aiohttp
99
annotated-types==0.7.0
1010
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
2929
# huggingface-hub
3030
# torch
3131
# transformers
32-
# triton
3332
frozenlist==1.5.0
3433
# via
3534
# aiohttp
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
111110
prometheus-fastapi-instrumentator==7.0.0
112111
# via -r requirements.in
113112
propcache==0.2.0
114-
# via yarl
113+
# via
114+
# aiohttp
115+
# yarl
115116
py-consul==1.5.3
116117
# via -r requirements.in
117118
pydantic==2.9.2
@@ -155,7 +156,9 @@ triton==3.3.0
155156
# via torch
156157
typing-extensions==4.12.2
157158
# via
159+
# aiosignal
158160
# anyio
161+
# exceptiongroup
159162
# fastapi
160163
# huggingface-hub
161164
# multidict
@@ -168,5 +171,5 @@ urllib3==2.5.0
168171
# via requests
169172
uvicorn==0.32.0
170173
# via -r requirements.in
171-
yarl==1.16.0
174+
yarl==1.18.3
172175
# via aiohttp

src/diffusers/configuration_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,4 +763,7 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
763763
# resolve remapping
764764
remapped_class = _fetch_remapped_cls_from_config(config, cls)
765765

766-
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
766+
if remapped_class is cls:
767+
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
768+
else:
769+
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .. import __version__
2525
from ..quantizers import DiffusersAutoQuantizer
2626
from ..utils import deprecate, is_accelerate_available, logging
27-
from ..utils.torch_utils import device_synchronize, empty_device_cache
27+
from ..utils.torch_utils import empty_device_cache
2828
from .single_file_utils import (
2929
SingleFileComponentError,
3030
convert_animatediff_checkpoint_to_diffusers,
@@ -431,10 +431,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
431431
keep_in_fp32_modules=keep_in_fp32_modules,
432432
unexpected_keys=unexpected_keys,
433433
)
434-
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
435-
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
436434
empty_device_cache()
437-
device_synchronize()
438435
else:
439436
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
440437

src/diffusers/loaders/single_file_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
4848
from ..utils.hub_utils import _get_model_file
49-
from ..utils.torch_utils import device_synchronize, empty_device_cache
49+
from ..utils.torch_utils import empty_device_cache
5050

5151

5252
if is_transformers_available():
@@ -1690,10 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
16901690

16911691
if is_accelerate_available():
16921692
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1693-
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
1694-
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
16951693
empty_device_cache()
1696-
device_synchronize()
16971694
else:
16981695
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16991696

@@ -2153,10 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
21532150

21542151
if is_accelerate_available():
21552152
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2156-
# Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
2157-
# required because we move tensors with non_blocking=True, which is slightly faster for model loading.
21582153
empty_device_cache()
2159-
device_synchronize()
21602154
else:
21612155
model.load_state_dict(diffusers_format_checkpoint)
21622156

src/diffusers/loaders/transformer_flux.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2121
from ..utils import is_accelerate_available, is_torch_version, logging
22-
from ..utils.torch_utils import device_synchronize, empty_device_cache
22+
from ..utils.torch_utils import empty_device_cache
2323

2424

2525
if is_accelerate_available():
@@ -82,7 +82,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8282
device_map = {"": self.device}
8383
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
8484
empty_device_cache()
85-
device_synchronize()
8685

8786
return image_projection
8887

@@ -158,7 +157,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
158157
key_id += 1
159158

160159
empty_device_cache()
161-
device_synchronize()
162160

163161
return attn_procs
164162

src/diffusers/loaders/transformer_sd3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..models.embeddings import IPAdapterTimeImageProjection
1919
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
2020
from ..utils import is_accelerate_available, is_torch_version, logging
21-
from ..utils.torch_utils import device_synchronize, empty_device_cache
21+
from ..utils.torch_utils import empty_device_cache
2222

2323

2424
logger = logging.get_logger(__name__)
@@ -82,7 +82,6 @@ def _convert_ip_adapter_attn_to_diffusers(
8282
)
8383

8484
empty_device_cache()
85-
device_synchronize()
8685

8786
return attn_procs
8887

@@ -152,7 +151,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(
152151
device_map = {"": self.device}
153152
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
154153
empty_device_cache()
155-
device_synchronize()
156154

157155
return image_proj
158156

src/diffusers/loaders/unet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
is_torch_version,
4444
logging,
4545
)
46-
from ..utils.torch_utils import device_synchronize, empty_device_cache
46+
from ..utils.torch_utils import empty_device_cache
4747
from .lora_base import _func_optionally_disable_offloading
4848
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
4949
from .utils import AttnProcsLayers
@@ -755,7 +755,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
755755
device_map = {"": self.device}
756756
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
757757
empty_device_cache()
758-
device_synchronize()
759758

760759
return image_projection
761760

@@ -854,7 +853,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
854853
key_id += 2
855854

856855
empty_device_cache()
857-
device_synchronize()
858856

859857
return attn_procs
860858

0 commit comments

Comments
 (0)