Skip to content

Commit c074fe4

Browse files
author
J石页
committed
NPU adaption for FLUX
1 parent b572635 commit c074fe4

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

examples/controlnet/train_controlnet_flux.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,6 @@ def parse_args(input_args=None):
473473
parser.add_argument(
474474
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
475475
)
476-
parser.add_argument(
477-
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
478-
)
479476
parser.add_argument(
480477
"--set_grads_to_none",
481478
action="store_true",
@@ -970,13 +967,6 @@ def load_model_hook(models, input_dir):
970967
accelerator.register_save_state_pre_hook(save_model_hook)
971968
accelerator.register_load_state_pre_hook(load_model_hook)
972969

973-
if args.enable_npu_flash_attention:
974-
if is_torch_npu_available():
975-
logger.info("npu flash attention enabled.")
976-
flux_transformer.enable_npu_flash_attention()
977-
else:
978-
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
979-
980970
if args.enable_xformers_memory_efficient_attention:
981971
if is_xformers_available():
982972
import xformers

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FluxAttnProcessor2_0,
3030
FluxAttnProcessor2_0_NPU,
3131
FusedFluxAttnProcessor2_0,
32+
FusedFluxAttnProcessor2_0_NPU,
3233
)
3334
from ...models.modeling_utils import ModelMixin
3435
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -140,7 +141,10 @@ def __init__(
140141
self.norm1_context = AdaLayerNormZero(dim)
141142

142143
if hasattr(F, "scaled_dot_product_attention"):
143-
processor = FluxAttnProcessor2_0()
144+
if is_torch_npu_available():
145+
processor = FluxAttnProcessor2_0_NPU()
146+
else:
147+
processor = FluxAttnProcessor2_0()
144148
else:
145149
raise ValueError(
146150
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
@@ -405,7 +409,10 @@ def fuse_qkv_projections(self):
405409
if isinstance(module, Attention):
406410
module.fuse_projections(fuse=True)
407411

408-
self.set_attn_processor(FusedFluxAttnProcessor2_0())
412+
if is_torch_npu_available():
413+
self.set_attn_processor(FusedFluxAttnProcessor2_0_NPU())
414+
else:
415+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
409416

410417
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
411418
def unfuse_qkv_projections(self):

0 commit comments

Comments
 (0)