Skip to content

Commit ad90fa1

Browse files
author
蒋硕
committed
NPU implementation for FLUX
1 parent 958f0ac commit ad90fa1

File tree

2 files changed

+228
-46
lines changed

2 files changed

+228
-46
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 217 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,28 +1778,7 @@ def __call__(
17781778
query = apply_rotary_emb(query, image_rotary_emb)
17791779
key = apply_rotary_emb(key, image_rotary_emb)
17801780

1781-
if is_torch_npu_available():
1782-
if query.dtype in (torch.float16, torch.bfloat16):
1783-
hidden_states = torch_npu.npu_fusion_attention(
1784-
query,
1785-
key,
1786-
value,
1787-
attn.heads,
1788-
input_layout="BNSD",
1789-
pse=None,
1790-
scale=1.0 / math.sqrt(query.shape[-1]),
1791-
pre_tockens=65536,
1792-
next_tockens=65536,
1793-
keep_prob=1.0,
1794-
sync=False,
1795-
inner_precise=0,
1796-
)[0]
1797-
else:
1798-
hidden_states = F.scaled_dot_product_attention(
1799-
query, key, value, dropout_p=0.0, is_causal=False
1800-
)
1801-
else:
1802-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1781+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
18031782
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
18041783
hidden_states = hidden_states.to(query.dtype)
18051784

@@ -1819,6 +1798,110 @@ def __call__(
18191798
else:
18201799
return hidden_states
18211800

1801+
class FluxAttnProcessor2_0_NPU:
1802+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
1803+
1804+
def __init__(self):
1805+
if not hasattr(F, "scaled_dot_product_attention"):
1806+
raise ImportError("FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0.")
1807+
1808+
def __call__(
1809+
self,
1810+
attn: Attention,
1811+
hidden_states: torch.FloatTensor,
1812+
encoder_hidden_states: torch.FloatTensor = None,
1813+
attention_mask: Optional[torch.FloatTensor] = None,
1814+
image_rotary_emb: Optional[torch.Tensor] = None,
1815+
) -> torch.FloatTensor:
1816+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1817+
1818+
# `sample` projections.
1819+
query = attn.to_q(hidden_states)
1820+
key = attn.to_k(hidden_states)
1821+
value = attn.to_v(hidden_states)
1822+
1823+
inner_dim = key.shape[-1]
1824+
head_dim = inner_dim // attn.heads
1825+
1826+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1827+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1828+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1829+
1830+
if attn.norm_q is not None:
1831+
query = attn.norm_q(query)
1832+
if attn.norm_k is not None:
1833+
key = attn.norm_k(key)
1834+
1835+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1836+
if encoder_hidden_states is not None:
1837+
# `context` projections.
1838+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1839+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1840+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1841+
1842+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1843+
batch_size, -1, attn.heads, head_dim
1844+
).transpose(1, 2)
1845+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1846+
batch_size, -1, attn.heads, head_dim
1847+
).transpose(1, 2)
1848+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1849+
batch_size, -1, attn.heads, head_dim
1850+
).transpose(1, 2)
1851+
1852+
if attn.norm_added_q is not None:
1853+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1854+
if attn.norm_added_k is not None:
1855+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1856+
1857+
# attention
1858+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1859+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1860+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1861+
1862+
if image_rotary_emb is not None:
1863+
from .embeddings import apply_rotary_emb
1864+
1865+
query = apply_rotary_emb(query, image_rotary_emb)
1866+
key = apply_rotary_emb(key, image_rotary_emb)
1867+
1868+
if query.dtype in (torch.float16, torch.bfloat16):
1869+
hidden_states = torch_npu.npu_fusion_attention(
1870+
query,
1871+
key,
1872+
value,
1873+
attn.heads,
1874+
input_layout="BNSD",
1875+
pse=None,
1876+
scale=1.0 / math.sqrt(query.shape[-1]),
1877+
pre_tockens=65536,
1878+
next_tockens=65536,
1879+
keep_prob=1.0,
1880+
sync=False,
1881+
inner_precise=0,
1882+
)[0]
1883+
else:
1884+
hidden_states = F.scaled_dot_product_attention(
1885+
query, key, value, dropout_p=0.0, is_causal=False
1886+
)
1887+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1888+
hidden_states = hidden_states.to(query.dtype)
1889+
1890+
if encoder_hidden_states is not None:
1891+
encoder_hidden_states, hidden_states = (
1892+
hidden_states[:, : encoder_hidden_states.shape[1]],
1893+
hidden_states[:, encoder_hidden_states.shape[1] :],
1894+
)
1895+
1896+
# linear proj
1897+
hidden_states = attn.to_out[0](hidden_states)
1898+
# dropout
1899+
hidden_states = attn.to_out[1](hidden_states)
1900+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1901+
1902+
return hidden_states, encoder_hidden_states
1903+
else:
1904+
return hidden_states
18221905

