Skip to content

Commit 446e28d

Browse files
committed
Merge branch 'main' into qwen-image-training
2 parents 67dfa47 + 69a9828 commit 446e28d

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 The Qwen-Image Team and The HuggingFace Team. All rights reserved.
1+
# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -11,6 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
#
15+
# We gratefully acknowledge the Wan Team for their outstanding contributions.
16+
# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
17+
# For more information about the Wan VAE, please refer to:
18+
# - GitHub: https://github.com/Wan-Video/Wan2.1
19+
# - arXiv: https://arxiv.org/abs/2503.20314
1420

1521
from typing import List, Optional, Tuple, Union
1622

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(
180180
added_kv_proj_dim: Optional[int] = None,
181181
cross_attention_dim_head: Optional[int] = None,
182182
processor=None,
183+
is_cross_attention=None,
183184
):
184185
super().__init__()
185186

@@ -207,6 +208,8 @@ def __init__(
207208
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
208209
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
209210

211+
self.is_cross_attention = cross_attention_dim_head is not None
212+
210213
self.set_processor(processor)
211214

212215
def fuse_projections(self):

0 commit comments

Comments
 (0)