Skip to content

Commit f611e5f

Browse files
authored
Merge branch 'main' into lora_modules
2 parents ff5511c + 5956b68 commit f611e5f

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959

6060
logger = get_logger(__name__)
6161
if is_torch_npu_available():
62+
import torch_npu
63+
6264
torch.npu.config.allow_internal_format = False
6365

6466
DATASET_NAME_MAPPING = {
@@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae):
540542
with torch.no_grad():
541543
model_input = vae.encode(pixel_values).latent_dist.sample()
542544
model_input = model_input * vae.config.scaling_factor
545+
546+
# There might have slightly performance improvement
547+
# by changing model_input.cpu() to accelerator.gather(model_input)
543548
return {"model_input": model_input.cpu()}
544549

545550

@@ -935,7 +940,10 @@ def preprocess_train(examples):
935940
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
936941
del text_encoders, tokenizers, vae
937942
gc.collect()
938-
torch.cuda.empty_cache()
943+
if is_torch_npu_available():
944+
torch_npu.npu.empty_cache()
945+
elif torch.cuda.is_available():
946+
torch.cuda.empty_cache()
939947

940948
def collate_fn(examples):
941949
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
@@ -1091,8 +1099,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10911099
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
10921100
target_size = (args.resolution, args.resolution)
10931101
add_time_ids = list(original_size + crops_coords_top_left + target_size)
1094-
add_time_ids = torch.tensor([add_time_ids])
1095-
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1102+
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
10961103
return add_time_ids
10971104

10981105
add_time_ids = torch.cat(
@@ -1261,7 +1268,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12611268
)
12621269

12631270
del pipeline
1264-
torch.cuda.empty_cache()
1271+
if is_torch_npu_available():
1272+
torch_npu.npu.empty_cache()
1273+
elif torch.cuda.is_available():
1274+
torch.cuda.empty_cache()
12651275

12661276
if args.use_ema:
12671277
# Switch back to the original UNet parameters.

0 commit comments

Comments
 (0)