Skip to content

Commit 7881392

Browse files
committed
Merge branch 'unet-1d-time-embed-debugging' of https://github.com/SammyAgrawal/diffusers into unet-1d-time-embed-debugging
2 parents 9335c36 + ff2df6e commit 7881392

29 files changed

+1099
-1163
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,63 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
2424

2525
</Tip>
2626

27+
## LoRA for faster inference
28+
29+
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
30+
number of steps. Refer to the code snippet below:
31+
32+
<details>
33+
<summary>Code</summary>
34+
35+
```py
36+
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
37+
import torch
38+
import math
39+
40+
ckpt_id = "Qwen/Qwen-Image"
41+
42+
# From
43+
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
44+
scheduler_config = {
45+
"base_image_seq_len": 256,
46+
"base_shift": math.log(3), # We use shift=3 in distillation
47+
"invert_sigmas": False,
48+
"max_image_seq_len": 8192,
49+
"max_shift": math.log(3), # We use shift=3 in distillation
50+
"num_train_timesteps": 1000,
51+
"shift": 1.0,
52+
"shift_terminal": None, # set shift_terminal to None
53+
"stochastic_sampling": False,
54+
"time_shift_type": "exponential",
55+
"use_beta_sigmas": False,
56+
"use_dynamic_shifting": True,
57+
"use_exponential_sigmas": False,
58+
"use_karras_sigmas": False,
59+
}
60+
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
61+
pipe = DiffusionPipeline.from_pretrained(
62+
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
63+
).to("cuda")
64+
pipe.load_lora_weights(
65+
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
66+
)
67+
68+
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
69+
negative_prompt = " "
70+
image = pipe(
71+
prompt=prompt,
72+
negative_prompt=negative_prompt,
73+
width=1024,
74+
height=1024,
75+
num_inference_steps=8,
76+
true_cfg_scale=1.0,
77+
generator=torch.manual_seed(0),
78+
).images[0]
79+
image.save("qwen_fewsteps.png")
80+
```
81+
82+
</details>
83+
2784
## QwenImagePipeline
2885

2986
[[autodoc]] QwenImagePipeline

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20772077
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
20782078
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
20792079
return converted_state_dict
2080+
2081+
2082+
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2083+
converted_state_dict = {}
2084+
all_keys = list(state_dict.keys())
2085+
down_key = ".lora_down.weight"
2086+
up_key = ".lora_up.weight"
2087+
2088+
def get_alpha_scales(down_weight, alpha_key):
2089+
rank = down_weight.shape[0]
2090+
alpha = state_dict.pop(alpha_key).item()
2091+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2092+
scale_down = scale
2093+
scale_up = 1.0
2094+
while scale_down * 2 < scale_up:
2095+
scale_down *= 2
2096+
scale_up /= 2
2097+
return scale_down, scale_up
2098+
2099+
for k in all_keys:
2100+
if k.endswith(down_key):
2101+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2102+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2103+
alpha_key = k.replace(down_key, ".alpha")
2104+
2105+
down_weight = state_dict.pop(k)
2106+
up_weight = state_dict.pop(k.replace(down_key, up_key))
2107+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2108+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2109+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2110+
2111+
if len(state_dict) > 0:
2112+
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
2113+
2114+
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2115+
return converted_state_dict

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_convert_non_diffusers_lora_to_diffusers,
5050
_convert_non_diffusers_ltxv_lora_to_diffusers,
5151
_convert_non_diffusers_lumina2_lora_to_diffusers,
52+
_convert_non_diffusers_qwen_lora_to_diffusers,
5253
_convert_non_diffusers_wan_lora_to_diffusers,
5354
_convert_xlabs_flux_lora_to_diffusers,
5455
_maybe_map_sgm_blocks_to_diffusers,
@@ -6548,7 +6549,6 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
65486549

