Skip to content

Commit a9513c1

Browse files
initial implementation of sdxl training
1 parent 56f7400 commit a9513c1

File tree

4 files changed

+1165
-0
lines changed

4 files changed

+1165
-0
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA
2+
3+
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
4+
5+
It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
6+
7+
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
8+
where we shard the input batches over the TPU devices.
9+
10+
As of 10-31-2024, these are some expected step times.
11+
12+
| accelerator | global batch size | step time (seconds) |
13+
| ----------- | ----------------- | --------- |
14+
| v5p-512 | 16384 | 1.01 |
15+
| v5p-256 | 8192 | 1.01 |
16+
| v5p-128 | 4096 | 1.0 |
17+
| v5p-64 | 2048 | 1.01 |
18+
19+
## Create TPU
20+
21+
To create a TPU on Google Cloud first set these environment variables:
22+
23+
```bash
24+
export TPU_NAME=<tpu-name>
25+
export PROJECT_ID=<project-id>
26+
export ZONE=<google-cloud-zone>
27+
export ACCELERATOR_TYPE=<accelerator type like v5p-8>
28+
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>
29+
```
30+
31+
Then run the create TPU command:
32+
```bash
33+
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
34+
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
35+
--reserved
36+
```
37+
38+
You can also use other ways to reserve TPUs like GKE or queued resources.
39+
40+
## Setup TPU environment
41+
42+
Install PyTorch and PyTorch/XLA nightly versions:
43+
```bash
44+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
45+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
46+
--command='
47+
pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
48+
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
49+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
50+
'
51+
```
52+
53+
Verify that PyTorch and PyTorch/XLA were installed correctly:
54+
55+
```bash
56+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
57+
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
58+
--command='python3 -c "import torch; import torch_xla;"'
59+
```
60+
61+
Install dependencies:
62+
```bash
63+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
64+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
65+
--command='
66+
git clone https://github.com/huggingface/diffusers.git
67+
cd diffusers
68+
git checkout main
69+
cd examples/research_projects/pytorch_xla
70+
pip3 install -r requirements.txt
71+
pip3 install pillow --upgrade
72+
cd ../../..
73+
pip3 install .'
74+
```
75+
76+
## Run the training job
77+
78+
### Authenticate
79+
80+
Run the following command to authenticate your token.
81+
82+
```bash
83+
huggingface-cli login
84+
```
85+
86+
This script only trains the unet part of the network. The VAE and text encoder
87+
are fixed.
88+
89+
```bash
90+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
91+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
92+
--command='
93+
export XLA_DISABLE_FUNCTIONALIZATION=0
94+
export PROFILE_DIR=/tmp/
95+
export CACHE_DIR=/tmp/
96+
export DATASET_NAME=lambdalabs/naruto-blip-captions
97+
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
98+
export TRAIN_STEPS=50
99+
export OUTPUT_DIR=/tmp/trained-model/
100+
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
101+
```
102+
103+
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
104+
105+
### Environment Envs Explained
106+
107+
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
108+
* `PROFILE_DIR`: Specify where to put the profiling results.
109+
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
110+
* `DATASET_NAME`: Dataset to train the model.
111+
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
112+
* `TRAIN_STEPS`: Total number of training steps to run the training for.
113+
* `OUTPUT_DIR`: Directory to store the fine-tuned model.
114+
115+
## Run inference using the output model
116+
117+
To run inference using the output, you can simply load the model and pass it
118+
input prompts. The first pass will compile the graph and takes longer with the following passes running much faster.
119+
120+
```bash
121+
export CACHE_DIR=/tmp/
122+
```
123+
124+
```python
125+
import torch
126+
import os
127+
import sys
128+
import numpy as np
129+
130+
import torch_xla.core.xla_model as xm
131+
from time import time
132+
from diffusers import StableDiffusionPipeline
133+
import torch_xla.runtime as xr
134+
135+
CACHE_DIR = os.environ.get("CACHE_DIR", None)
136+
if CACHE_DIR:
137+
xr.initialize_cache(CACHE_DIR, readonly=False)
138+
139+
def main():
140+
device = xm.xla_device()
141+
model_path = "jffacevedo/pxla_trained_model"
142+
pipe = StableDiffusionPipeline.from_pretrained(
143+
model_path,
144+
torch_dtype=torch.bfloat16
145+
)
146+
pipe.to(device)
147+
prompt = ["A naruto with green eyes and red legs."]
148+
start = time()
149+
print("compiling...")
150+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
151+
print(f"compile time: {time() - start}")
152+
print("generate...")
153+
start = time()
154+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
155+
print(f"generation time (after compile) : {time() - start}")
156+
image.save("naruto.png")
157+
158+
if __name__ == '__main__':
159+
main()
160+
```
161+
162+
Expected Results:
163+
164+
```bash
165+
compiling...
166+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
167+
compile time: 720.656970500946
168+
generate...
169+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
170+
generation time (after compile) : 1.8461642265319824
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Stable Diffusion XL text-to-image fine-tuning using PyTorch/XLA
2+
3+
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
4+
5+
It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
6+
7+
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
8+
where we shard the input batches over the TPU devices.
9+
10+
As of 10-31-2024, these are some expected step times.
11+
12+
| accelerator | global batch size | step time (seconds) |
13+
| ----------- | ----------------- | --------- |
14+
| v5p-512 | 16384 | 1.01 |
15+
| v5p-256 | 8192 | 1.01 |
16+
| v5p-128 | 4096 | 1.0 |
17+
| v5p-64 | 2048 | 1.01 |
18+
19+
## Create TPU
20+
21+
To create a TPU on Google Cloud first set these environment variables:
22+
23+
```bash
24+
export TPU_NAME=<tpu-name>
25+
export PROJECT_ID=<project-id>
26+
export ZONE=<google-cloud-zone>
27+
export ACCELERATOR_TYPE=<accelerator type like v5p-8>
28+
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>
29+
```
30+
31+
Then run the create TPU command:
32+
```bash
33+
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
34+
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
35+
--reserved
36+
```
37+
38+
You can also use other ways to reserve TPUs like GKE or queued resources.
39+
40+
## Setup TPU environment
41+
42+
Install PyTorch and PyTorch/XLA nightly versions:
43+
```bash
44+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
45+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
46+
--command='
47+
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
48+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
49+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
50+
'
51+
```
52+
53+
Verify that PyTorch and PyTorch/XLA were installed correctly:
54+
55+
```bash
56+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
57+
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
58+
--command='python3 -c "import torch; import torch_xla;"'
59+
```
60+
61+
Install dependencies:
62+
```bash
63+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
64+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
65+
--command='
66+
git clone https://github.com/huggingface/diffusers.git
67+
cd diffusers
68+
git checkout main
69+
cd examples/research_projects/pytorch_xla/training/text_to_image/
70+
pip3 install -r requirements_sdxl.txt
71+
pip3 install pillow --upgrade
72+
cd ../../../../../
73+
pip3 install .'
74+
```
75+
76+
## Run the training job
77+
78+
### Authenticate
79+
80+
Run the following command to authenticate your token.
81+
82+
```bash
83+
huggingface-cli login
84+
```
85+
86+
This script only trains the unet part of the network. The VAE and text encoder
87+
are fixed.
88+
89+
```bash
90+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
91+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
92+
--command='
93+
export XLA_DISABLE_FUNCTIONALIZATION=0
94+
export PROFILE_DIR=/tmp/
95+
export CACHE_DIR=/tmp/
96+
export DATASET_NAME=lambdalabs/naruto-blip-captions
97+
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
98+
export TRAIN_STEPS=50
99+
export OUTPUT_DIR=/tmp/trained-model/
100+
python diffusers/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
101+
```
102+
103+
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
104+
105+
### Environment Envs Explained
106+
107+
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
108+
* `PROFILE_DIR`: Specify where to put the profiling results.
109+
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
110+
* `DATASET_NAME`: Dataset to train the model.
111+
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
112+
* `TRAIN_STEPS`: Total number of training steps to run the training for.
113+
* `OUTPUT_DIR`: Directory to store the fine-tuned model.
114+
115+
## Run inference using the output model
116+
117+
To run inference using the output, you can simply load the model and pass it
118+
input prompts. The first pass will compile the graph and takes longer with the following passes running much faster.
119+
120+
```bash
121+
export CACHE_DIR=/tmp/
122+
```
123+
124+
```python
125+
import torch
126+
import os
127+
import sys
128+
import numpy as np
129+
130+
import torch_xla.core.xla_model as xm
131+
from time import time
132+
from diffusers import StableDiffusionPipeline
133+
import torch_xla.runtime as xr
134+
135+
CACHE_DIR = os.environ.get("CACHE_DIR", None)
136+
if CACHE_DIR:
137+
xr.initialize_cache(CACHE_DIR, readonly=False)
138+
139+
def main():
140+
device = xm.xla_device()
141+
model_path = "jffacevedo/pxla_trained_model"
142+
pipe = StableDiffusionPipeline.from_pretrained(
143+
model_path,
144+
torch_dtype=torch.bfloat16
145+
)
146+
pipe.to(device)
147+
prompt = ["A naruto with green eyes and red legs."]
148+
start = time()
149+
print("compiling...")
150+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
151+
print(f"compile time: {time() - start}")
152+
print("generate...")
153+
start = time()
154+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
155+
print(f"generation time (after compile) : {time() - start}")
156+
image.save("naruto.png")
157+
158+
if __name__ == '__main__':
159+
main()
160+
```
161+
162+
Expected Results:
163+
164+
```bash
165+
compiling...
166+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
167+
compile time: 720.656970500946
168+
generate...
169+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
170+
generation time (after compile) : 1.8461642265319824
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
accelerate>=0.16.0
2+
torch==2.5.1
3+
torchvision==0.21.0
4+
transformers>=4.25.1
5+
datasets>=2.19.1
6+
ftfy
7+
tensorboard
8+
Jinja2
9+
peft==0.7.0

0 commit comments

Comments
 (0)