Skip to content

Commit 626e339

Browse files
author
J石页
committed
NPU Adaption for Sanna
1 parent 55ac1db commit 626e339

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
is_wandb_available,
6464
)
6565
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
66+
from diffusers.utils.import_utils import is_torch_npu_available
6667
from diffusers.utils.torch_utils import is_compiled_module
6768

6869

@@ -74,6 +75,9 @@
7475

7576
logger = get_logger(__name__)
7677

78+
if is_torch_npu_available():
79+
torch.npu.config.allow_internal_format = False
80+
7781

7882
def save_model_card(
7983
repo_id: str,
@@ -920,8 +924,7 @@ def main(args):
920924
image.save(image_filename)
921925

922926
del pipeline
923-
if torch.cuda.is_available():
924-
torch.cuda.empty_cache()
927+
free_memory()
925928

926929
# Handle the repository creation
927930
if accelerator.is_main_process:
@@ -979,10 +982,10 @@ def main(args):
979982
)
980983

981984
# VAE should always be kept in fp32 for SANA (?)
982-
vae.to(dtype=torch.float32)
985+
vae.to(accelerator.device, dtype=torch.float32)
983986
transformer.to(accelerator.device, dtype=weight_dtype)
984987
# because Gemma2 is particularly suited for bfloat16.
985-
text_encoder.to(dtype=torch.bfloat16)
988+
text_encoder.to(accelerator.device, dtype=torch.bfloat16)
986989

987990
# Initialize a text encoding pipeline and keep it to CPU for now.
988991
text_encoding_pipeline = SanaPipeline.from_pretrained(

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3147,7 +3147,16 @@ def __call__(
31473147
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
31483148
# scaled_dot_product_attention expects attention_mask shape to be
31493149
# (batch, heads, source_length, target_length)
3150-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3150+
attn_mask = attention_mask[0]
3151+
seq_len = hidden_states.shape[1]
3152+
attention_mask = attn_mask.repeat_interleave(seq_len * batch_size, dim=0)
3153+
attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1])
3154+
3155+
if attention_mask.dtype != torch.uint8:
3156+
if attention_mask.dtype == torch.bool:
3157+
attention_mask = torch.logical_not(attention_mask.bool())
3158+
else:
3159+
attention_mask = attention_mask.to(torch.uint8)
31513160

31523161
if attn.group_norm is not None:
31533162
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...loaders import PeftAdapterMixin
22-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
22+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available
2323
from ..attention_processor import (
2424
Attention,
2525
AttentionProcessor,
2626
AttnProcessor2_0,
27+
AttnProcessorNPU,
2728
SanaLinearAttnProcessor2_0,
2829
)
2930
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -119,6 +120,12 @@ def __init__(
119120
# 2. Cross Attention
120121
if cross_attention_dim is not None:
121122
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
123+
124+
if is_torch_npu_available():
125+
attn_processor = AttnProcessorNPU()
126+
else:
127+
attn_processor = AttnProcessor2_0()
128+
122129
self.attn2 = Attention(
123130
query_dim=dim,
124131
cross_attention_dim=cross_attention_dim,
@@ -127,7 +134,7 @@ def __init__(
127134
dropout=dropout,
128135
bias=True,
129136
out_bias=attention_out_bias,
130-
processor=AttnProcessor2_0(),
137+
processor=attn_processor,
131138
)
132139

133140
# 3. Feed-forward

0 commit comments

Comments
 (0)