11# Stable Diffusion XL text-to-image fine-tuning using PyTorch/XLA
22
3- The ` train_text_to_image_xla .py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
3+ The ` train_text_to_image_sdxl .py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
44
5- It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
5+ It has been tested on v5p TPU versions. Training code has been tested on a single v5p-8.
66
77This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
88where we shard the input batches over the TPU devices.
99
10- As of 10-31-2024 , these are some expected step times.
10+ As of 04-04-2025 , these are some expected step times.
1111
1212| accelerator | global batch size | step time (seconds) |
1313| ----------- | ----------------- | --------- |
14- | v5p-512 | 16384 | 1.01 |
15- | v5p-256 | 8192 | 1.01 |
16- | v5p-128 | 4096 | 1.0 |
17- | v5p-64 | 2048 | 1.01 |
18-
14+ | v5p-8 | 32 | 1.02 |
15+ | v5p-8 | 48 (with gradient checkpointing) | 1.42 |
16+ | v5p-8 | 64 (with gradient checkpointing) | 1.66 |
17+ |
1918## Create TPU
2019
2120To create a TPU on Google Cloud first set these environment variables:
@@ -39,65 +38,61 @@ You can also use other ways to reserve TPUs like GKE or queued resources.
3938
4039## Setup TPU environment
4140
41+ Assuming that you have conda setup on the VM
4242Install PyTorch and PyTorch/XLA nightly versions:
4343``` bash
44- gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
45- --project=${PROJECT_ID} --zone=${ZONE} --worker=all \
46- --command='
47- pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
48- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
49- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
50- '
44+ conda create -n torch310 python=3.10
45+ conda activate torch310
46+ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0a0%2Bgit01cb351-cp310-cp310-linux_x86_64.whl
47+ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0%2Bgite341ff0-cp310-cp310-linux_x86_64.whl
48+ pip install https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu/libtpu-0.0.11.dev20250303+nightly-py3-none-manylinux_2_27_x86_64.whl
49+ pip install torchax
50+ pip install jax==0.5.4.dev20250321 jaxlib==0.5.4.dev20250321 \
51+ -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
52+ pip install --no-deps torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
53+
54+ git clone -b sdxl_training_bbahl
[email protected] :entrpn/diffusers.git
55+ cd diffusers/
56+ pip install -r examples/research_projects/pytorch_xla/training/text_to_image/requirements_sdxl.txt
57+
58+ export CMAKE_PREFIX_PATH=${CONDA_PREFIX:- " $( dirname $( which conda) ) /../" }
59+ ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CMAKE_PREFIX_PATH} /lib/libstdc++.so.6
60+
61+ pip install -e .
5162```
5263
53- Verify that PyTorch and PyTorch/XLA were installed correctly:
64+ ## Run the training job
5465
55- ``` bash
56- gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
57- --project ${PROJECT_ID} --zone ${ZONE} --worker=all \
58- --command=' python3 -c "import torch; import torch_xla;"'
59- ```
66+ Run the following command to authenticate your token.
6067
61- Install dependencies:
6268``` bash
63- gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
64- --project=${PROJECT_ID} --zone=${ZONE} --worker=all \
65- --command='
66- git clone https://github.com/huggingface/diffusers.git
67- cd diffusers
68- git checkout main
69- cd examples/research_projects/pytorch_xla/training/text_to_image/
70- pip3 install -r requirements_sdxl.txt
71- pip3 install pillow --upgrade
72- cd ../../../../../
73- pip3 install .'
69+ huggingface-cli login
7470```
7571
76- ## Run the training job
77-
78- ### Authenticate
79-
80- Run the following command to authenticate your token.
72+ Please update the variables in ` train.sh ` as needed
8173
8274``` bash
83- huggingface-cli login
75+ export XLA_DISABLE_FUNCTIONALIZATION=1
76+ export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
77+ export PROFILE_DIR=/tmp/xla_profile/
78+ export CACHE_DIR=/tmp/xla_cache/
79+ export DATASET_NAME=lambdalabs/naruto-blip-captions
80+ export PER_HOST_BATCH_SIZE=64 # This is known to work on TPU v5p with gradient checkpointing.
81+ export TRAIN_STEPS=50
82+ export PROFILE_START_STEP=10
83+ export OUTPUT_DIR=/tmp/docker/trained-model/
84+ export HF_HOME=" /tmp/hf_home/"
85+ # export XLA_HLO_DEBUG=1
86+ # export XLA_IR_DEBUG=1
87+ python examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --measure_start_step=$PROFILE_START_STEP --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=5000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4 --xla_gradient_checkpointing
8488```
8589
90+
8691This script only trains the unet part of the network. The VAE and text encoder
8792are fixed.
8893
8994``` bash
90- gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
91- --project=${PROJECT_ID} --zone=${ZONE} --worker=all \
92- --command='
93- export XLA_DISABLE_FUNCTIONALIZATION=0
94- export PROFILE_DIR=/tmp/
95- export CACHE_DIR=/tmp/
96- export DATASET_NAME=lambdalabs/naruto-blip-captions
97- export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
98- export TRAIN_STEPS=50
99- export OUTPUT_DIR=/tmp/trained-model/
100- python diffusers/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
95+ ./train.sh
10196```
10297
10398Pass ` --print_loss ` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
@@ -119,52 +114,18 @@ input prompts. The first pass will compile the graph and takes longer with the f
119114
120115``` bash
121116export CACHE_DIR=/tmp/
117+ export OUTPUT_DIR=/tmp/trained-model
118+ python inference_sdxl.py
122119```
123-
124- ``` python
125- import torch
126- import os
127- import sys
128- import numpy as np
129-
130- import torch_xla.core.xla_model as xm
131- from time import time
132- from diffusers import StableDiffusionPipeline
133- import torch_xla.runtime as xr
134-
135- CACHE_DIR = os.environ.get(" CACHE_DIR" , None )
136- if CACHE_DIR :
137- xr.initialize_cache(CACHE_DIR , readonly = False )
138-
139- def main ():
140- device = xm.xla_device()
141- model_path = " jffacevedo/pxla_trained_model"
142- pipe = StableDiffusionPipeline.from_pretrained(
143- model_path,
144- torch_dtype = torch.bfloat16
145- )
146- pipe.to(device)
147- prompt = [" A naruto with green eyes and red legs." ]
148- start = time()
149- print (" compiling..." )
150- image = pipe(prompt, num_inference_steps = 30 , guidance_scale = 7.5 ).images[0 ]
151- print (f " compile time: { time() - start} " )
152- print (" generate..." )
153- start = time()
154- image = pipe(prompt, num_inference_steps = 30 , guidance_scale = 7.5 ).images[0 ]
155- print (f " generation time (after compile) : { time() - start} " )
156- image.save(" naruto.png" )
157-
158- if __name__ == ' __main__' :
159- main()
160- ```
161-
162120Expected Results:
163121
164122``` bash
165123compiling...
166- 100%| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 30/30 [10:03< 00:00, 20.10s/it]
167- compile time: 720.656970500946
124+ 0%| | 0/30 [00:00< ? , ? it/s]/mnt/bbahl/miniconda3/envs/torchverify/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:351: UserWarning: Device capability of jax unspecified, assuming ` cpu` and ` cuda` . Please specify it via the ` devices` argument of ` register_backend` .
125+ warnings.warn(
126+ 100%| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 30/30 [01:29< 00:00, 2.97s/it]
127+ compile time: 226.64892053604126
168128generate...
169- 100%| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 30/30 [00:01< 00:00, 17.65it/s]
170- generation time (after compile) : 1.8461642265319824
129+ 100%| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 30/30 [00:04< 00:00, 6.17it/s]
130+ generation time (after compile) : 5.120622396469116
131+ ` ` `
0 commit comments