Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False

DATASET_NAME_MAPPING = {
Expand Down Expand Up @@ -531,7 +532,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}


def compute_vae_encodings(batch, vae):
def compute_vae_encodings(batch, accelerator, vae):
images = batch.pop("pixel_values")
pixel_values = torch.stack(list(images))
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
Expand All @@ -540,7 +541,7 @@ def compute_vae_encodings(batch, vae):
with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
return {"model_input": model_input.cpu()}
return {"model_input": accelerator.gather(model_input)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using the accelerator, the communication time can be reduced

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason may caused by the pixel_values, as it is in vae.device (accelerator). Therefore, by changing the code, the accelerator can distribute and reduce the time cost.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't an all-gather a more expensive op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I tested three different approaches, the first one is accelerator.gather(model_input), the average FPS is 29.13 with training duration 530; the second one is model_input.to(accelerator.device), the average FPS is 27.41 with training duration 544; the last one is the original model_input.cpu(), the average FPS is 28.56 with training duration 537. Overall, with same hardware, the FPS will increase a little with accelerator.gather. I tested multiple times with accelerator.gather and model_input.cpu(), the average FPS in accelerator.gather is larger than model_input.cpu().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, thanks! Since the performance improvement seems to be minor, do you think it makes sense to not change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, I will change it back to the .cpu

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I welcome you to also add a note on your findings on accelerate.gather() so that users are aware. I think that'd still be quite valuable.



def generate_timestep_weights(args, num_timesteps):
Expand Down Expand Up @@ -910,7 +911,7 @@ def preprocess_train(examples):
proportion_empty_prompts=args.proportion_empty_prompts,
caption_column=args.caption_column,
)
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae)
with accelerator.main_process_first():
from datasets.fingerprint import Hasher

Expand All @@ -935,7 +936,10 @@ def preprocess_train(examples):
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
del text_encoders, tokenizers, vae
gc.collect()
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
else:
torch.cuda.empty_cache()

def collate_fn(examples):
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
Expand Down Expand Up @@ -1091,8 +1095,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this makes sense!

return add_time_ids

add_time_ids = torch.cat(
Expand Down Expand Up @@ -1261,7 +1264,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
)

del pipeline
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
else:
torch.cuda.empty_cache()

if args.use_ema:
# Switch back to the original UNet parameters.
Expand Down