Skip to content

Commit 963e290

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent 326b98d commit 963e290

File tree

2 files changed

+10
-13
lines changed

2 files changed

+10
-13
lines changed

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,6 @@ def parse_args(input_args=None):
601601
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
602602
)
603603
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
604-
parser.add_argument(
605-
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
606-
)
607604

608605
if input_args is not None:
609606
args = parser.parse_args(input_args)
@@ -970,13 +967,6 @@ def main(args):
970967
vae.requires_grad_(False)
971968
text_encoder.requires_grad_(False)
972969

973-
if args.enable_npu_flash_attention:
974-
if is_torch_npu_available():
975-
logger.info("npu flash attention enabled.")
976-
transformer.enable_npu_flash_attention()
977-
else:
978-
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
979-
980970
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
981971
# as these weights are only used for inference, keeping weights in full precision is not required.
982972
weight_dtype = torch.float32

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...loaders import PeftAdapterMixin
22-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
22+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available
2323
from ..attention_processor import (
2424
Attention,
2525
AttentionProcessor,
2626
AttnProcessor2_0,
27+
AttnProcessorNPU,
2728
SanaLinearAttnProcessor2_0,
2829
)
2930
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -119,6 +120,12 @@ def __init__(
119120
# 2. Cross Attention
120121
if cross_attention_dim is not None:
121122
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
123+
124+
if is_torch_npu_available():
125+
attn_processor = AttnProcessorNPU()
126+
else:
127+
attn_processor = AttnProcessor2_0()
128+
122129
self.attn2 = Attention(
123130
query_dim=dim,
124131
cross_attention_dim=cross_attention_dim,
@@ -127,7 +134,7 @@ def __init__(
127134
dropout=dropout,
128135
bias=True,
129136
out_bias=attention_out_bias,
130-
processor=AttnProcessor2_0(),
137+
processor=attn_processor,
131138
)
132139

133140
# 3. Feed-forward
@@ -250,14 +257,14 @@ def __init__(
250257
inner_dim = num_attention_heads * attention_head_dim
251258

252259
# 1. Patch Embedding
260+
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
253261
self.patch_embed = PatchEmbed(
254262
height=sample_size,
255263
width=sample_size,
256264
patch_size=patch_size,
257265
in_channels=in_channels,
258266
embed_dim=inner_dim,
259267
interpolation_scale=interpolation_scale,
260-
pos_embed_type="sincos" if interpolation_scale is not None else None,
261268
)
262269

263270
# 2. Additional condition embeddings

0 commit comments

Comments
 (0)