Skip to content

Commit 756931f

Browse files
authored
Merge branch 'main' into cogvideox-fun/control
2 parents f747754 + 29a2c5d commit 756931f

File tree

8 files changed

+536
-37
lines changed

8 files changed

+536
-37
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
title: Outpainting
7676
title: Advanced inference
7777
- sections:
78+
- local: using-diffusers/cogvideox
79+
title: CogVideoX
7880
- local: using-diffusers/sdxl
7981
title: Stable Diffusion XL
8082
- local: using-diffusers/sdxl_turbo
@@ -129,6 +131,8 @@
129131
title: T2I-Adapters
130132
- local: training/instructpix2pix
131133
title: InstructPix2Pix
134+
- local: training/cogvideox
135+
title: CogVideoX
132136
title: Models
133137
- isExpanded: false
134138
sections:

docs/source/en/training/cogvideox.md

Lines changed: 291 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
# CogVideoX
13+
14+
CogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods.
15+
16+
- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy.
17+
18+
- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos.
19+
20+
21+
22+
## Load model checkpoints
23+
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
24+
25+
26+
```py
27+
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
28+
pipe = CogVideoXPipeline.from_pretrained(
29+
"THUDM/CogVideoX-2b",
30+
torch_dtype=torch.float16
31+
)
32+
33+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
34+
"THUDM/CogVideoX-5b-I2V",
35+
torch_dtype=torch.bfloat16
36+
)
37+
38+
```
39+
40+
## Text-to-Video
41+
For text-to-video, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results.
42+
43+
```py
44+
import torch
45+
from diffusers import CogVideoXPipeline
46+
from diffusers.utils import export_to_video
47+
48+
prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea."
49+
50+
pipe = CogVideoXPipeline.from_pretrained(
51+
"THUDM/CogVideoX-5b",
52+
torch_dtype=torch.bfloat16
53+
)
54+
55+
pipe.enable_model_cpu_offload()
56+
pipe.vae.enable_tiling()
57+
58+
video = pipe(
59+
prompt=prompt,
60+
num_videos_per_prompt=1,
61+
num_inference_steps=50,
62+
num_frames=49,
63+
guidance_scale=6,
64+
generator=torch.Generator(device="cuda").manual_seed(42),
65+
).frames[0]
66+
67+
export_to_video(video, "output.mp4", fps=8)
68+
69+
```
70+
71+
72+
<div class="flex justify-center">
73+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_out.gif" alt="generated image of an astronaut in a jungle"/>
74+
</div>
75+
76+
77+
## Image-to-Video
78+
79+
80+
You'll use the [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) checkpoint for this guide.
81+
82+
```py
83+
import torch
84+
from diffusers import CogVideoXImageToVideoPipeline
85+
from diffusers.utils import export_to_video, load_image
86+
87+
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
88+
image = load_image(image="cogvideox_rocket.png")
89+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
90+
"THUDM/CogVideoX-5b-I2V",
91+
torch_dtype=torch.bfloat16
92+
)
93+
94+
pipe.vae.enable_tiling()
95+
pipe.vae.enable_slicing()
96+
97+
video = pipe(
98+
prompt=prompt,
99+
image=image,
100+
num_videos_per_prompt=1,
101+
num_inference_steps=50,
102+
num_frames=49,
103+
guidance_scale=6,
104+
generator=torch.Generator(device="cuda").manual_seed(42),
105+
).frames[0]
106+
107+
export_to_video(video, "output.mp4", fps=8)
108+
```
109+
110+
<div class="flex gap-4">
111+
<div>
112+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png"/>
113+
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
114+
</div>
115+
<div>
116+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_outrocket.gif"/>
117+
<figcaption class="mt-2 text-center text-sm text-gray-500">generated video</figcaption>
118+
</div>
119+
</div>
120+

