Skip to content

Commit 0ef20e7

Browse files
author
蒋硕
committed
Improve the performance and suitable for NPU
1 parent 4e3cc2c commit 0ef20e7

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,7 +2274,8 @@ def __call__(
22742274
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
22752275
)
22762276

2277-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
2277+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2278+
hidden_states = hidden_states.to(query.dtype)
22782279

22792280
# linear proj
22802281
hidden_states = attn.to_out[0](hidden_states)
@@ -4276,7 +4277,6 @@ def __init__(self):
42764277
CROSS_ATTENTION_PROCESSORS = (
42774278
AttnProcessor,
42784279
AttnProcessor2_0,
4279-
AttnProcessorNPU,
42804280
XFormersAttnProcessor,
42814281
SlicedAttnProcessor,
42824282
IPAdapterAttnProcessor,
@@ -4286,7 +4286,6 @@ def __init__(self):
42864286
AttentionProcessor = Union[
42874287
AttnProcessor,
42884288
AttnProcessor2_0,
4289-
AttnProcessorNPU,
42904289
FusedAttnProcessor2_0,
42914290
XFormersAttnProcessor,
42924291
SlicedAttnProcessor,
@@ -4301,4 +4300,4 @@ def __init__(self):
43014300
PAGIdentitySelfAttnProcessor2_0,
43024301
PAGCFGHunyuanAttnProcessor2_0,
43034302
PAGHunyuanAttnProcessor2_0,
4304-
]
4303+
]

0 commit comments

Comments
 (0)