Skip to content

Commit 07d6649

Browse files
author
蒋硕
committed
Improve the performance and suitable for NPU
1 parent 01337da commit 07d6649

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@
5959

6060
logger = get_logger(__name__)
6161
if is_torch_npu_available():
62+
import torch_npu
6263
torch.npu.config.allow_internal_format = False
64+
torch.npu.set_compile_mode(jit_compile=False)
6365

6466
DATASET_NAME_MAPPING = {
6567
"lambdalabs/naruto-blip-captions": ("image", "text"),
@@ -531,7 +533,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
531533
return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
532534

533535

534-
def compute_vae_encodings(batch, vae):
536+
def compute_vae_encodings(batch, accelerator, vae):
535537
images = batch.pop("pixel_values")
536538
pixel_values = torch.stack(list(images))
537539
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
@@ -540,7 +542,7 @@ def compute_vae_encodings(batch, vae):
540542
with torch.no_grad():
541543
model_input = vae.encode(pixel_values).latent_dist.sample()
542544
model_input = model_input * vae.config.scaling_factor
543-
return {"model_input": model_input.cpu()}
545+
return {"model_input": accelerator.gather(model_input)}
544546

545547

546548
def generate_timestep_weights(args, num_timesteps):
@@ -910,7 +912,7 @@ def preprocess_train(examples):
910912
proportion_empty_prompts=args.proportion_empty_prompts,
911913
caption_column=args.caption_column,
912914
)
913-
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
915+
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae)
914916
with accelerator.main_process_first():
915917
from datasets.fingerprint import Hasher
916918

@@ -935,7 +937,10 @@ def preprocess_train(examples):
935937
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
936938
del text_encoders, tokenizers, vae
937939
gc.collect()
938-
torch.cuda.empty_cache()
940+
if is_torch_npu_available():
941+
torch_npu.npu.empty_cache()
942+
else:
943+
torch.cuda.empty_cache()
939944

940945
def collate_fn(examples):
941946
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
@@ -1091,8 +1096,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10911096
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
10921097
target_size = (args.resolution, args.resolution)
10931098
add_time_ids = list(original_size + crops_coords_top_left + target_size)
1094-
add_time_ids = torch.tensor([add_time_ids])
1095-
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1099+
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
10961100
return add_time_ids
10971101

10981102
add_time_ids = torch.cat(
@@ -1261,7 +1265,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12611265
)
12621266

12631267
del pipeline
1264-
torch.cuda.empty_cache()
1268+
if is_torch_npu_available():
1269+
torch_npu.npu.empty_cache()
1270+
else:
1271+
torch.cuda.empty_cache()
12651272

12661273
if args.use_ema:
12671274
# Switch back to the original UNet parameters.

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,8 +2274,7 @@ def __call__(
22742274
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
22752275
)
22762276

2277-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2278-
hidden_states = hidden_states.to(query.dtype)
2277+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim).to(query.dtype)
22792278

22802279
# linear proj
22812280
hidden_states = attn.to_out[0](hidden_states)
@@ -4277,6 +4276,7 @@ def __init__(self):
42774276
CROSS_ATTENTION_PROCESSORS = (
42784277
AttnProcessor,
42794278
AttnProcessor2_0,
4279+
AttnProcessorNPU,
42804280
XFormersAttnProcessor,
42814281
SlicedAttnProcessor,
42824282
IPAdapterAttnProcessor,
@@ -4286,6 +4286,7 @@ def __init__(self):
42864286
AttentionProcessor = Union[
42874287
AttnProcessor,
42884288
AttnProcessor2_0,
4289+
AttnProcessorNPU,
42894290
FusedAttnProcessor2_0,
42904291
XFormersAttnProcessor,
42914292
SlicedAttnProcessor,

0 commit comments

Comments
 (0)