docs/source/en/using-diffusers/text-img2vid.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,59 @@ This guide will show you how to generate videos, how to configure video model pa
2323
2424
[Stable Video Diffusions (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid), [I2VGen-XL](https://huggingface.co/ali-vilab/i2vgen-xl/), [AnimateDiff](https://huggingface.co/guoyww/animatediff), and [ModelScopeT2V](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) are popular models used for video diffusion. Each model is distinct. For example, AnimateDiff inserts a motion modeling module into a frozen text-to-image model to generate personalized animated images, whereas SVD is entirely pretrained from scratch with a three-stage training process to generate short high-quality videos.
2525

26+
[CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) is another popular video generation model. The model is a multidimensional transformer that integrates text, time, and space. It employs full attention in the attention module and includes an expert block at the layer level to spatially align text and video.
27+
28+
### CogVideoX
29+
30+
[CogVideoX](../api/pipelines/cogvideox) uses a 3D Variational Autoencoder (VAE) to compress videos along the spatial and temporal dimensions.
31+
32+
Begin by loading the [`CogVideoXPipeline`] and passing an initial text or image to generate a video.
33+
<Tip>
34+
35+
CogVideoX is available for image-to-video and text-to-video. [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) uses the [`CogVideoXImageToVideoPipeline`] for image-to-video. [THUDM/CogVideoX-5b](https://huggingface.co/THUDM/CogVideoX-5b) and [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) are available for text-to-video with the [`CogVideoXPipeline`].
36+
37+
</Tip>
38+
39+
```py
40+
import torch
41+
from diffusers import CogVideoXImageToVideoPipeline
42+
from diffusers.utils import export_to_video, load_image
43+
44+
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
45+
image = load_image(image="cogvideox_rocket.png")
46+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
47+
"THUDM/CogVideoX-5b-I2V",
48+
torch_dtype=torch.bfloat16
49+
)
50+
51+
pipe.vae.enable_tiling()
52+
pipe.vae.enable_slicing()
53+
54+
video = pipe(
55+
prompt=prompt,
56+
image=image,
57+
num_videos_per_prompt=1,
58+
num_inference_steps=50,
59+
num_frames=49,
60+
guidance_scale=6,
61+
generator=torch.Generator(device="cuda").manual_seed(42),
62+
).frames[0]
63+
64+
export_to_video(video, "output.mp4", fps=8)
65+
```
66+
67+
<div class="flex gap-4">
68+
<div>
69+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png"/>
70+
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
71+
</div>
72+
<div>
73+
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_outrocket.gif"/>
74+
<figcaption class="mt-2 text-center text-sm text-gray-500">generated video</figcaption>
75+
</div>
76+
</div>
77+
78+
2679
### Stable Video Diffusion
2780

2881
[SVD](../api/pipelines/svd) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image. You can learn more details about model, like micro-conditioning, in the [Stable Video Diffusion](../using-diffusers/svd) guide.

src/diffusers/models/controlnet_flux.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..models.attention_processor import AttentionProcessor
2424
from ..models.modeling_utils import ModelMixin
2525
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26-
from .controlnet import BaseOutput, zero_module
26+
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
2727
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
2828
from .modeling_outputs import Transformer2DModelOutput
2929
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
@@ -55,6 +55,7 @@ def __init__(
5555
guidance_embeds: bool = False,
5656
axes_dims_rope: List[int] = [16, 56, 56],
5757
num_mode: int = None,
58+
conditioning_embedding_channels: int = None,
5859
):
5960
super().__init__()
6061
self.out_channels = in_channels
@@ -106,7 +107,14 @@ def __init__(
106107
if self.union:
107108
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
108109

109-
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
110+
if conditioning_embedding_channels is not None:
111+
self.input_hint_block = ControlNetConditioningEmbedding(
112+
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
113+
)
114+
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
115+
else:
116+
self.input_hint_block = None
117+
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
110118

111119
self.gradient_checkpointing = False
112120

@@ -269,6 +277,16 @@ def forward(
269277
)
270278
hidden_states = self.x_embedder(hidden_states)
271279

280+
if self.input_hint_block is not None:
281+
controlnet_cond = self.input_hint_block(controlnet_cond)
282+
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
283+
height = height_pw // self.config.patch_size
284+
width = width_pw // self.config.patch_size
285+
controlnet_cond = controlnet_cond.reshape(
286+
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
287+
)
288+
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
289+
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
272290
# add
273291
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
274292

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def forward(
402402
controlnet_block_samples=None,
403403
controlnet_single_block_samples=None,
404404
return_dict: bool = True,
405+
controlnet_blocks_repeat: bool = False,
405406
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
406407
"""
407408
The [`FluxTransformer2DModel`] forward method.
@@ -508,7 +509,13 @@ def custom_forward(*inputs):
508509
if controlnet_block_samples is not None:
509510
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
510511
interval_control = int(np.ceil(interval_control))
511-
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
512+
# For Xlabs ControlNet.
513+
if controlnet_blocks_repeat:
514+
hidden_states = (
515+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
516+
)
517+
else:
518+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
512519

513520
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
514521

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -754,19 +754,22 @@ def __call__(
754754
)
755755
height, width = control_image.shape[-2:]
756756

757-
# vae encode
758-
control_image = self.vae.encode(control_image).latent_dist.sample()
759-
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
760-
761-
# pack
762-
height_control_image, width_control_image = control_image.shape[2:]
763-
control_image = self._pack_latents(
764-
control_image,
765-
batch_size * num_images_per_prompt,
766-
num_channels_latents,
767-
height_control_image,
768-
width_control_image,
769-
)
757+
# xlab controlnet has a input_hint_block and instantx controlnet does not
758+
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
759+
if self.controlnet.input_hint_block is None:
760+
# vae encode
761+
control_image = self.vae.encode(control_image).latent_dist.sample()
762+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
763+
764+
# pack
765+
height_control_image, width_control_image = control_image.shape[2:]
766+
control_image = self._pack_latents(
767+
control_image,
768+
batch_size * num_images_per_prompt,
769+
num_channels_latents,
770+
height_control_image,
771+
width_control_image,
772+
)
770773

771774
# Here we ensure that `control_mode` has the same length as the control_image.
772775
if control_mode is not None:
@@ -777,8 +780,9 @@ def __call__(
777780

778781
elif isinstance(self.controlnet, FluxMultiControlNetModel):
779782
control_images = []
780-
781-
for control_image_ in control_image:
783+
# xlab controlnet has a input_hint_block and instantx controlnet does not
784+
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
785+
for i, control_image_ in enumerate(control_image):
782786
control_image_ = self.prepare_image(
783787
image=control_image_,
784788
width=width,
@@ -790,20 +794,20 @@ def __call__(
790794
)
791795
height, width = control_image_.shape[-2:]
792796

793-
# vae encode
794-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
795-
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
796-
797-
# pack
798-
height_control_image, width_control_image = control_image_.shape[2:]
799-
control_image_ = self._pack_latents(
800-
control_image_,
801-
batch_size * num_images_per_prompt,
802-
num_channels_latents,
803-
height_control_image,
804-
width_control_image,
805-
)
806-
797+
if self.controlnet.nets[0].input_hint_block is None:
798+
# vae encode
799+
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
800+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
801+
802+
# pack
803+
height_control_image, width_control_image = control_image_.shape[2:]
804+
control_image_ = self._pack_latents(
805+
control_image_,
806+
batch_size * num_images_per_prompt,
807+
num_channels_latents,
808+
height_control_image,
809+
width_control_image,
810+
)
807811
control_images.append(control_image_)
808812

809813
control_image = control_images
@@ -927,6 +931,7 @@ def __call__(
927931
img_ids=latent_image_ids,
928932
joint_attention_kwargs=self.joint_attention_kwargs,
929933
return_dict=False,
934+
controlnet_blocks_repeat=controlnet_blocks_repeat,
930935
)[0]
931936

932937
# compute the previous noisy sample x_t -> x_t-1

src/diffusers/training_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
if is_transformers_available():
2525
import transformers
2626

27+
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
28+
import deepspeed
29+
2730
if is_peft_available():
2831
from peft import set_peft_model_state_dict
2932

@@ -442,15 +445,13 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
442445
self.cur_decay_value = decay
443446
one_minus_decay = 1 - decay
444447

445-
context_manager = contextlib.nullcontext
446-
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
447-
import deepspeed
448+
context_manager = contextlib.nullcontext()
448449

449450
if self.foreach:
450451
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
451452
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
452453

453-
with context_manager():
454+
with context_manager:
454455
params_grad = [param for param in parameters if param.requires_grad]
455456
s_params_grad = [
456457
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
@@ -472,7 +473,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
472473
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
473474
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
474475

475-
with context_manager():
476+
with context_manager:
476477
if param.requires_grad:
477478
s_param.sub_(one_minus_decay * (s_param - param))
478479
else:

0 commit comments

Comments
 (0)