Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def parse_args(input_args=None):
],
help="The image interpolation method to use for resizing images.",
)
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1182,6 +1183,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")

# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
Expand Down
9 changes: 9 additions & 0 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module


Expand Down Expand Up @@ -686,6 +687,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1213,6 +1215,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")

# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
Expand Down
8 changes: 8 additions & 0 deletions examples/dreambooth/train_dreambooth_lora_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1354,6 +1355,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")

# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
Expand Down
14 changes: 1 addition & 13 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,25 +354,13 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)

if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU

deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
)
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor()

self.attn = FluxAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
processor=FluxAttnProcessor(),
eps=1e-6,
pre_only=True,
)
Expand Down
Loading