-
Couldn't load subscription status.
- Fork 6.5k
Improve the performance and suitable for NPU computing #9642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+14
−4
Merged
Changes from 1 commit
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e2f0a7b
Improve the performance and suitable for NPU
ab3cd4f
Improve the performance and suitable for NPU computing
98f55d0
Improve the performance and suitable for NPU
4100eb4
Improve the performance and suitable for NPU
8ed6fe0
Improve the performance and suitable for NPU
b79ab15
Improve the performance and suitable for NPU
a1748fc
Merge branch 'main' into main
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,7 @@ | |
|
|
||
| logger = get_logger(__name__) | ||
| if is_torch_npu_available(): | ||
| import torch_npu | ||
| torch.npu.config.allow_internal_format = False | ||
|
|
||
| DATASET_NAME_MAPPING = { | ||
|
|
@@ -531,7 +532,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca | |
| return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} | ||
|
|
||
|
|
||
| def compute_vae_encodings(batch, vae): | ||
| def compute_vae_encodings(batch, accelerator, vae): | ||
| images = batch.pop("pixel_values") | ||
| pixel_values = torch.stack(list(images)) | ||
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | ||
|
|
@@ -540,7 +541,7 @@ def compute_vae_encodings(batch, vae): | |
| with torch.no_grad(): | ||
| model_input = vae.encode(pixel_values).latent_dist.sample() | ||
| model_input = model_input * vae.config.scaling_factor | ||
| return {"model_input": model_input.cpu()} | ||
| return {"model_input": accelerator.gather(model_input)} | ||
|
|
||
|
|
||
| def generate_timestep_weights(args, num_timesteps): | ||
|
|
@@ -910,7 +911,7 @@ def preprocess_train(examples): | |
| proportion_empty_prompts=args.proportion_empty_prompts, | ||
| caption_column=args.caption_column, | ||
| ) | ||
| compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) | ||
| compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae) | ||
| with accelerator.main_process_first(): | ||
| from datasets.fingerprint import Hasher | ||
|
|
||
|
|
@@ -935,7 +936,10 @@ def preprocess_train(examples): | |
| del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two | ||
| del text_encoders, tokenizers, vae | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
| if is_torch_npu_available(): | ||
| torch_npu.npu.empty_cache() | ||
| else: | ||
| torch.cuda.empty_cache() | ||
leisuzz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def collate_fn(examples): | ||
| model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) | ||
|
|
@@ -1091,8 +1095,7 @@ def compute_time_ids(original_size, crops_coords_top_left): | |
| # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids | ||
| target_size = (args.resolution, args.resolution) | ||
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | ||
| add_time_ids = torch.tensor([add_time_ids]) | ||
| add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) | ||
| add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, this makes sense! |
||
| return add_time_ids | ||
|
|
||
| add_time_ids = torch.cat( | ||
|
|
@@ -1261,7 +1264,10 @@ def compute_time_ids(original_size, crops_coords_top_left): | |
| ) | ||
|
|
||
| del pipeline | ||
| torch.cuda.empty_cache() | ||
| if is_torch_npu_available(): | ||
| torch_npu.npu.empty_cache() | ||
| else: | ||
| torch.cuda.empty_cache() | ||
leisuzz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if args.use_ema: | ||
| # Switch back to the original UNet parameters. | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By using the accelerator, the communication time can be reduced
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reason may caused by the pixel_values, as it is in vae.device (accelerator). Therefore, by changing the code, the accelerator can distribute and reduce the time cost.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But isn't an all-gather a more expensive op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, I tested three different approaches, the first one is accelerator.gather(model_input), the average FPS is 29.13 with training duration 530; the second one is model_input.to(accelerator.device), the average FPS is 27.41 with training duration 544; the last one is the original model_input.cpu(), the average FPS is 28.56 with training duration 537. Overall, with same hardware, the FPS will increase a little with accelerator.gather. I tested multiple times with accelerator.gather and model_input.cpu(), the average FPS in accelerator.gather is larger than model_input.cpu().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, thanks! Since the performance improvement seems to be minor, do you think it makes sense to not change this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense, I will change it back to the .cpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I welcome you to also add a note on your findings on
accelerate.gather()so that users are aware. I think that'd still be quite valuable.