From e2f0a7b25a3bf276a580ccdf966504faebd348e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Fri, 11 Oct 2024 10:01:32 +0800 Subject: [PATCH 1/6] Improve the performance and suitable for NPU --- .../text_to_image/train_text_to_image_sdxl.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 2ca511c857ae..d04ea9c9258e 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -59,6 +59,7 @@ logger = get_logger(__name__) if is_torch_npu_available(): + import torch_npu torch.npu.config.allow_internal_format = False DATASET_NAME_MAPPING = { @@ -531,7 +532,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} -def compute_vae_encodings(batch, vae): +def compute_vae_encodings(batch, accelerator, vae): images = batch.pop("pixel_values") pixel_values = torch.stack(list(images)) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -540,7 +541,7 @@ def compute_vae_encodings(batch, vae): with torch.no_grad(): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor - return {"model_input": model_input.cpu()} + return {"model_input": accelerator.gather(model_input)} def generate_timestep_weights(args, num_timesteps): @@ -910,7 +911,7 @@ def preprocess_train(examples): proportion_empty_prompts=args.proportion_empty_prompts, caption_column=args.caption_column, ) - compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) + compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae) with accelerator.main_process_first(): from datasets.fingerprint import Hasher @@ -935,7 +936,10 @@ def preprocess_train(examples): del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two del text_encoders, tokenizers, vae gc.collect() - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + else: + torch.cuda.empty_cache() def collate_fn(examples): model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) @@ -1091,8 +1095,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids target_size = (args.resolution, args.resolution) add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype) return add_time_ids add_time_ids = torch.cat( @@ -1261,7 +1264,10 @@ def compute_time_ids(original_size, crops_coords_top_left): ) del pipeline - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + else: + torch.cuda.empty_cache() if args.use_ema: # Switch back to the original UNet parameters. From ab3cd4f45bdef2126468ca4febd258cbb93efc2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Sat, 12 Oct 2024 09:11:02 +0800 Subject: [PATCH 2/6] Improve the performance and suitable for NPU computing --- examples/text_to_image/train_text_to_image_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index d04ea9c9258e..5ba9a09122e9 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -938,7 +938,7 @@ def preprocess_train(examples): gc.collect() if is_torch_npu_available(): torch_npu.npu.empty_cache() - else: + elif torch.cuda.is_available(): torch.cuda.empty_cache() def collate_fn(examples): @@ -1266,7 +1266,7 @@ def compute_time_ids(original_size, crops_coords_top_left): del pipeline if is_torch_npu_available(): torch_npu.npu.empty_cache() - else: + elif torch.cuda.is_available(): torch.cuda.empty_cache() if args.use_ema: From 98f55d0415f9bbd35fd82dc19e57d74ea5d311f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Sat, 12 Oct 2024 15:38:00 +0800 Subject: [PATCH 3/6] Improve the performance and suitable for NPU --- examples/text_to_image/train_text_to_image_sdxl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 5ba9a09122e9..5d3e955c7d4e 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -532,7 +532,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} -def compute_vae_encodings(batch, accelerator, vae): +def compute_vae_encodings(batch, vae): images = batch.pop("pixel_values") pixel_values = torch.stack(list(images)) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -541,7 +541,7 @@ def compute_vae_encodings(batch, accelerator, vae): with torch.no_grad(): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor - return {"model_input": accelerator.gather(model_input)} + return {"model_input": model_input.cpu()} def generate_timestep_weights(args, num_timesteps): @@ -911,7 +911,7 @@ def preprocess_train(examples): proportion_empty_prompts=args.proportion_empty_prompts, caption_column=args.caption_column, ) - compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae) + compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) with accelerator.main_process_first(): from datasets.fingerprint import Hasher From 4100eb47312f979d8a42ffd3fb42b2a3dafc368c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Sat, 12 Oct 2024 15:46:17 +0800 Subject: [PATCH 4/6] Improve the performance and suitable for NPU --- examples/text_to_image/train_text_to_image_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 5d3e955c7d4e..7495c6ae14f4 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -541,6 +541,8 @@ def compute_vae_encodings(batch, vae): with torch.no_grad(): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor + # There might have slightly performance improvement + # by changing model_input.cpu() to accelerator.gather(model_input) return {"model_input": model_input.cpu()} From 8ed6fe086ccb2882fc1ce2dc286a70a9ca55a6f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Mon, 14 Oct 2024 15:51:36 +0800 Subject: [PATCH 5/6] Improve the performance and suitable for NPU --- examples/text_to_image/train_text_to_image_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 7495c6ae14f4..15fe7962541f 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -541,6 +541,7 @@ def compute_vae_encodings(batch, vae): with torch.no_grad(): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor + # There might have slightly performance improvement # by changing model_input.cpu() to accelerator.gather(model_input) return {"model_input": model_input.cpu()} From b79ab15e7c27e6e0855b4294abe184208aa7b72d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Mon, 14 Oct 2024 18:56:13 +0800 Subject: [PATCH 6/6] Improve the performance and suitable for NPU --- examples/text_to_image/train_text_to_image_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 15fe7962541f..bcf0fa9eb0ac 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -60,6 +60,7 @@ logger = get_logger(__name__) if is_torch_npu_available(): import torch_npu + torch.npu.config.allow_internal_format = False DATASET_NAME_MAPPING = {