Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 4 additions & 36 deletions src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):

_supports_gradient_checkpointing = False

# fmt: off
@register_to_config
def __init__(
self,
Expand All @@ -678,43 +679,10 @@ def __init__(
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: List[float] = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
],
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
],
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
) -> None:
# fmt: on
super().__init__()

self.z_dim = z_dim
Expand Down
28 changes: 11 additions & 17 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def apply_rotary_emb_qwen(


class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
def __init__(self, embedding_dim):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
Expand Down Expand Up @@ -473,8 +473,6 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
joint_attention_dim (`int`, defaults to `3584`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
Expand All @@ -495,8 +493,7 @@ def __init__(
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
guidance_embeds: bool = False, # TODO: this should probably be removed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @naykun can you confirm if we need this config?

axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
Expand All @@ -505,9 +502,7 @@ def __init__(

self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)

self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)

self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)

Expand Down Expand Up @@ -538,10 +533,9 @@ def forward(
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance: torch.Tensor = None, # TODO: this should probably be removed
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`QwenTransformer2DModel`] forward method.
Expand All @@ -555,7 +549,7 @@ def forward(
Mask of the input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*):
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Expand All @@ -567,17 +561,17 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
Expand Down Expand Up @@ -617,7 +611,7 @@ def forward(
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
joint_attention_kwargs=attention_kwargs,
)

# Use only the image part (hidden_states) from the dual-stream blocks
Expand Down
Loading
Loading