Skip to content

Commit 2052049

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent 0d9e1b3 commit 2052049

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,6 @@ def __init__(
294294
processor = (
295295
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
296296
)
297-
298-
if is_torch_npu_available():
299-
if isinstance(processor, AttnProcessor2_0):
300-
processor = AttnProcessorNPU()
301297
self.set_processor(processor)
302298

303299
def set_use_xla_flash_attention(
@@ -525,6 +521,11 @@ def set_processor(self, processor: "AttnProcessor") -> None:
525521
processor (`AttnProcessor`):
526522
The attention processor to use.
527523
"""
524+
# set to use npu flash attention from 'torch_npu' if available
525+
if is_torch_npu_available():
526+
if isinstance(processor, AttnProcessor2_0):
527+
processor = AttnProcessorNPU()
528+
528529
# if current processor is in `self._modules` and if passed `processor` is not, we need to
529530
# pop `processor` from `self._modules`
530531
if (

0 commit comments

Comments
 (0)