Skip to content

Commit 330d190

Browse files
author
J石页
committed
NPU attention refactor for FLUX transformer
1 parent 91a151b commit 330d190

File tree

4 files changed

+39
-15
lines changed

4 files changed

+39
-15
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def parse_args(input_args=None):
642642
],
643643
help="The image interpolation method to use for resizing images.",
644644
)
645+
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
645646

646647
if input_args is not None:
647648
args = parser.parse_args(input_args)
@@ -1182,6 +1183,16 @@ def main(args):
11821183
text_encoder_one.requires_grad_(False)
11831184
text_encoder_two.requires_grad_(False)
11841185

1186+
if args.enable_npu_flash_attention:
1187+
if is_torch_npu_available():
1188+
logger.info("npu flash attention enabled.")
1189+
for block in transformer.transformer_blocks:
1190+
block.attn.processor._attention_backend = "_native_npu"
1191+
for block in transformer.single_transformer_blocks:
1192+
block.attn.processor._attention_backend = "_native_npu"
1193+
else:
1194+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1195+
11851196
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
11861197
# as these weights are only used for inference, keeping weights in full precision is not required.
11871198
weight_dtype = torch.float32

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
is_wandb_available,
8181
)
8282
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
83+
from diffusers.utils.import_utils import is_torch_npu_available
8384
from diffusers.utils.torch_utils import is_compiled_module
8485

8586

@@ -686,6 +687,7 @@ def parse_args(input_args=None):
686687
),
687688
)
688689
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
690+
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
689691

690692
if input_args is not None:
691693
args = parser.parse_args(input_args)
@@ -1213,6 +1215,16 @@ def main(args):
12131215
text_encoder_one.requires_grad_(False)
12141216
text_encoder_two.requires_grad_(False)
12151217

1218+
if args.enable_npu_flash_attention:
1219+
if is_torch_npu_available():
1220+
logger.info("npu flash attention enabled.")
1221+
for block in transformer.transformer_blocks:
1222+
block.attn.processor._attention_backend = "_native_npu"
1223+
for block in transformer.single_transformer_blocks:
1224+
block.attn.processor._attention_backend = "_native_npu"
1225+
else:
1226+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1227+
12161228
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
12171229
# as these weights are only used for inference, keeping weights in full precision is not required.
12181230
weight_dtype = torch.float32

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def parse_args(input_args=None):
706706
),
707707
)
708708
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
709+
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
709710

710711
if input_args is not None:
711712
args = parser.parse_args(input_args)
@@ -1354,6 +1355,16 @@ def main(args):
13541355
text_encoder_one.requires_grad_(False)
13551356
text_encoder_two.requires_grad_(False)
13561357

1358+
if args.enable_npu_flash_attention:
1359+
if is_torch_npu_available():
1360+
logger.info("npu flash attention enabled.")
1361+
for block in transformer.transformer_blocks:
1362+
block.attn.processor._attention_backend = "_native_npu"
1363+
for block in transformer.single_transformer_blocks:
1364+
block.attn.processor._attention_backend = "_native_npu"
1365+
else:
1366+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1367+
13571368
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
13581369
# as these weights are only used for inference, keeping weights in full precision is not required.
13591370
weight_dtype = torch.float32

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
7373

7474

7575
class FluxAttnProcessor:
76-
_attention_backend = None
76+
def __init__(self, _attention_backend=None):
77+
super().__init__()
78+
79+
self._attention_backend = _attention_backend
7780

78-
def __init__(self):
7981
if not hasattr(F, "scaled_dot_product_attention"):
8082
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
8183

@@ -354,25 +356,13 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
354356
self.act_mlp = nn.GELU(approximate="tanh")
355357
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
356358

357-
if is_torch_npu_available():
358-
from ..attention_processor import FluxAttnProcessor2_0_NPU
359-
360-
deprecation_message = (
361-
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
362-
"should be set explicitly using the `set_attn_processor` method."
363-
)
364-
deprecate("npu_processor", "0.34.0", deprecation_message)
365-
processor = FluxAttnProcessor2_0_NPU()
366-
else:
367-
processor = FluxAttnProcessor()
368-
369359
self.attn = FluxAttention(
370360
query_dim=dim,
371361
dim_head=attention_head_dim,
372362
heads=num_attention_heads,
373363
out_dim=dim,
374364
bias=True,
375-
processor=processor,
365+
processor=FluxAttnProcessor(),
376366
eps=1e-6,
377367
pre_only=True,
378368
)

0 commit comments

Comments
 (0)