18231906
class FusedFluxAttnProcessor2_0:
18241907
"""Attention processor used typically in processing the SD3-like self-attention projections."""
@@ -1893,28 +1976,7 @@ def __call__(
18931976
query = apply_rotary_emb(query, image_rotary_emb)
18941977
key = apply_rotary_emb(key, image_rotary_emb)
18951978

1896-
if is_torch_npu_available():
1897-
if query.dtype in (torch.float16, torch.bfloat16):
1898-
hidden_states = torch_npu.npu_fusion_attention(
1899-
query,
1900-
key,
1901-
value,
1902-
attn.heads,
1903-
input_layout="BNSD",
1904-
pse=None,
1905-
scale=1.0 / math.sqrt(query.shape[-1]),
1906-
pre_tockens=65536,
1907-
next_tockens=65536,
1908-
keep_prob=1.0,
1909-
sync=False,
1910-
inner_precise=0,
1911-
)[0]
1912-
else:
1913-
hidden_states = F.scaled_dot_product_attention(
1914-
query, key, value, dropout_p=0.0, is_causal=False
1915-
)
1916-
else:
1917-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1979+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
19181980
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
19191981
hidden_states = hidden_states.to(query.dtype)
19201982

@@ -1934,6 +1996,117 @@ def __call__(
19341996
else:
19351997
return hidden_states
19361998

1999+
class FusedFluxAttnProcessor2_0_NPU:
2000+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
2001+
2002+
def __init__(self):
2003+
if not hasattr(F, "scaled_dot_product_attention"):
2004+
raise ImportError(
2005+
"FusedFluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0."
2006+
)
2007+
2008+
def __call__(
2009+
self,
2010+
attn: Attention,
2011+
hidden_states: torch.FloatTensor,
2012+
encoder_hidden_states: torch.FloatTensor = None,
2013+
attention_mask: Optional[torch.FloatTensor] = None,
2014+
image_rotary_emb: Optional[torch.Tensor] = None,
2015+
) -> torch.FloatTensor:
2016+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2017+
2018+
# `sample` projections.
2019+
qkv = attn.to_qkv(hidden_states)
2020+
split_size = qkv.shape[-1] // 3
2021+
query, key, value = torch.split(qkv, split_size, dim=-1)
2022+
2023+
inner_dim = key.shape[-1]
2024+
head_dim = inner_dim // attn.heads
2025+
2026+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2027+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2028+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2029+
2030+
if attn.norm_q is not None:
2031+
query = attn.norm_q(query)
2032+
if attn.norm_k is not None:
2033+
key = attn.norm_k(key)
2034+
2035+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2036+
# `context` projections.
2037+
if encoder_hidden_states is not None:
2038+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2039+
split_size = encoder_qkv.shape[-1] // 3
2040+
(
2041+
encoder_hidden_states_query_proj,
2042+
encoder_hidden_states_key_proj,
2043+
encoder_hidden_states_value_proj,
2044+
) = torch.split(encoder_qkv, split_size, dim=-1)
2045+
2046+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2047+
batch_size, -1, attn.heads, head_dim
2048+
).transpose(1, 2)
2049+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2050+
batch_size, -1, attn.heads, head_dim
2051+
).transpose(1, 2)
2052+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2053+
batch_size, -1, attn.heads, head_dim
2054+
).transpose(1, 2)
2055+
2056+
if attn.norm_added_q is not None:
2057+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2058+
if attn.norm_added_k is not None:
2059+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2060+
2061+
# attention
2062+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2063+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2064+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2065+
2066+
if image_rotary_emb is not None:
2067+
from .embeddings import apply_rotary_emb
2068+
2069+
query = apply_rotary_emb(query, image_rotary_emb)
2070+
key = apply_rotary_emb(key, image_rotary_emb)
2071+
2072+
if query.dtype in (torch.float16, torch.bfloat16):
2073+
hidden_states = torch_npu.npu_fusion_attention(
2074+
query,
2075+
key,
2076+
value,
2077+
attn.heads,
2078+
input_layout="BNSD",
2079+
pse=None,
2080+
scale=1.0 / math.sqrt(query.shape[-1]),
2081+
pre_tockens=65536,
2082+
next_tockens=65536,
2083+
keep_prob=1.0,
2084+
sync=False,
2085+
inner_precise=0,
2086+
)[0]
2087+
else:
2088+
hidden_states = F.scaled_dot_product_attention(
2089+
query, key, value, dropout_p=0.0, is_causal=False
2090+
)
2091+
2092+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2093+
hidden_states = hidden_states.to(query.dtype)
2094+
2095+
if encoder_hidden_states is not None:
2096+
encoder_hidden_states, hidden_states = (
2097+
hidden_states[:, : encoder_hidden_states.shape[1]],
2098+
hidden_states[:, encoder_hidden_states.shape[1] :],
2099+
)
2100+
2101+
# linear proj
2102+
hidden_states = attn.to_out[0](hidden_states)
2103+
# dropout
2104+
hidden_states = attn.to_out[1](hidden_states)
2105+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2106+
2107+
return hidden_states, encoder_hidden_states
2108+
else:
2109+
return hidden_states
19372110

