Skip to content

Commit 1c4be59

Browse files
authored
Merge branch 'main' into qwen
2 parents 4ead0a5 + 4d9b822 commit 1c4be59

21 files changed

+1502
-390
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- local: installation
66
title: Installation
77
- local: quicktour
8-
title: Quicktour
8+
title: Quickstart
99
- local: stable_diffusion
1010
title: Basic performance
1111

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
1818

19-
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
19+
Qwen-Image comes in the following variants:
20+
21+
| model type | model id |
22+
|:----------:|:--------:|
23+
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
24+
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
2025

2126
<Tip>
2227

@@ -87,10 +92,6 @@ image.save("qwen_fewsteps.png")
8792
- all
8893
- __call__
8994

90-
## QwenImagePipelineOutput
91-
92-
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
93-
9495
## QwenImageImg2ImgPipeline
9596

9697
[[autodoc]] QwenImageImg2ImgPipeline
@@ -102,3 +103,13 @@ image.save("qwen_fewsteps.png")
102103
[[autodoc]] QwenImageInpaintPipeline
103104
- all
104105
- __call__
106+
107+
## QwenImageEditPipeline
108+
109+
[[autodoc]] QwenImageEditPipeline
110+
- all
111+
- __call__
112+
113+
## QwenImagePipelineOutput
114+
115+
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput

docs/source/en/quicktour.md

Lines changed: 155 additions & 249 deletions
Large diffs are not rendered by default.

docs/source/en/stable_diffusion.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@ This guide recommends some basic performance tips for using the [`DiffusionPipel
2222

2323
Reducing the amount of memory used indirectly speeds up generation and can help a model fit on device.
2424

25+
The [`~DiffusionPipeline.enable_model_cpu_offload`] method moves a model to the CPU when it is not in use to save GPU memory.
26+
2527
```py
2628
import torch
2729
from diffusers import DiffusionPipeline
2830

2931
pipeline = DiffusionPipeline.from_pretrained(
3032
"stabilityai/stable-diffusion-xl-base-1.0",
31-
torch_dtype=torch.bfloat16
32-
).to("cuda")
33+
torch_dtype=torch.bfloat16,
34+
device_map="cuda"
35+
)
3336
pipeline.enable_model_cpu_offload()
3437

3538
prompt = """
@@ -44,7 +47,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
4447

4548
Denoising is the most computationally demanding process during diffusion. Methods that optimizes this process accelerates inference speed. Try the following methods for a speed up.
4649

47-
- Add `.to("cuda")` to place the pipeline on a GPU. Placing a model on an accelerator, like a GPU, increases speed because it performs computations in parallel.
50+
- Add `device_map="cuda"` to place the pipeline on a GPU. Placing a model on an accelerator, like a GPU, increases speed because it performs computations in parallel.
4851
- Set `torch_dtype=torch.bfloat16` to execute the pipeline in half-precision. Reducing the data type precision increases speed because it takes less time to perform computations in a lower precision.
4952

5053
```py
@@ -54,8 +57,9 @@ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
5457

5558
pipeline = DiffusionPipeline.from_pretrained(
5659
"stabilityai/stable-diffusion-xl-base-1.0",
57-
torch_dtype=torch.bfloat16
58-
).to("cuda")
60+
torch_dtype=torch.bfloat16,
61+
device_map="cuda
62+
)
5963
```
6064

6165
- Use a faster scheduler, such as [`DPMSolverMultistepScheduler`], which only requires ~20-25 steps.
@@ -88,8 +92,9 @@ Many modern diffusion models deliver high-quality images out-of-the-box. However
8892

8993
pipeline = DiffusionPipeline.from_pretrained(
9094
"stabilityai/stable-diffusion-xl-base-1.0",
91-
torch_dtype=torch.bfloat16
92-
).to("cuda")
95+
torch_dtype=torch.bfloat16,
96+
device_map="cuda"
97+
)
9398

9499
prompt = """
95100
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
@@ -109,8 +114,9 @@ Many modern diffusion models deliver high-quality images out-of-the-box. However
109114

110115
pipeline = DiffusionPipeline.from_pretrained(
111116
"stabilityai/stable-diffusion-xl-base-1.0",
112-
torch_dtype=torch.bfloat16
113-
).to("cuda")
117+
torch_dtype=torch.bfloat16,
118+
device_map="cuda"
119+
)
114120
pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
115121

