diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index b803babdc827..c24d16c6005a 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -642,6 +642,7 @@ def parse_args(input_args=None): ], help="The image interpolation method to use for resizing images.", ) + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") if input_args is not None: args = parser.parse_args(input_args) @@ -1182,6 +1183,13 @@ def main(args): text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index a8a76097f3c3..2353625c3878 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -80,6 +80,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -686,6 +687,7 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") if input_args is not None: args = parser.parse_args(input_args) @@ -1213,6 +1215,13 @@ def main(args): text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 6aa165ed20b3..ffeef7b4b34b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -706,6 +706,7 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") if input_args is not None: args = parser.parse_args(input_args) @@ -1354,6 +1355,13 @@ def main(args): text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 60c7eb1dbabe..7ab371a1a18e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,8 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers -from ...utils.import_utils import is_torch_npu_available +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -354,25 +353,13 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - if is_torch_npu_available(): - from ..attention_processor import FluxAttnProcessor2_0_NPU - - deprecation_message = ( - "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " - "should be set explicitly using the `set_attn_processor` method." - ) - deprecate("npu_processor", "0.34.0", deprecation_message) - processor = FluxAttnProcessor2_0_NPU() - else: - processor = FluxAttnProcessor() - self.attn = FluxAttention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, - processor=processor, + processor=FluxAttnProcessor(), eps=1e-6, pre_only=True, )