19382111
class CogVideoXAttnProcessor2_0:
19392112
r"""

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@
2727
Attention,
2828
AttentionProcessor,
2929
FluxAttnProcessor2_0,
30+
FluxAttnProcessor2_0_NPU,
3031
FusedFluxAttnProcessor2_0,
32+
FusedFluxAttnProcessor2_0_NPU,
3133
)
3234
from ...models.modeling_utils import ModelMixin
3335
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
3436
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
3537
from ...utils.torch_utils import maybe_allow_in_graph
38+
from ...utils.import_utils import is_torch_npu_available
3639
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
3740
from ..modeling_outputs import Transformer2DModelOutput
3841

@@ -64,7 +67,10 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
6467
self.act_mlp = nn.GELU(approximate="tanh")
6568
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
6669

67-
processor = FluxAttnProcessor2_0()
70+
if is_torch_npu_available():
71+
processor = FluxAttnProcessor2_0_NPU()
72+
else:
73+
processor = FluxAttnProcessor2_0()
6874
self.attn = Attention(
6975
query_dim=dim,
7076
cross_attention_dim=None,
@@ -369,7 +375,10 @@ def fuse_qkv_projections(self):
369375
if isinstance(module, Attention):
370376
module.fuse_projections(fuse=True)
371377

372-
self.set_attn_processor(FusedFluxAttnProcessor2_0())
378+
if is_torch_npu_available():
379+
self.set_attn_processor(FusedFluxAttnProcessor2_0_NPU())
380+
else:
381+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
373382

374383
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
375384
def unfuse_qkv_projections(self):

0 commit comments

Comments
 (0)