Skip to content

Commit d8467e1

Browse files
author
J石页
committed
NPU attention refactor for FLUX transformer
1 parent 91a151b commit d8467e1

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def parse_args(input_args=None):
642642
],
643643
help="The image interpolation method to use for resizing images.",
644644
)
645+
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
645646

646647
if input_args is not None:
647648
args = parser.parse_args(input_args)
@@ -1182,6 +1183,13 @@ def main(args):
11821183
text_encoder_one.requires_grad_(False)
11831184
text_encoder_two.requires_grad_(False)
11841185

1186+
if args.enable_npu_flash_attention:
1187+
if is_torch_npu_available():
1188+
logger.info("npu flash attention enabled.")
1189+
transformer.set_attention_backend("_native_npu")
1190+
else:
1191+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1192+
11851193
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
11861194
# as these weights are only used for inference, keeping weights in full precision is not required.
11871195
weight_dtype = torch.float32

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
is_wandb_available,
8181
)
8282
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
83+
from diffusers.utils.import_utils import is_torch_npu_available
8384
from diffusers.utils.torch_utils import is_compiled_module
8485

8586

@@ -686,6 +687,7 @@ def parse_args(input_args=None):
686687
),
687688
)
688689
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
690+
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
689691

690692
if input_args is not None:
691693
args = parser.parse_args(input_args)
@@ -1213,6 +1215,13 @@ def main(args):
12131215
text_encoder_one.requires_grad_(False)
12141216
text_encoder_two.requires_grad_(False)
12151217

1218+
if args.enable_npu_flash_attention:
1219+
if is_torch_npu_available():
1220+
logger.info("npu flash attention enabled.")
1221+
transformer.set_attention_backend("_native_npu")
1222+
else:
1223+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1224+
12161225
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
12171226
# as these weights are only used for inference, keeping weights in full precision is not required.
12181227
weight_dtype = torch.float32

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def parse_args(input_args=None):
706706
),
707707
)
708708
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
709+
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
709710

710711
if input_args is not None:
711712
args = parser.parse_args(input_args)
@@ -1354,6 +1355,13 @@ def main(args):
13541355
text_encoder_one.requires_grad_(False)
13551356
text_encoder_two.requires_grad_(False)
13561357

1358+
if args.enable_npu_flash_attention:
1359+
if is_torch_npu_available():
1360+
logger.info("npu flash attention enabled.")
1361+
transformer.set_attention_backend("_native_npu")
1362+
else:
1363+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1364+
13571365
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
13581366
# as these weights are only used for inference, keeping weights in full precision is not required.
13591367
weight_dtype = torch.float32

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
7474

7575
class FluxAttnProcessor:
7676
_attention_backend = None
77-
7877
def __init__(self):
7978
if not hasattr(F, "scaled_dot_product_attention"):
8079
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
@@ -354,25 +353,13 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
354353
self.act_mlp = nn.GELU(approximate="tanh")
355354
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
356355

357-
if is_torch_npu_available():
358-
from ..attention_processor import FluxAttnProcessor2_0_NPU
359-
360-
deprecation_message = (
361-
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
362-
"should be set explicitly using the `set_attn_processor` method."
363-
)
364-
deprecate("npu_processor", "0.34.0", deprecation_message)
365-
processor = FluxAttnProcessor2_0_NPU()
366-
else:
367-
processor = FluxAttnProcessor()
368-
369356
self.attn = FluxAttention(
370357
query_dim=dim,
371358
dim_head=attention_head_dim,
372359
heads=num_attention_heads,
373360
out_dim=dim,
374361
bias=True,
375-
processor=processor,
362+
processor=FluxAttnProcessor(),
376363
eps=1e-6,
377364
pre_only=True,
378365
)

0 commit comments

Comments
 (0)