Skip to content

Commit 7938b42

Browse files
authored
Merge branch 'main' into ipadapter-flux
2 parents cab0dd8 + 9020086 commit 7938b42

File tree

5 files changed

+235
-4
lines changed

5 files changed

+235
-4
lines changed

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ transformer = LTXVideoTransformer3DModel.from_single_file(
7979
pipe = LTXPipeline.from_pretrained(
8080
"Lightricks/LTX-Video",
8181
transformer=transformer,
82-
generator=torch.manual_seed(0),
8382
torch_dtype=torch.bfloat16,
8483
)
8584
pipe.enable_model_cpu_offload()

docs/source/en/api/pipelines/mochi.md

Lines changed: 196 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
-->
1515

16-
# Mochi
16+
# Mochi 1 Preview
1717

1818
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
1919

@@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
2525

2626
</Tip>
2727

28+
## Generating videos with Mochi-1 Preview
29+
30+
The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
31+
32+
```python
33+
import torch
34+
from diffusers import MochiPipeline
35+
from diffusers.utils import export_to_video
36+
37+
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
38+
39+
# Enable memory savings
40+
pipe.enable_model_cpu_offload()
41+
pipe.enable_vae_tiling()
42+
43+
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
44+
45+
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
46+
frames = pipe(prompt, num_frames=85).frames[0]
47+
48+
export_to_video(frames, "mochi.mp4", fps=30)
49+
```
50+
51+
## Using a lower precision variant to save memory
52+
53+
The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
54+
55+
```python
56+
import torch
57+
from diffusers import MochiPipeline
58+
from diffusers.utils import export_to_video
59+
60+
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
61+
62+
# Enable memory savings
63+
pipe.enable_model_cpu_offload()
64+
pipe.enable_vae_tiling()
65+
66+
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
67+
frames = pipe(prompt, num_frames=85).frames[0]
68+
69+
export_to_video(frames, "mochi.mp4", fps=30)
70+
```
71+
72+
## Reproducing the results from the Genmo Mochi repo
73+
74+
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
75+
76+
<Tip>
77+
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
78+
79+
When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
80+
</Tip>
81+
82+
<Tip>
83+
Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
84+
</Tip>
85+
86+
```python
87+
import torch
88+
from torch.nn.attention import SDPBackend, sdpa_kernel
89+
90+
from diffusers import MochiPipeline
91+
from diffusers.utils import export_to_video
92+
from diffusers.video_processor import VideoProcessor
93+
94+
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
95+
pipe.enable_vae_tiling()
96+
pipe.enable_model_cpu_offload()
97+
98+
prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
99+
100+
with torch.no_grad():
101+
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
102+
pipe.encode_prompt(prompt=prompt)
103+
)
104+
105+
with torch.autocast("cuda", torch.bfloat16):
106+
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
107+
frames = pipe(
108+
prompt_embeds=prompt_embeds,
109+
prompt_attention_mask=prompt_attention_mask,
110+
negative_prompt_embeds=negative_prompt_embeds,
111+
negative_prompt_attention_mask=negative_prompt_attention_mask,
112+
guidance_scale=4.5,
113+
num_inference_steps=64,
114+
height=480,
115+
width=848,
116+
num_frames=163,
117+
generator=torch.Generator("cuda").manual_seed(0),
118+
output_type="latent",
119+
return_dict=False,
120+
)[0]
121+
122+
video_processor = VideoProcessor(vae_scale_factor=8)
123+
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
124+
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
125+
if has_latents_mean and has_latents_std:
126+
latents_mean = (
127+
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
128+
)
129+
latents_std = (
130+
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
131+
)
132+
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
133+
else:
134+
frames = frames / pipe.vae.config.scaling_factor
135+
136+
with torch.no_grad():
137+
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
138+
139+
video = video_processor.postprocess_video(video)[0]
140+
export_to_video(video, "mochi.mp4", fps=30)
141+
```
142+
143+
## Running inference with multiple GPUs
144+
145+
It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
146+
147+
```python
148+
import torch
149+
from diffusers import MochiPipeline, MochiTransformer3DModel
150+
from diffusers.utils import export_to_video
151+
152+
model_id = "genmo/mochi-1-preview"
153+
transformer = MochiTransformer3DModel.from_pretrained(
154+
model_id,
155+
subfolder="transformer",
156+
device_map="auto",
157+
max_memory={0: "24GB", 1: "24GB"}
158+
)
159+
160+
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
161+
pipe.enable_model_cpu_offload()
162+
pipe.enable_vae_tiling()
163+
164+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
165+
frames = pipe(
166+
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
167+
negative_prompt="",
168+
height=480,
169+
width=848,
170+
num_frames=85,
171+
num_inference_steps=50,
172+
guidance_scale=4.5,
173+
num_videos_per_prompt=1,
174+
generator=torch.Generator(device="cuda").manual_seed(0),
175+
max_sequence_length=256,
176+
output_type="pil",
177+
).frames[0]
178+
179+
export_to_video(frames, "output.mp4", fps=30)
180+
```
181+
182+
## Using single file loading with the Mochi Transformer
183+
184+
You can use `from_single_file` to load the Mochi transformer in its original format.
185+
186+
<Tip>
187+
Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
188+
</Tip>
189+
190+
```python
191+
import torch
192+
from diffusers import MochiPipeline, MochiTransformer3DModel
193+
from diffusers.utils import export_to_video
194+
195+
model_id = "genmo/mochi-1-preview"
196+
197+
ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
198+
199+
transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
200+
201+
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
202+
pipe.enable_model_cpu_offload()
203+
pipe.enable_vae_tiling()
204+
205+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
206+
frames = pipe(
207+
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
208+
negative_prompt="",
209+
height=480,
210+
width=848,
211+
num_frames=85,
212+
num_inference_steps=50,
213+
guidance_scale=4.5,
214+
num_videos_per_prompt=1,
215+
generator=torch.Generator(device="cuda").manual_seed(0),
216+
max_sequence_length=256,
217+
output_type="pil",
218+
).frames[0]
219+
220+
export_to_video(frames, "output.mp4", fps=30)
221+
```
222+
28223
## MochiPipeline
29224

30225
[[autodoc]] MochiPipeline

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
143143
Args:
144144
text_encoder ([`LlamaModel`]):
145145
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
146-
tokenizer_2 (`LlamaTokenizer`):
146+
tokenizer (`LlamaTokenizer`):
147147
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
148148
transformer ([`HunyuanVideoTransformer3DModel`]):
149149
Conditional Transformer to denoise the encoded image latents.

src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def prepare_latents(
446446
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
447447
)
448448

449-
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
449+
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
450450
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
451451

452452
# check num_channels

tests/lora/test_lora_layers_flux.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,3 +825,40 @@ def test_lora(self, lora_ckpt_id):
825825
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
826826

827827
assert max_diff < 1e-3
828+
829+
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
830+
def test_lora_with_turbo(self, lora_ckpt_id):
831+
self.pipeline.load_lora_weights(lora_ckpt_id)
832+
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
833+
self.pipeline.fuse_lora()
834+
self.pipeline.unload_lora_weights()
835+
836+
if "Canny" in lora_ckpt_id:
837+
control_image = load_image(
838+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/canny_condition_image.png"
839+
)
840+
else:
841+
control_image = load_image(
842+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
843+
)
844+
845+
image = self.pipeline(
846+
prompt=self.prompt,
847+
control_image=control_image,
848+
height=1024,
849+
width=1024,
850+
num_inference_steps=self.num_inference_steps,
851+
guidance_scale=30.0 if "Canny" in lora_ckpt_id else 10.0,
852+
output_type="np",
853+
generator=torch.manual_seed(self.seed),
854+
).images
855+
856+
out_slice = image[0, -3:, -3:, -1].flatten()
857+
if "Canny" in lora_ckpt_id:
858+
expected_slice = np.array([0.6562, 0.7266, 0.7578, 0.6367, 0.6758, 0.7031, 0.6172, 0.6602, 0.6484])
859+
else:
860+
expected_slice = np.array([0.6680, 0.7344, 0.7656, 0.6484, 0.6875, 0.7109, 0.6328, 0.6719, 0.6562])
861+
862+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
863+
864+
assert max_diff < 1e-3

0 commit comments

Comments
 (0)