Skip to content

Commit 758b1db

Browse files
authored
Merge branch 'main' into guidance-scale-docs
2 parents 2185553 + 478df93 commit 758b1db

File tree

56 files changed

+10133
-1048
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+10133
-1048
lines changed

.github/workflows/pr_tests_gpu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- "src/diffusers/loaders/peft.py"
1414
- "tests/pipelines/test_pipelines_common.py"
1515
- "tests/models/test_modeling_common.py"
16+
- "examples/**/*.py"
1617
workflow_dispatch:
1718

1819
concurrency:

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,8 @@
353353
title: SanaTransformer2DModel
354354
- local: api/models/sd3_transformer2d
355355
title: SD3Transformer2DModel
356+
- local: api/models/skyreels_v2_transformer_3d
357+
title: SkyReelsV2Transformer3DModel
356358
- local: api/models/stable_audio_transformer
357359
title: StableAudioDiTModel
358360
- local: api/models/transformer2d
@@ -547,6 +549,8 @@
547549
title: Semantic Guidance
548550
- local: api/pipelines/shap_e
549551
title: Shap-E
552+
- local: api/pipelines/skyreels_v2
553+
title: SkyReels-V2
550554
- local: api/pipelines/stable_audio
551555
title: Stable Audio
552556
- local: api/pipelines/stable_cascade

docs/source/en/api/loaders/lora.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
2626
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
2727
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
2828
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
29+
- [`SkyReelsV2LoraLoaderMixin`] provides similar functions for [SkyReels-V2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/skyreels_v2).
2930
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
3031
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
3132
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
@@ -92,6 +93,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
9293

9394
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
9495

96+
## SkyReelsV2LoraLoaderMixin
97+
98+
[[autodoc]] loaders.lora_pipeline.SkyReelsV2LoraLoaderMixin
99+
95100
## AmusedLoraLoaderMixin
96101

97102
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
@@ -100,6 +105,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
100105

101106
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
102107

103-
## WanLoraLoaderMixin
108+
## LoraBaseMixin
104109

105-
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
110+
[[autodoc]] loaders.lora_base.LoraBaseMixin
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# SkyReelsV2Transformer3DModel
13+
14+
A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/SkyworkAI/SkyReels-V2) by the Skywork AI.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import SkyReelsV2Transformer3DModel
20+
21+
transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
22+
```
23+
24+
## SkyReelsV2Transformer3DModel
25+
26+
[[autodoc]] SkyReelsV2Transformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

docs/source/en/api/pipelines/skyreels_v2.md

Lines changed: 367 additions & 0 deletions
Large diffs are not rendered by default.

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
compute_density_for_timestep_sampling,
5959
compute_loss_weighting_for_sd3,
6060
free_memory,
61+
offload_models,
6162
)
6263
from diffusers.utils import (
6364
check_min_version,
@@ -1364,43 +1365,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13641365
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
13651366
# the redundant encoding.
13661367
if not train_dataset.custom_instance_prompts:
1367-
if args.offload:
1368-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1369-
(
1370-
instance_prompt_hidden_states_t5,
1371-
instance_prompt_hidden_states_llama3,
1372-
instance_pooled_prompt_embeds,
1373-
_,
1374-
_,
1375-
_,
1376-
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
1377-
if args.offload:
1378-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1368+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1369+
(
1370+
instance_prompt_hidden_states_t5,
1371+
instance_prompt_hidden_states_llama3,
1372+
instance_pooled_prompt_embeds,
1373+
_,
1374+
_,
1375+
_,
1376+
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
13791377

13801378
# Handle class prompt for prior-preservation.
13811379
if args.with_prior_preservation:
1382-
if args.offload:
1383-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1384-
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1385-
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1386-
)
1387-
if args.offload:
1388-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1380+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1381+
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1382+
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1383+
)
13891384

13901385
validation_embeddings = {}
13911386
if args.validation_prompt is not None:
1392-
if args.offload:
1393-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1394-
(
1395-
validation_embeddings["prompt_embeds_t5"],
1396-
validation_embeddings["prompt_embeds_llama3"],
1397-
validation_embeddings["pooled_prompt_embeds"],
1398-
validation_embeddings["negative_prompt_embeds_t5"],
1399-
validation_embeddings["negative_prompt_embeds_llama3"],
1400-
validation_embeddings["negative_pooled_prompt_embeds"],
1401-
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
1402-
if args.offload:
1403-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1387+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1388+
(
1389+
validation_embeddings["prompt_embeds_t5"],
1390+
validation_embeddings["prompt_embeds_llama3"],
1391+
validation_embeddings["pooled_prompt_embeds"],
1392+
validation_embeddings["negative_prompt_embeds_t5"],
1393+
validation_embeddings["negative_prompt_embeds_llama3"],
1394+
validation_embeddings["negative_pooled_prompt_embeds"],
1395+
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
14041396

14051397
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14061398
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1581,12 +1573,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15811573
if args.cache_latents:
15821574
model_input = latents_cache[step].sample()
15831575
else:
1584-
if args.offload:
1585-
vae = vae.to(accelerator.device)
1586-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1576+
with offload_models(vae, device=accelerator.device, offload=args.offload):
1577+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
15871578
model_input = vae.encode(pixel_values).latent_dist.sample()
1588-
if args.offload:
1589-
vae = vae.to("cpu")
1579+
15901580
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
15911581
model_input = model_input.to(dtype=weight_dtype)
15921582

0 commit comments

Comments
 (0)