Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
626e339
NPU Adaption for Sanna
Dec 30, 2024
1a72a00
NPU Adaption for Sanna
leisuzz Dec 30, 2024
4d67c54
Merge branch 'main' into main
sayakpaul Jan 3, 2025
a1965dd
NPU Adaption for Sanna
Jan 7, 2025
2c3b117
Merge https://github.com/leisuzz/diffusers
Jan 7, 2025
326b98d
NPU Adaption for Sanna
Jan 7, 2025
715822f
Merge branch 'main' into main
leisuzz Jan 7, 2025
963e290
NPU Adaption for Sanna
Jan 7, 2025
510e1d6
Merge https://github.com/leisuzz/diffusers
Jan 7, 2025
3d3aae3
NPU Adaption for Sanna
Jan 8, 2025
4cea819
NPU Adaption for Sanna
Jan 8, 2025
0d9e1b3
NPU Adaption for Sanna
Jan 8, 2025
2052049
NPU Adaption for Sanna
Jan 8, 2025
487dd1a
Merge branch 'main' into main
leisuzz Jan 13, 2025
cfbbb8f
NPU Adaption for Sanna
Jan 14, 2025
7b8ad74
Merge branch 'main' of https://github.com/leisuzz/diffusers
Jan 14, 2025
d7d54d8
Merge branch 'main' into main
leisuzz Jan 16, 2025
ad4beaa
Merge branch 'main' into main
leisuzz Jan 17, 2025
52d8c71
Merge branch 'main' into main
leisuzz Jan 22, 2025
4c1d56d
NPU Adaption for Sanna
Jan 23, 2025
63e3459
Merge https://github.com/leisuzz/diffusers
Jan 23, 2025
d61d570
NPU Adaption for Sanna
Jan 23, 2025
ab2d71b
NPU Adaption for Sanna
Jan 23, 2025
a323229
Merge branch 'main' into main
leisuzz Jan 23, 2025
a456fb1
NPU Adaption for Sanna
Jan 23, 2025
fedfdd4
NPU Adaption for Sanna
Jan 24, 2025
7364276
Merge branch 'main' of https://github.com/leisuzz/diffusers
Jan 24, 2025
3add6de
NPU Adaption for Sanna
Jan 24, 2025
70cf529
NPU Adaption for Sanna
Jan 24, 2025
8f18aae
NPU Adaption for Sanna
Jan 24, 2025
feb8064
Merge branch 'main' into main
leisuzz Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module


Expand All @@ -74,6 +75,9 @@

logger = get_logger(__name__)

if is_torch_npu_available():
torch.npu.config.allow_internal_format = False


def save_model_card(
repo_id: str,
Expand Down Expand Up @@ -920,8 +924,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down
15 changes: 14 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ def __init__(
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm I don't think we should change the default attention processor here
let's keep this logic in SANA:)

if is_torch_npu_available():
if isinstance(processor, AttnProcessor2_0):
processor = AttnProcessorNPU()
self.set_processor(processor)

def set_use_xla_flash_attention(
Expand Down Expand Up @@ -3147,7 +3151,16 @@ def __call__(
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attn_mask = attention_mask[0]
seq_len = hidden_states.shape[1]
attention_mask = attn_mask.repeat_interleave(seq_len * batch_size, dim=0)
attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1])

if attention_mask.dtype != torch.uint8:
if attention_mask.dtype == torch.bool:
attention_mask = torch.logical_not(attention_mask.bool())
else:
attention_mask = attention_mask.to(torch.uint8)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, is_torch_npu_available
from ..attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor2_0,
AttnProcessorNPU,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
Expand Down