From c074fe4d6d041bcebe9215e45c6ad78ff79f39c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=E7=9F=B3=E9=A1=B5?= Date: Mon, 6 Jan 2025 12:55:52 +0800 Subject: [PATCH] NPU adaption for FLUX --- examples/controlnet/train_controlnet_flux.py | 10 ---------- src/diffusers/models/transformers/transformer_flux.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 6f472b3df62b..0bfffb85e571 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -473,9 +473,6 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) - parser.add_argument( - "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." - ) parser.add_argument( "--set_grads_to_none", action="store_true", @@ -970,13 +967,6 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - if args.enable_npu_flash_attention: - if is_torch_npu_available(): - logger.info("npu flash attention enabled.") - flux_transformer.enable_npu_flash_attention() - else: - raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") - if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f5e92700b2f3..bcd458ac7506 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -29,6 +29,7 @@ FluxAttnProcessor2_0, FluxAttnProcessor2_0_NPU, FusedFluxAttnProcessor2_0, + FusedFluxAttnProcessor2_0_NPU, ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -140,7 +141,10 @@ def __init__( self.norm1_context = AdaLayerNormZero(dim) if hasattr(F, "scaled_dot_product_attention"): - processor = FluxAttnProcessor2_0() + if is_torch_npu_available(): + processor = FluxAttnProcessor2_0_NPU() + else: + processor = FluxAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." @@ -405,7 +409,10 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) - self.set_attn_processor(FusedFluxAttnProcessor2_0()) + if is_torch_npu_available(): + self.set_attn_processor(FusedFluxAttnProcessor2_0_NPU()) + else: + self.set_attn_processor(FusedFluxAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self):