Skip to content

Commit a1965dd

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent 1a72a00 commit a1965dd

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,9 @@ def parse_args(input_args=None):
601601
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
602602
)
603603
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
604+
parser.add_argument(
605+
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
606+
)
604607

605608
if input_args is not None:
606609
args = parser.parse_args(input_args)
@@ -967,6 +970,13 @@ def main(args):
967970
vae.requires_grad_(False)
968971
text_encoder.requires_grad_(False)
969972

973+
if args.enable_npu_flash_attention:
974+
if is_torch_npu_available():
975+
logger.info("npu flash attention enabled.")
976+
transformer.enable_npu_flash_attention()
977+
else:
978+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
979+
970980
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
971981
# as these weights are only used for inference, keeping weights in full precision is not required.
972982
weight_dtype = torch.float32

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...loaders import PeftAdapterMixin
22-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available
22+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2323
from ..attention_processor import (
2424
Attention,
2525
AttentionProcessor,
2626
AttnProcessor2_0,
27-
AttnProcessorNPU,
2827
SanaLinearAttnProcessor2_0,
2928
)
3029
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -120,12 +119,6 @@ def __init__(
120119
# 2. Cross Attention
121120
if cross_attention_dim is not None:
122121
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
123-
124-
if is_torch_npu_available():
125-
attn_processor = AttnProcessorNPU()
126-
else:
127-
attn_processor = AttnProcessor2_0()
128-
129122
self.attn2 = Attention(
130123
query_dim=dim,
131124
cross_attention_dim=cross_attention_dim,
@@ -134,7 +127,7 @@ def __init__(
134127
dropout=dropout,
135128
bias=True,
136129
out_bias=attention_out_bias,
137-
processor=attn_processor,
130+
processor=AttnProcessor2_0,
138131
)
139132

140133
# 3. Feed-forward

0 commit comments

Comments
 (0)