Skip to content

Conversation

@entrpn
Copy link
Contributor

@entrpn entrpn commented Sep 6, 2024

What does this PR do?

@sayakpaul Enables Pytorch XLA training on TPUs for Stable Diffusion 2.x models.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
Copy link
Member

Thanks for your contributions! Could we maybe move this to the research_projects folder as we cannot test it at the moment?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@entrpn
Copy link
Contributor Author

entrpn commented Sep 9, 2024

@sayakpaul I moved the files to research_projects. Please review.


def main(args):
device = xm.xla_device()
model_path = <output_dir>
Copy link
Member

@sayakpaul sayakpaul Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a repo id on the Hub that could be loaded here? This way, users can directly try out the snippet without having to look for one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uploaded a trained model to the hub and updated the model_path

pixel_values,
input_ids,
):
with xp.Trace("model.forward"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming these traces are thin enough to NOT introduce any unnecessary latency?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think xp.Trace calls are very lightweight, and can help profiling.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Some minor comments here and there. But from an implementation standpoint, this looks excellent.

In the README, I think it'd be nice to include some gotchas that the users need to be aware of:

  • Would the example work on a multi-node TPU host?
  • How much wall-clock time can the users expect?
  • Would the inference snippet work on multiple TPU chips?


The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.

It has been tested on v4 and v5p TPU versions.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you be a little more specific on the TPU models you used? It would be nice if someone wants to reproduce this to know what is the amount of TPUs required to do this without hitting an OOM.

Copy link
Contributor Author

@entrpn entrpn Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this to the readme. Please review.

@entrpn
Copy link
Contributor Author

entrpn commented Sep 11, 2024

@sayakpaul @tengomucho added changes per the comments. Please review.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really good!

Comment on lines +10 to +16
As of 9-11-2024, these are some expected step times.

| accelerator | global batch size | step time (seconds) |
| ----------- | ----------------- | --------- |
| v5p-128 | 1024 | 0.245 |
| v5p-256 | 2048 | 0.234 |
| v5p-512 | 4096 | 0.2498 |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very helpful, thanks much!

| v5p-256 | 2048 | 0.234 |
| v5p-512 | 4096 | 0.2498 |

## Create TPU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are official documentation from GCP to be linked here, feel free to.

@sayakpaul sayakpaul merged commit 45aa8bb into huggingface:main Sep 12, 2024
15 checks passed
@sayakpaul
Copy link
Member

Thank you for this contribution!

@sayakpaul
Copy link
Member

@entrpn do you think it could make sense to have something similar for Flux? It's the most popular text-to-image generation model right now.

Ccing @linoytsaban and @apolinario for awareness as you do a lot of fine-tuning.

@entrpn
Copy link
Contributor Author

entrpn commented Sep 13, 2024

@entrpn do you think it could make sense to have something similar for Flux? It's the most popular text-to-image generation model right now.

Ccing @linoytsaban and @apolinario for awareness as you do a lot of fine-tuning.

We can revisit this in the future. Thank you for your help merging this.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* enable pxla training of stable diffusion 2.x models.

* run linter/style and run pipeline test for stable diffusion and fix issues.

* update xla libraries

* fix read me newline.

* move files to research folder.

* update per comments.

* rename readme.

---------

Co-authored-by: Juan Acevedo <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants