Skip to content

Commit cfbbb8f

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent 2052049 commit cfbbb8f

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -521,11 +521,6 @@ def set_processor(self, processor: "AttnProcessor") -> None:
521521
processor (`AttnProcessor`):
522522
The attention processor to use.
523523
"""
524-
# set to use npu flash attention from 'torch_npu' if available
525-
if is_torch_npu_available():
526-
if isinstance(processor, AttnProcessor2_0):
527-
processor = AttnProcessorNPU()
528-
529524
# if current processor is in `self._modules` and if passed `processor` is not, we need to
530525
# pop `processor` from `self._modules`
531526
if (

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...loaders import PeftAdapterMixin
22-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
22+
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2323
from ..attention_processor import (
2424
Attention,
2525
AttentionProcessor,
2626
AttnProcessor2_0,
27+
AttnProcessorNPU,
2728
SanaLinearAttnProcessor2_0,
2829
)
2930
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -119,6 +120,13 @@ def __init__(
119120
# 2. Cross Attention
120121
if cross_attention_dim is not None:
121122
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
123+
124+
# if NPU is available, will use NPU fused attention instead
125+
if is_torch_npu_available():
126+
attn_processor = AttnProcessorNPU()
127+
else:
128+
attn_processor = AttnProcessor2_0()
129+
122130
self.attn2 = Attention(
123131
query_dim=dim,
124132
cross_attention_dim=cross_attention_dim,
@@ -127,7 +135,7 @@ def __init__(
127135
dropout=dropout,
128136
bias=True,
129137
out_bias=attention_out_bias,
130-
processor=AttnProcessor2_0(),
138+
processor=attn_processor,
131139
)
132140

133141
# 3. Feed-forward

0 commit comments

Comments
 (0)