Skip to content

Commit 5823736

Browse files
HelloWorldBeginnermhh001sayakpaul
authored
Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. (#7816)
* Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. * fix check code quality * Decouple the NPU flash attention and make it an independent module. * add doc and unit tests for npu flash attention. --------- Co-authored-by: mhh001 <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 3e35628 commit 5823736

File tree

7 files changed

+261
-12
lines changed

7 files changed

+261
-12
lines changed

docs/source/en/api/attnprocessor.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,6 @@ An attention processor is a class for applying different types of attention mech
5555

5656
## XFormersAttnProcessor
5757
[[autodoc]] models.attention_processor.XFormersAttnProcessor
58+
59+
## AttnProcessorNPU
60+
[[autodoc]] models.attention_processor.AttnProcessorNPU

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import transformers
3333
from accelerate import Accelerator
3434
from accelerate.logging import get_logger
35-
from accelerate.utils import ProjectConfiguration, set_seed
35+
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
3636
from datasets import load_dataset
3737
from huggingface_hub import create_repo, upload_folder
3838
from packaging import version
@@ -53,7 +53,7 @@
5353
from diffusers.optimization import get_scheduler
5454
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
5555
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
56-
from diffusers.utils.import_utils import is_xformers_available
56+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
5757
from diffusers.utils.torch_utils import is_compiled_module
5858

5959

@@ -64,6 +64,8 @@
6464
check_min_version("0.28.0.dev0")
6565

6666
logger = get_logger(__name__)
67+
if is_torch_npu_available():
68+
torch.npu.config.allow_internal_format = False
6769

6870

6971
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
@@ -471,6 +473,9 @@ def parse_args(input_args=None):
471473
parser.add_argument(
472474
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
473475
)
476+
parser.add_argument(
477+
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
478+
)
474479
parser.add_argument(
475480
"--set_grads_to_none",
476481
action="store_true",
@@ -936,6 +941,13 @@ def load_model_hook(models, input_dir):
936941
text_encoder_two.requires_grad_(False)
937942
controlnet.train()
938943

944+
if args.enable_npu_flash_attention:
945+
if is_torch_npu_available():
946+
logger.info("npu flash attention enabled.")
947+
unet.enable_npu_flash_attention()
948+
else:
949+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
950+
939951
if args.enable_xformers_memory_efficient_attention:
940952
if is_xformers_available():
941953
import xformers
@@ -1235,7 +1247,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
12351247
progress_bar.update(1)
12361248
global_step += 1
12371249

1238-
if accelerator.is_main_process:
1250+
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1251+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
12391252
if global_step % args.checkpointing_steps == 0:
12401253
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
12411254
if args.checkpoints_total_limit is not None:

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import transformers
3333
from accelerate import Accelerator
3434
from accelerate.logging import get_logger
35-
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
35+
from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
3636
from datasets import load_dataset
3737
from huggingface_hub import create_repo, upload_folder
3838
from packaging import version
@@ -60,14 +60,16 @@
6060
is_wandb_available,
6161
)
6262
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
63-
from diffusers.utils.import_utils import is_xformers_available
63+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
6464
from diffusers.utils.torch_utils import is_compiled_module
6565

6666

6767
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
6868
check_min_version("0.28.0.dev0")
6969

7070
logger = get_logger(__name__)
71+
if is_torch_npu_available():
72+
torch.npu.config.allow_internal_format = False
7173

7274

7375
def save_model_card(
@@ -419,6 +421,9 @@ def parse_args(input_args=None):
419421
parser.add_argument(
420422
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
421423
)
424+
parser.add_argument(
425+
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
426+
)
422427
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
423428
parser.add_argument(
424429
"--rank",
@@ -623,6 +628,13 @@ def main(args):
623628
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
624629
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
625630

631+
if args.enable_npu_flash_attention:
632+
if is_torch_npu_available():
633+
logger.info("npu flash attention enabled.")
634+
unet.enable_npu_flash_attention()
635+
else:
636+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
637+
626638
if args.enable_xformers_memory_efficient_attention:
627639
if is_xformers_available():
628640
import xformers
@@ -1149,7 +1161,8 @@ def compute_time_ids(original_size, crops_coords_top_left):
11491161
accelerator.log({"train_loss": train_loss}, step=global_step)
11501162
train_loss = 0.0
11511163

1152-
if accelerator.is_main_process:
1164+
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1165+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
11531166
if global_step % args.checkpointing_steps == 0:
11541167
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
11551168
if args.checkpoints_total_limit is not None:

src/diffusers/models/activations.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
from torch import nn
1919

2020
from ..utils import deprecate
21+
from ..utils.import_utils import is_torch_npu_available
2122

2223

24+
if is_torch_npu_available():
25+
import torch_npu
26+
2327
ACTIVATION_FUNCTIONS = {
2428
"swish": nn.SiLU(),
2529
"silu": nn.SiLU(),
@@ -98,9 +102,13 @@ def forward(self, hidden_states, *args, **kwargs):
98102
if len(args) > 0 or kwargs.get("scale", None) is not None:
99103
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
100104
deprecate("scale", "1.0.0", deprecation_message)
101-
102-
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
103-
return hidden_states * self.gelu(gate)
105+
hidden_states = self.proj(hidden_states)
106+
if is_torch_npu_available():
107+
# using torch_npu.npu_geglu can run faster and save memory on NPU.
108+
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
109+
else:
110+
hidden_states, gate = hidden_states.chunk(2, dim=-1)
111+
return hidden_states * self.gelu(gate)
104112

105113

106114
class ApproximateGELU(nn.Module):

src/diffusers/models/attention_processor.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15+
import math
1516
from importlib import import_module
1617
from typing import Callable, List, Optional, Union
1718

@@ -21,13 +22,15 @@
2122

2223
from ..image_processor import IPAdapterMaskProcessor
2324
from ..utils import deprecate, logging
24-
from ..utils.import_utils import is_xformers_available
25+
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2526
from ..utils.torch_utils import maybe_allow_in_graph
2627
from .lora import LoRALinearLayer
2728

2829

2930
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3031

32+
if is_torch_npu_available():
33+
import torch_npu
3134

3235
if is_xformers_available():
3336
import xformers
@@ -209,6 +212,23 @@ def __init__(
209212
)
210213
self.set_processor(processor)
211214

215+
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
216+
r"""
217+
Set whether to use npu flash attention from `torch_npu` or not.
218+
219+
"""
220+
if use_npu_flash_attention:
221+
processor = AttnProcessorNPU()
222+
else:
223+
# set attention processor
224+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
225+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
226+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
227+
processor = (
228+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
229+
)
230+
self.set_processor(processor)
231+
212232
def set_use_memory_efficient_attention_xformers(
213233
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
214234
) -> None:
@@ -1207,6 +1227,116 @@ def __call__(
12071227
return hidden_states
12081228

12091229

1230+
class AttnProcessorNPU:
1231+
1232+
r"""
1233+
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
1234+
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
1235+
not significant.
1236+
1237+
"""
1238+
1239+
def __init__(self):
1240+
if not is_torch_npu_available():
1241+
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
1242+
1243+
def __call__(
1244+
self,
1245+
attn: Attention,
1246+
hidden_states: torch.FloatTensor,
1247+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
1248+
attention_mask: Optional[torch.FloatTensor] = None,
1249+
temb: Optional[torch.FloatTensor] = None,
1250+
*args,
1251+
**kwargs,
1252+
) -> torch.FloatTensor:
1253+
if len(args) > 0 or kwargs.get("scale", None) is not None:
1254+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1255+
deprecate("scale", "1.0.0", deprecation_message)
1256+
1257+
residual = hidden_states
1258+
if attn.spatial_norm is not None:
1259+
hidden_states = attn.spatial_norm(hidden_states, temb)
1260+
1261+
input_ndim = hidden_states.ndim
1262+
1263+
if input_ndim == 4:
1264+
batch_size, channel, height, width = hidden_states.shape
1265+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1266+
1267+
batch_size, sequence_length, _ = (
1268+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1269+
)
1270+
1271+
if attention_mask is not None:
1272+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1273+
# scaled_dot_product_attention expects attention_mask shape to be
1274+
# (batch, heads, source_length, target_length)
1275+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1276+
1277+
if attn.group_norm is not None:
1278+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1279+
1280+
query = attn.to_q(hidden_states)
1281+
1282+
if encoder_hidden_states is None:
1283+
encoder_hidden_states = hidden_states
1284+
elif attn.norm_cross:
1285+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1286+
1287+
key = attn.to_k(encoder_hidden_states)
1288+
value = attn.to_v(encoder_hidden_states)
1289+
1290+
inner_dim = key.shape[-1]
1291+
head_dim = inner_dim // attn.heads
1292+
1293+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1294+
1295+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1296+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1297+
1298+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1299+
if query.dtype in (torch.float16, torch.bfloat16):
1300+
hidden_states = torch_npu.npu_fusion_attention(
1301+
query,
1302+
key,
1303+
value,
1304+
attn.heads,
1305+
input_layout="BNSD",
1306+
pse=None,
1307+
atten_mask=attention_mask,
1308+
scale=1.0 / math.sqrt(query.shape[-1]),
1309+
pre_tockens=65536,
1310+
next_tockens=65536,
1311+
keep_prob=1.0,
1312+
sync=False,
1313+
inner_precise=0,
1314+
)[0]
1315+
else:
1316+
# TODO: add support for attn.scale when we move to Torch 2.1
1317+
hidden_states = F.scaled_dot_product_attention(
1318+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1319+
)
1320+
1321+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1322+
hidden_states = hidden_states.to(query.dtype)
1323+
1324+
# linear proj
1325+
hidden_states = attn.to_out[0](hidden_states)
1326+
# dropout
1327+
hidden_states = attn.to_out[1](hidden_states)
1328+
1329+
if input_ndim == 4:
1330+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1331+
1332+
if attn.residual_connection:
1333+
hidden_states = hidden_states + residual
1334+
1335+
hidden_states = hidden_states / attn.rescale_output_factor
1336+
1337+
return hidden_states
1338+
1339+
12101340
class AttnProcessor2_0:
12111341
r"""
12121342
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).

src/diffusers/models/modeling_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,36 @@ def disable_gradient_checkpointing(self) -> None:
272272
if self._supports_gradient_checkpointing:
273273
self.apply(partial(self._set_gradient_checkpointing, value=False))
274274

275+
def set_use_npu_flash_attention(self, valid: bool) -> None:
276+
r"""
277+
Set the switch for the npu flash attention.
278+
"""
279+
280+
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
281+
if hasattr(module, "set_use_npu_flash_attention"):
282+
module.set_use_npu_flash_attention(valid)
283+
284+
for child in module.children():
285+
fn_recursive_set_npu_flash_attention(child)
286+
287+
for module in self.children():
288+
if isinstance(module, torch.nn.Module):
289+
fn_recursive_set_npu_flash_attention(module)
290+
291+
def enable_npu_flash_attention(self) -> None:
292+
r"""
293+
Enable npu flash attention from torch_npu
294+
295+
"""
296+
self.set_use_npu_flash_attention(True)
297+
298+
def disable_npu_flash_attention(self) -> None:
299+
r"""
300+
disable npu flash attention from torch_npu
301+
302+
"""
303+
self.set_use_npu_flash_attention(False)
304+
275305
def set_use_memory_efficient_attention_xformers(
276306
self, valid: bool, attention_op: Optional[Callable] = None
277307
) -> None:

0 commit comments

Comments
 (0)