Skip to content

Commit 958f0ac

Browse files
author
蒋硕
committed
NPU implementation for FLUX
1 parent 0d9d98f commit 958f0ac

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
)
5959
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
6060
from diffusers.utils.torch_utils import is_compiled_module
61-
61+
from diffusers.utils.import_utils import is_torch_npu_available
6262

6363
if is_wandb_available():
6464
import wandb
@@ -68,6 +68,10 @@
6868

6969
logger = get_logger(__name__)
7070

71+
if is_torch_npu_available():
72+
import torch_npu
73+
torch.npu.config.allow_internal_format = False
74+
torch.npu.set_compile_mode(jit_compile=False)
7175

7276
def save_model_card(
7377
repo_id: str,
@@ -1073,6 +1077,9 @@ def main(args):
10731077
del pipeline
10741078
if torch.cuda.is_available():
10751079
torch.cuda.empty_cache()
1080+
elif is_torch_npu_available():
1081+
torch_npu.npu.empty_cache()
1082+
10761083

10771084
# Handle the repository creation
10781085
if accelerator.is_main_process:
@@ -1359,6 +1366,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
13591366
gc.collect()
13601367
if torch.cuda.is_available():
13611368
torch.cuda.empty_cache()
1369+
elif is_torch_npu_available():
1370+
torch_npu.npu.empty_cache()
13621371

13631372
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
13641373
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1722,7 +1731,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17221731
)
17231732
if not args.train_text_encoder:
17241733
del text_encoder_one, text_encoder_two
1725-
torch.cuda.empty_cache()
1734+
if torch.cuda.is_available():
1735+
torch.cuda.empty_cache()
1736+
elif is_torch_npu_available():
1737+
torch_npu.npu.empty_cache()
17261738
gc.collect()
17271739

17281740
# Save the lora layers

src/diffusers/models/attention_processor.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,28 @@ def __call__(
17781778
query = apply_rotary_emb(query, image_rotary_emb)
17791779
key = apply_rotary_emb(key, image_rotary_emb)
17801780

1781-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1781+
if is_torch_npu_available():
1782+
if query.dtype in (torch.float16, torch.bfloat16):
1783+
hidden_states = torch_npu.npu_fusion_attention(
1784+
query,
1785+
key,
1786+
value,
1787+
attn.heads,
1788+
input_layout="BNSD",
1789+
pse=None,
1790+
scale=1.0 / math.sqrt(query.shape[-1]),
1791+
pre_tockens=65536,
1792+
next_tockens=65536,
1793+
keep_prob=1.0,
1794+
sync=False,
1795+
inner_precise=0,
1796+
)[0]
1797+
else:
1798+
hidden_states = F.scaled_dot_product_attention(
1799+
query, key, value, dropout_p=0.0, is_causal=False
1800+
)
1801+
else:
1802+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
17821803
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
17831804
hidden_states = hidden_states.to(query.dtype)
17841805

@@ -1872,7 +1893,28 @@ def __call__(
18721893
query = apply_rotary_emb(query, image_rotary_emb)
18731894
key = apply_rotary_emb(key, image_rotary_emb)
18741895

1875-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1896+
if is_torch_npu_available():
1897+
if query.dtype in (torch.float16, torch.bfloat16):
1898+
hidden_states = torch_npu.npu_fusion_attention(
1899+
query,
1900+
key,
1901+
value,
1902+
attn.heads,
1903+
input_layout="BNSD",
1904+
pse=None,
1905+
scale=1.0 / math.sqrt(query.shape[-1]),
1906+
pre_tockens=65536,
1907+
next_tockens=65536,
1908+
keep_prob=1.0,
1909+
sync=False,
1910+
inner_precise=0,
1911+
)[0]
1912+
else:
1913+
hidden_states = F.scaled_dot_product_attention(
1914+
query, key, value, dropout_p=0.0, is_causal=False
1915+
)
1916+
else:
1917+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
18761918
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
18771919
hidden_states = hidden_states.to(query.dtype)
18781920

0 commit comments

Comments
 (0)