diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 2ca511c857ae..bcf0fa9eb0ac 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -59,6 +59,8 @@ logger = get_logger(__name__) if is_torch_npu_available(): + import torch_npu + torch.npu.config.allow_internal_format = False DATASET_NAME_MAPPING = { @@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae): with torch.no_grad(): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor + + # There might have slightly performance improvement + # by changing model_input.cpu() to accelerator.gather(model_input) return {"model_input": model_input.cpu()} @@ -935,7 +940,10 @@ def preprocess_train(examples): del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two del text_encoders, tokenizers, vae gc.collect() - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + elif torch.cuda.is_available(): + torch.cuda.empty_cache() def collate_fn(examples): model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) @@ -1091,8 +1099,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids target_size = (args.resolution, args.resolution) add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype) return add_time_ids add_time_ids = torch.cat( @@ -1261,7 +1268,10 @@ def compute_time_ids(original_size, crops_coords_top_left): ) del pipeline - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + elif torch.cuda.is_available(): + torch.cuda.empty_cache() if args.use_ema: # Switch back to the original UNet parameters.