1919
2020from ...configuration_utils import ConfigMixin , register_to_config
2121from ...loaders import PeftAdapterMixin
22- from ...utils import USE_PEFT_BACKEND , is_torch_version , logging , scale_lora_layers , unscale_lora_layers
22+ from ...utils import USE_PEFT_BACKEND , is_torch_npu_available , is_torch_version , logging , scale_lora_layers , unscale_lora_layers
2323from ..attention_processor import (
2424 Attention ,
2525 AttentionProcessor ,
2626 AttnProcessor2_0 ,
27+ AttnProcessorNPU ,
2728 SanaLinearAttnProcessor2_0 ,
2829)
2930from ..embeddings import PatchEmbed , PixArtAlphaTextProjection
@@ -119,6 +120,13 @@ def __init__(
119120 # 2. Cross Attention
120121 if cross_attention_dim is not None :
121122 self .norm2 = nn .LayerNorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
123+
124+ # if NPU is available, will use NPU fused attention instead
125+ if is_torch_npu_available ():
126+ attn_processor = AttnProcessorNPU ()
127+ else :
128+ attn_processor = AttnProcessor2_0 ()
129+
122130 self .attn2 = Attention (
123131 query_dim = dim ,
124132 cross_attention_dim = cross_attention_dim ,
@@ -127,7 +135,7 @@ def __init__(
127135 dropout = dropout ,
128136 bias = True ,
129137 out_bias = attention_out_bias ,
130- processor = AttnProcessor2_0 () ,
138+ processor = attn_processor ,
131139 )
132140
133141 # 3. Feed-forward
0 commit comments