Skip to content

Commit 4cea819

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent 3d3aae3 commit 4cea819

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ def __init__(
294294
processor = (
295295
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
296296
)
297+
298+
if is_torch_npu_available():
299+
if isinstance(processor, AttnProcessor2_0):
300+
processor = AttnProcessorNPU()
297301
self.set_processor(processor)
298302

299303
def set_use_xla_flash_attention(

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,6 @@ def __init__(
120120
# 2. Cross Attention
121121
if cross_attention_dim is not None:
122122
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
123-
124-
if is_torch_npu_available():
125-
attn_processor = AttnProcessorNPU()
126-
else:
127-
attn_processor = AttnProcessor2_0()
128-
129123
self.attn2 = Attention(
130124
query_dim=dim,
131125
cross_attention_dim=cross_attention_dim,
@@ -134,7 +128,7 @@ def __init__(
134128
dropout=dropout,
135129
bias=True,
136130
out_bias=attention_out_bias,
137-
processor=attn_processor,
131+
processor=AttnProcessor2_0(),
138132
)
139133

140134
# 3. Feed-forward

0 commit comments

Comments
 (0)