116122
prompt = """

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@
489489
"PixArtAlphaPipeline",
490490
"PixArtSigmaPAGPipeline",
491491
"PixArtSigmaPipeline",
492+
"QwenImageEditPipeline",
492493
"QwenImageImg2ImgPipeline",
493494
"QwenImageInpaintPipeline",
494495
"QwenImagePipeline",
@@ -1123,6 +1124,7 @@
11231124
PixArtAlphaPipeline,
11241125
PixArtSigmaPAGPipeline,
11251126
PixArtSigmaPipeline,
1127+
QwenImageEditPipeline,
11261128
QwenImageImg2ImgPipeline,
11271129
QwenImageInpaintPipeline,
11281130
QwenImagePipeline,

src/diffusers/models/model_loading_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import functools
1818
import importlib
1919
import inspect
20-
import math
2120
import os
2221
from array import array
2322
from collections import OrderedDict, defaultdict
@@ -717,27 +716,33 @@ def _expand_device_map(device_map, param_names):
717716

718717

719718
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
720-
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
719+
def _caching_allocator_warmup(
720+
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
721+
) -> None:
721722
"""
722723
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
723724
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
724725
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
725726
very large margin.
726727
"""
728+
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
727729
# Remove disk and cpu devices, and cast to proper torch.device
728730
accelerator_device_map = {
729731
param: torch.device(device)
730732
for param, device in expanded_device_map.items()
731733
if str(device) not in ["cpu", "disk"]
732734
}
733-
parameter_count = defaultdict(lambda: 0)
735+
total_byte_count = defaultdict(lambda: 0)
734736
for param_name, device in accelerator_device_map.items():
735737
try:
736738
param = model.get_parameter(param_name)
737739
except AttributeError:
738740
param = model.get_buffer(param_name)
739-
parameter_count[device] += math.prod(param.shape)
741+
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
742+
param_byte_count = param.numel() * param.element_size()
743+
# TODO: account for TP when needed.
744+
total_byte_count[device] += param_byte_count
740745

741746
# This will kick off the caching allocator to avoid having to Malloc afterwards
742-
for device, param_count in parameter_count.items():
743-
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
747+
for device, byte_count in total_byte_count.items():
748+
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,10 +1532,9 @@ def _load_pretrained_model(
15321532
# tensors using their expected shape and not performing any initialization of the memory (empty data).
15331533
# When the actual device allocations happen, the allocator already has a pool of unused device memory
15341534
# that it can re-use for faster loading of the model.
1535-
# TODO: add support for warmup with hf_quantizer
1536-
if device_map is not None and hf_quantizer is None:
1535+
if device_map is not None:
15371536
expanded_device_map = _expand_device_map(device_map, expected_keys)
1538-
_caching_allocator_warmup(model, expanded_device_map, dtype)
1537+
_caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
15391538

15401539
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
15411540
state_dict_folder, state_dict_index = None, None

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2929
from ..modeling_outputs import Transformer2DModelOutput
3030
from ..modeling_utils import ModelMixin
31-
from ..normalization import AdaLayerNormContinuous
31+
from ..normalization import LayerNorm, RMSNorm
3232

3333

3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,6 +584,38 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
584584
return (freqs.cos(), freqs.sin())
585585

586586

587+
class CogView4AdaLayerNormContinuous(nn.Module):
588+
"""
589+
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
590+
Linear on conditioning embedding.
591+
"""
592+
593+
def __init__(
594+
self,
595+
embedding_dim: int,
596+
conditioning_embedding_dim: int,
597+
elementwise_affine: bool = True,
598+
eps: float = 1e-5,
599+
bias: bool = True,
600+
norm_type: str = "layer_norm",
601+
):
602+
super().__init__()
603+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
604+
if norm_type == "layer_norm":
605+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
606+
elif norm_type == "rms_norm":
607+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
608+
else:
609+
raise ValueError(f"unknown norm_type {norm_type}")
610+
611+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
612+
# *** NO SiLU here ***
613+
emb = self.linear(conditioning_embedding.to(x.dtype))
614+
scale, shift = torch.chunk(emb, 2, dim=1)
615+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
616+
return x
617+
618+
587619
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
588620
r"""
589621
Args:
@@ -666,7 +698,7 @@ def __init__(
666698
)
667699

668700
# 4. Output projection
669-
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
701+
self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
670702
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
671703