65496550
@classmethod
65506551
@validate_hf_hub_args
6551-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
65526552
def lora_state_dict(
65536553
cls,
65546554
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -6642,6 +6642,10 @@ def lora_state_dict(
66426642
logger.warning(warn_msg)
66436643
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
66446644

6645+
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
6646+
if has_alphas_in_sd:
6647+
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
6648+
66456649
out = (state_dict, metadata) if return_lora_metadata else state_dict
66466650
return out
66476651

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515

16+
import functools
1617
import math
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
162163
self.axes_dim = axes_dim
163164
pos_index = torch.arange(1024)
164165
neg_index = torch.arange(1024).flip(0) * -1 - 1
165-
self.pos_freqs = torch.cat(
166+
pos_freqs = torch.cat(
166167
[
167168
self.rope_params(pos_index, self.axes_dim[0], self.theta),
168169
self.rope_params(pos_index, self.axes_dim[1], self.theta),
169170
self.rope_params(pos_index, self.axes_dim[2], self.theta),
170171
],
171172
dim=1,
172173
)
173-
self.neg_freqs = torch.cat(
174+
neg_freqs = torch.cat(
174175
[
175176
self.rope_params(neg_index, self.axes_dim[0], self.theta),
176177
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179180
dim=1,
180181
)
181182
self.rope_cache = {}
183+
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184+
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
182185

183186
# 是否使用 scale rope
184187
self.scale_rope = scale_rope
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device):
198201
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199202
txt_length: [bs] a list of 1 integers representing the length of the text
200203
"""
201-
if self.pos_freqs.device != device:
202-
self.pos_freqs = self.pos_freqs.to(device)
203-
self.neg_freqs = self.neg_freqs.to(device)
204-
205204
if isinstance(video_fhw, list):
206205
video_fhw = video_fhw[0]
207206
frame, height, width = video_fhw
208207
rope_key = f"{frame}_{height}_{width}"
209208

210-
if rope_key not in self.rope_cache:
211-
seq_lens = frame * height * width
212-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
213-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
214-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
215-
if self.scale_rope:
216-
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
217-
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
218-
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
219-
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
220-
221-
else:
222-
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
223-
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
224-
225-
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
226-
self.rope_cache[rope_key] = freqs.clone().contiguous()
227-
vid_freqs = self.rope_cache[rope_key]
209+
if not torch.compiler.is_compiling():
210+
if rope_key not in self.rope_cache:
211+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212+
vid_freqs = self.rope_cache[rope_key]
213+
else:
214+
vid_freqs = self._compute_video_freqs(frame, height, width)
228215

229216
if self.scale_rope:
230217
max_vid_index = max(height // 2, width // 2)
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
236223

237224
return vid_freqs, txt_freqs
238225

226+
@functools.lru_cache(maxsize=None)
227+
def _compute_video_freqs(self, frame, height, width):
228+
seq_lens = frame * height * width
229+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231+
232+
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233+
if self.scale_rope:
234+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
236+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
237+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
238+
else:
239+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
240+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
241+
242+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
243+
return freqs.clone().contiguous()
244+
239245

240246
class QwenDoubleStreamAttnProcessor2_0:
241247
"""
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
482488
_supports_gradient_checkpointing = True
483489
_no_split_modules = ["QwenImageTransformerBlock"]
484490
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
491+
_repeated_blocks = ["QwenImageTransformerBlock"]
485492

486493
@register_to_config
487494
def __init__(

src/diffusers/modular_pipelines/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
_import_structure["modular_pipeline"] = [
2626
"ModularPipelineBlocks",
2727
"ModularPipeline",
28-
"PipelineBlock",
2928
"AutoPipelineBlocks",
3029
"SequentialPipelineBlocks",
3130
"LoopSequentialPipelineBlocks",
@@ -59,7 +58,6 @@
5958
LoopSequentialPipelineBlocks,
6059
ModularPipeline,
6160
ModularPipelineBlocks,
62-
PipelineBlock,
6361
PipelineState,
6462
SequentialPipelineBlocks,
6563
)

0 commit comments

Comments
 (0)