-
Couldn't load subscription status.
- Fork 6.5k
Ptxla sd training #9381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ptxla sd training #9381
Conversation
|
Thanks for your contributions! Could we maybe move this to the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
… into ptxla_sd_training
|
@sayakpaul I moved the files to research_projects. Please review. |
|
|
||
| def main(args): | ||
| device = xm.xla_device() | ||
| model_path = <output_dir> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a repo id on the Hub that could be loaded here? This way, users can directly try out the snippet without having to look for one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uploaded a trained model to the hub and updated the model_path
examples/research_projects/pytorch_xla/train_text_to_image_xla.py
Outdated
Show resolved
Hide resolved
examples/research_projects/pytorch_xla/train_text_to_image_xla.py
Outdated
Show resolved
Hide resolved
| pixel_values, | ||
| input_ids, | ||
| ): | ||
| with xp.Trace("model.forward"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming these traces are thin enough to NOT introduce any unnecessary latency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think xp.Trace calls are very lightweight, and can help profiling.
examples/research_projects/pytorch_xla/train_text_to_image_xla.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Some minor comments here and there. But from an implementation standpoint, this looks excellent.
In the README, I think it'd be nice to include some gotchas that the users need to be aware of:
- Would the example work on a multi-node TPU host?
- How much wall-clock time can the users expect?
- Would the inference snippet work on multiple TPU chips?
|
|
||
| The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA. | ||
|
|
||
| It has been tested on v4 and v5p TPU versions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you be a little more specific on the TPU models you used? It would be nice if someone wants to reproduce this to know what is the amount of TPUs required to do this without hitting an OOM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added this to the readme. Please review.
|
@sayakpaul @tengomucho added changes per the comments. Please review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really good!
| As of 9-11-2024, these are some expected step times. | ||
|
|
||
| | accelerator | global batch size | step time (seconds) | | ||
| | ----------- | ----------------- | --------- | | ||
| | v5p-128 | 1024 | 0.245 | | ||
| | v5p-256 | 2048 | 0.234 | | ||
| | v5p-512 | 4096 | 0.2498 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very helpful, thanks much!
| | v5p-256 | 2048 | 0.234 | | ||
| | v5p-512 | 4096 | 0.2498 | | ||
|
|
||
| ## Create TPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are official documentation from GCP to be linked here, feel free to.
|
Thank you for this contribution! |
|
@entrpn do you think it could make sense to have something similar for Flux? It's the most popular text-to-image generation model right now. Ccing @linoytsaban and @apolinario for awareness as you do a lot of fine-tuning. |
We can revisit this in the future. Thank you for your help merging this. |
* enable pxla training of stable diffusion 2.x models. * run linter/style and run pipeline test for stable diffusion and fix issues. * update xla libraries * fix read me newline. * move files to research folder. * update per comments. * rename readme. --------- Co-authored-by: Juan Acevedo <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
@sayakpaul Enables Pytorch XLA training on TPUs for Stable Diffusion 2.x models.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.