672704
self.gradient_checkpointing = False

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import functools
1716
import math
1817
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -161,17 +160,17 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161160
super().__init__()
162161
self.theta = theta
163162
self.axes_dim = axes_dim
164-
pos_index = torch.arange(1024)
165-
neg_index = torch.arange(1024).flip(0) * -1 - 1
166-
pos_freqs = torch.cat(
163+
pos_index = torch.arange(4096)
164+
neg_index = torch.arange(4096).flip(0) * -1 - 1
165+
self.pos_freqs = torch.cat(
167166
[
168167
self.rope_params(pos_index, self.axes_dim[0], self.theta),
169168
self.rope_params(pos_index, self.axes_dim[1], self.theta),
170169
self.rope_params(pos_index, self.axes_dim[2], self.theta),
171170
],
172171
dim=1,
173172
)
174-
neg_freqs = torch.cat(
173+
self.neg_freqs = torch.cat(
175174
[
176175
self.rope_params(neg_index, self.axes_dim[0], self.theta),
177176
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180179
dim=1,
181180
)
182181
self.rope_cache = {}
183-
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184-
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
185182

186-
# 是否使用 scale rope
183+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
187184
self.scale_rope = scale_rope
188185

189186
def rope_params(self, index, dim, theta=10000):
@@ -201,35 +198,48 @@ def forward(self, video_fhw, txt_seq_lens, device):
201198
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
202199
txt_length: [bs] a list of 1 integers representing the length of the text
203200
"""
201+
if self.pos_freqs.device != device:
202+
self.pos_freqs = self.pos_freqs.to(device)
203+
self.neg_freqs = self.neg_freqs.to(device)
204+
204205
if isinstance(video_fhw, list):
205206
video_fhw = video_fhw[0]
206-
frame, height, width = video_fhw
207-
rope_key = f"{frame}_{height}_{width}"
208-
209-
if not torch.compiler.is_compiling():
210-
if rope_key not in self.rope_cache:
211-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212-
vid_freqs = self.rope_cache[rope_key]
213-
else:
214-
vid_freqs = self._compute_video_freqs(frame, height, width)
207+
if not isinstance(video_fhw, list):
208+
video_fhw = [video_fhw]
209+
210+
vid_freqs = []
211+
max_vid_index = 0
212+
for idx, fhw in enumerate(video_fhw):
213+
frame, height, width = fhw
214+
rope_key = f"{idx}_{height}_{width}"
215+
216+
if not torch.compiler.is_compiling():
217+
if rope_key not in self.rope_cache:
218+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219+
video_freq = self.rope_cache[rope_key]
220+
else:
221+
video_freq = self._compute_video_freqs(frame, height, width, idx)
222+
video_freq = video_freq.to(device)
223+
vid_freqs.append(video_freq)
215224

216-
if self.scale_rope:
217-
max_vid_index = max(height // 2, width // 2)
218-
else:
219-
max_vid_index = max(height, width)
225+
if self.scale_rope:
226+
max_vid_index = max(height // 2, width // 2, max_vid_index)
227+
else:
228+
max_vid_index = max(height, width, max_vid_index)
220229

221230
max_len = max(txt_seq_lens)
222231
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
232+
vid_freqs = torch.cat(vid_freqs, dim=0)
223233

224234
return vid_freqs, txt_freqs
225235

226236
@functools.lru_cache(maxsize=None)
227-
def _compute_video_freqs(self, frame, height, width):
237+
def _compute_video_freqs(self, frame, height, width, idx=0):
228238
seq_lens = frame * height * width
229239
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230240
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231241

232-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
242+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233243
if self.scale_rope:
234244
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235245
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@
391391
"QwenImagePipeline",
392392
"QwenImageImg2ImgPipeline",
393393
"QwenImageInpaintPipeline",
394+
"QwenImageEditPipeline",
394395
]
395396
try:
396397
if not is_onnx_available():
@@ -708,7 +709,12 @@
708709
from .paint_by_example import PaintByExamplePipeline
709710
from .pia import PIAPipeline
710711
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
711-
from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline
712+
from .qwenimage import (
713+
QwenImageEditPipeline,
714+
QwenImageImg2ImgPipeline,
715+
QwenImageInpaintPipeline,
716+
QwenImagePipeline,
717+
)
712718
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
713719
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
714720
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline

0 commit comments

Comments
 (0)