1919
2020from ...configuration_utils import ConfigMixin , register_to_config
2121from ...loaders import PeftAdapterMixin
22- from ...utils import USE_PEFT_BACKEND , is_torch_version , logging , scale_lora_layers , unscale_lora_layers
22+ from ...utils import USE_PEFT_BACKEND , is_torch_version , logging , scale_lora_layers , unscale_lora_layers , is_torch_npu_available
2323from ..attention_processor import (
2424 Attention ,
2525 AttentionProcessor ,
2626 AttnProcessor2_0 ,
27+ AttnProcessorNPU ,
2728 SanaLinearAttnProcessor2_0 ,
2829)
2930from ..embeddings import PatchEmbed , PixArtAlphaTextProjection
@@ -119,6 +120,12 @@ def __init__(
119120 # 2. Cross Attention
120121 if cross_attention_dim is not None :
121122 self .norm2 = nn .LayerNorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
123+
124+ if is_torch_npu_available ():
125+ attn_processor = AttnProcessorNPU ()
126+ else :
127+ attn_processor = AttnProcessor2_0 ()
128+
122129 self .attn2 = Attention (
123130 query_dim = dim ,
124131 cross_attention_dim = cross_attention_dim ,
@@ -127,7 +134,7 @@ def __init__(
127134 dropout = dropout ,
128135 bias = True ,
129136 out_bias = attention_out_bias ,
130- processor = AttnProcessor2_0 () ,
137+ processor = attn_processor ,
131138 )
132139
133140 # 3. Feed-forward
@@ -250,14 +257,14 @@ def __init__(
250257 inner_dim = num_attention_heads * attention_head_dim
251258
252259 # 1. Patch Embedding
260+ interpolation_scale = interpolation_scale if interpolation_scale is not None else max (sample_size // 64 , 1 )
253261 self .patch_embed = PatchEmbed (
254262 height = sample_size ,
255263 width = sample_size ,
256264 patch_size = patch_size ,
257265 in_channels = in_channels ,
258266 embed_dim = inner_dim ,
259267 interpolation_scale = interpolation_scale ,
260- pos_embed_type = "sincos" if interpolation_scale is not None else None ,
261268 )
262269
263270 # 2. Additional condition embeddings
0 commit comments