diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index dd10664ece18..9e69bd6a668b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -63,6 +63,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -74,6 +75,9 @@ logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + def save_model_card( repo_id: str, @@ -601,6 +605,7 @@ def parse_args(input_args=None): ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") if input_args is not None: args = parser.parse_args(input_args) @@ -924,8 +929,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -988,6 +992,13 @@ def main(args): # because Gemma2 is particularly suited for bfloat16. text_encoder.to(dtype=torch.bfloat16) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + # Initialize a text encoding pipeline and keep it to CPU for now. text_encoding_pipeline = SanaPipeline.from_pretrained( args.pretrained_model_name_or_path, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30e160dd2408..26625753e4b6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3154,6 +3154,11 @@ def __call__( # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1) + if attention_mask.dtype == torch.bool: + attention_mask = torch.logical_not(attention_mask.bool()) + else: + attention_mask = attention_mask.bool() if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)