Skip to content

Commit af03f73

Browse files
authored
Merge branch 'main' into wan22-lightx2v
2 parents b09fc48 + 4a9dbd5 commit af03f73

21 files changed

+467
-913
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515

16+
import functools
1617
import math
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
162163
self.axes_dim = axes_dim
163164
pos_index = torch.arange(1024)
164165
neg_index = torch.arange(1024).flip(0) * -1 - 1
165-
self.pos_freqs = torch.cat(
166+
pos_freqs = torch.cat(
166167
[
167168
self.rope_params(pos_index, self.axes_dim[0], self.theta),
168169
self.rope_params(pos_index, self.axes_dim[1], self.theta),
169170
self.rope_params(pos_index, self.axes_dim[2], self.theta),
170171
],
171172
dim=1,
172173
)
173-
self.neg_freqs = torch.cat(
174+
neg_freqs = torch.cat(
174175
[
175176
self.rope_params(neg_index, self.axes_dim[0], self.theta),
176177
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179180
dim=1,
180181
)
181182
self.rope_cache = {}
183+
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184+
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
182185

183186
# 是否使用 scale rope
184187
self.scale_rope = scale_rope
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device):
198201
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199202
txt_length: [bs] a list of 1 integers representing the length of the text
200203
"""
201-
if self.pos_freqs.device != device:
202-
self.pos_freqs = self.pos_freqs.to(device)
203-
self.neg_freqs = self.neg_freqs.to(device)
204-
205204
if isinstance(video_fhw, list):
206205
video_fhw = video_fhw[0]
207206
frame, height, width = video_fhw
208207
rope_key = f"{frame}_{height}_{width}"
209208

210-
if rope_key not in self.rope_cache:
211-
seq_lens = frame * height * width
212-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
213-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
214-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
215-
if self.scale_rope:
216-
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
217-
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
218-
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
219-
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
220-
221-
else:
222-
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
223-
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
224-
225-
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
226-
self.rope_cache[rope_key] = freqs.clone().contiguous()
227-
vid_freqs = self.rope_cache[rope_key]
209+
if not torch.compiler.is_compiling():
210+
if rope_key not in self.rope_cache:
211+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212+
vid_freqs = self.rope_cache[rope_key]
213+
else:
214+
vid_freqs = self._compute_video_freqs(frame, height, width)
228215

229216
if self.scale_rope:
230217
max_vid_index = max(height // 2, width // 2)
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
236223

237224
return vid_freqs, txt_freqs
238225

226+
@functools.lru_cache(maxsize=None)
227+
def _compute_video_freqs(self, frame, height, width):
228+
seq_lens = frame * height * width
229+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231+
232+
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233+
if self.scale_rope:
234+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
236+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
237+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
238+
else:
239+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
240+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
241+
242+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
243+
return freqs.clone().contiguous()
244+
239245

240246
class QwenDoubleStreamAttnProcessor2_0:
241247
"""
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
482488
_supports_gradient_checkpointing = True
483489
_no_split_modules = ["QwenImageTransformerBlock"]
484490
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
491+
_repeated_blocks = ["QwenImageTransformerBlock"]
485492

486493
@register_to_config
487494
def __init__(

src/diffusers/modular_pipelines/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
_import_structure["modular_pipeline"] = [
2626
"ModularPipelineBlocks",
2727
"ModularPipeline",
28-
"PipelineBlock",
2928
"AutoPipelineBlocks",
3029
"SequentialPipelineBlocks",
3130
"LoopSequentialPipelineBlocks",
@@ -59,7 +58,6 @@
5958
LoopSequentialPipelineBlocks,
6059
ModularPipeline,
6160
ModularPipelineBlocks,
62-
PipelineBlock,
6361
PipelineState,
6462
SequentialPipelineBlocks,
6563
)

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...schedulers import FlowMatchEulerDiscreteScheduler
2323
from ...utils import logging
2424
from ...utils.torch_utils import randn_tensor
25-
from ..modular_pipeline import PipelineBlock, PipelineState
25+
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2626
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2727
from .modular_pipeline import FluxModularPipeline
2828

@@ -231,7 +231,7 @@ def _get_initial_timesteps_and_optionals(
231231
return timesteps, num_inference_steps, sigmas, guidance
232232

233233

234-
class FluxInputStep(PipelineBlock):
234+
class FluxInputStep(ModularPipelineBlocks):
235235
model_name = "flux"
236236

237237
@property
@@ -249,11 +249,6 @@ def description(self) -> str:
249249
def inputs(self) -> List[InputParam]:
250250
return [
251251
InputParam("num_images_per_prompt", default=1),
252-
]
253-
254-
@property
255-
def intermediate_inputs(self) -> List[str]:
256-
return [
257252
InputParam(
258253
"prompt_embeds",
259254
required=True,
@@ -322,7 +317,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
322317
return components, state
323318

324319

325-
class FluxSetTimestepsStep(PipelineBlock):
320+
class FluxSetTimestepsStep(ModularPipelineBlocks):
326321
model_name = "flux"
327322

328323
@property
@@ -340,14 +335,10 @@ def inputs(self) -> List[InputParam]:
340335
InputParam("timesteps"),
341336
InputParam("sigmas"),
342337
InputParam("guidance_scale", default=3.5),
338+
InputParam("latents", type_hint=torch.Tensor),
343339
InputParam("num_images_per_prompt", default=1),
344340
InputParam("height", type_hint=int),
345341
InputParam("width", type_hint=int),
346-
]
347-
348-
@property
349-
def intermediate_inputs(self) -> List[str]:
350-
return [
351342
InputParam(
352343
"batch_size",
353344
required=True,
@@ -398,7 +389,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
398389
return components, state
399390

400391

401-
class FluxImg2ImgSetTimestepsStep(PipelineBlock):
392+
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
402393
model_name = "flux"
403394

404395
@property
@@ -420,11 +411,6 @@ def inputs(self) -> List[InputParam]:
420411
InputParam("num_images_per_prompt", default=1),
421412
InputParam("height", type_hint=int),
422413
InputParam("width", type_hint=int),
423-
]
424-
425-
@property
426-
def intermediate_inputs(self) -> List[str]:
427-
return [
428414
InputParam(
429415
"batch_size",
430416
required=True,
@@ -497,7 +483,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
497483
return components, state
498484

499485

500-
class FluxPrepareLatentsStep(PipelineBlock):
486+
class FluxPrepareLatentsStep(ModularPipelineBlocks):
501487
model_name = "flux"
502488

503489
@property
@@ -515,11 +501,6 @@ def inputs(self) -> List[InputParam]:
515501
InputParam("width", type_hint=int),
516502
InputParam("latents", type_hint=Optional[torch.Tensor]),
517503
InputParam("num_images_per_prompt", type_hint=int, default=1),
518-
]
519-
520-
@property
521-
def intermediate_inputs(self) -> List[InputParam]:
522-
return [
523504
InputParam("generator"),
524505
InputParam(
525506
"batch_size",
@@ -621,7 +602,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
621602
return components, state
622603

623604

624-
class FluxImg2ImgPrepareLatentsStep(PipelineBlock):
605+
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
625606
model_name = "flux"
626607

627608
@property
@@ -639,11 +620,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
639620
InputParam("width", type_hint=int),
640621
InputParam("latents", type_hint=Optional[torch.Tensor]),
641622
InputParam("num_images_per_prompt", type_hint=int, default=1),
642-
]
643-
644-
@property
645-
def intermediate_inputs(self) -> List[InputParam]:
646-
return [
647623
InputParam("generator"),
648624
InputParam(
649625
"image_latents",

src/diffusers/modular_pipelines/flux/decoders.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...models import AutoencoderKL
2323
from ...utils import logging
2424
from ...video_processor import VaeImageProcessor
25-
from ..modular_pipeline import PipelineBlock, PipelineState
25+
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2626
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2727

2828

@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
4545
return latents
4646

4747

48-
class FluxDecodeStep(PipelineBlock):
48+
class FluxDecodeStep(ModularPipelineBlocks):
4949
model_name = "flux"
5050

5151
@property
@@ -70,17 +70,12 @@ def inputs(self) -> List[Tuple[str, Any]]:
7070
InputParam("output_type", default="pil"),
7171
InputParam("height", default=1024),
7272
InputParam("width", default=1024),
73-
]
74-
75-
@property
76-
def intermediate_inputs(self) -> List[str]:
77-
return [
7873
InputParam(
7974
"latents",
8075
required=True,
8176
type_hint=torch.Tensor,
8277
description="The denoised latents from the denoising step",
83-
)
78+
),
8479
]
8580

8681
@property

src/diffusers/modular_pipelines/flux/denoise.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..modular_pipeline import (
2323
BlockState,
2424
LoopSequentialPipelineBlocks,
25-
PipelineBlock,
25+
ModularPipelineBlocks,
2626
PipelineState,
2727
)
2828
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -32,7 +32,7 @@
3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3333

3434

35-
class FluxLoopDenoiser(PipelineBlock):
35+
class FluxLoopDenoiser(ModularPipelineBlocks):
3636
model_name = "flux"
3737

3838
@property
@@ -49,11 +49,8 @@ def description(self) -> str:
4949

5050
@property
5151
def inputs(self) -> List[Tuple[str, Any]]:
52-
return [InputParam("joint_attention_kwargs")]
53-
54-
@property
55-
def intermediate_inputs(self) -> List[str]:
5652
return [
53+
InputParam("joint_attention_kwargs"),
5754
InputParam(
5855
"latents",
5956
required=True,
@@ -113,7 +110,7 @@ def __call__(
113110
return components, block_state
114111

115112

116-
class FluxLoopAfterDenoiser(PipelineBlock):
113+
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
117114
model_name = "flux"
118115

119116
@property
@@ -175,7 +172,7 @@ def loop_expected_components(self) -> List[ComponentSpec]:
175172
]
176173

177174
@property
178-
def loop_intermediate_inputs(self) -> List[InputParam]:
175+
def loop_inputs(self) -> List[InputParam]:
179176
return [
180177
InputParam(
181178
"timesteps",

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
2525
from ...models import AutoencoderKL
2626
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
27-
from ..modular_pipeline import PipelineBlock, PipelineState
27+
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2828
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
2929
from .modular_pipeline import FluxModularPipeline
3030

@@ -67,7 +67,7 @@ def retrieve_latents(
6767
raise AttributeError("Could not access latents of provided encoder_output")
6868

6969

70-
class FluxVaeEncoderStep(PipelineBlock):
70+
class FluxVaeEncoderStep(ModularPipelineBlocks):
7171
model_name = "flux"
7272

7373
@property
@@ -88,11 +88,10 @@ def expected_components(self) -> List[ComponentSpec]:
8888

8989
@property
9090
def inputs(self) -> List[InputParam]:
91-
return [InputParam("image", required=True), InputParam("height"), InputParam("width")]
92-
93-
@property
94-
def intermediate_inputs(self) -> List[InputParam]:
9591
return [
92+
InputParam("image", required=True),
93+
InputParam("height"),
94+
InputParam("width"),
9695
InputParam("generator"),
9796
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
9897
InputParam(
@@ -157,7 +156,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
157156
return components, state
158157

159158

160-
class FluxTextEncoderStep(PipelineBlock):
159+
class FluxTextEncoderStep(ModularPipelineBlocks):
161160
model_name = "flux"
162161

163162
@property

0 commit comments

Comments
 (0)