Skip to content
16 changes: 14 additions & 2 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module

from diffusers.utils.import_utils import is_torch_npu_available

if is_wandb_available():
import wandb
Expand All @@ -68,6 +68,10 @@

logger = get_logger(__name__)

if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False
torch.npu.set_compile_mode(jit_compile=False)

def save_model_card(
repo_id: str,
Expand Down Expand Up @@ -1073,6 +1077,9 @@ def main(args):
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()


# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1359,6 +1366,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1722,7 +1731,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
gc.collect()

# Save the lora layers
Expand Down
46 changes: 44 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,7 +1778,28 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
if is_torch_npu_available():
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -1872,7 +1893,28 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
if is_torch_npu_available():
if query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query,
key,
value,
attn.heads,
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down