Skip to content

Commit 47867ee

Browse files
committed
Update readme and step time calculation. Also add flag for gradient checkpointing
1 parent 8415cb2 commit 47867ee

File tree

5 files changed

+76
-275
lines changed

5 files changed

+76
-275
lines changed

β€Žexamples/research_projects/pytorch_xla/training/text_to_image/README.mdβ€Ž

Lines changed: 0 additions & 170 deletions
This file was deleted.
Lines changed: 54 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
# Stable Diffusion XL text-to-image fine-tuning using PyTorch/XLA
22

3-
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
3+
The `train_text_to_image_sdxl.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
44

5-
It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
5+
It has been tested on v5p TPU versions. Training code has been tested on a single v5p-8.
66

77
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
88
where we shard the input batches over the TPU devices.
99

10-
As of 10-31-2024, these are some expected step times.
10+
As of 04-04-2025, these are some expected step times.
1111

1212
| accelerator | global batch size | step time (seconds) |
1313
| ----------- | ----------------- | --------- |
14-
| v5p-512 | 16384 | 1.01 |
15-
| v5p-256 | 8192 | 1.01 |
16-
| v5p-128 | 4096 | 1.0 |
17-
| v5p-64 | 2048 | 1.01 |
18-
14+
| v5p-8 | 32 | 1.02 |
15+
| v5p-8 | 48 (with gradient checkpointing) | 1.42 |
16+
| v5p-8 | 64 (with gradient checkpointing) | 1.66 |
17+
|
1918
## Create TPU
2019

2120
To create a TPU on Google Cloud first set these environment variables:
@@ -39,65 +38,61 @@ You can also use other ways to reserve TPUs like GKE or queued resources.
3938

4039
## Setup TPU environment
4140

41+
Assuming that you have conda setup on the VM
4242
Install PyTorch and PyTorch/XLA nightly versions:
4343
```bash
44-
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
45-
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
46-
--command='
47-
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
48-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
49-
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
50-
'
44+
conda create -n torch310 python=3.10
45+
conda activate torch310
46+
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0a0%2Bgit01cb351-cp310-cp310-linux_x86_64.whl
47+
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0%2Bgite341ff0-cp310-cp310-linux_x86_64.whl
48+
pip install https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu/libtpu-0.0.11.dev20250303+nightly-py3-none-manylinux_2_27_x86_64.whl
49+
pip install torchax
50+
pip install jax==0.5.4.dev20250321 jaxlib==0.5.4.dev20250321 \
51+
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
52+
pip install --no-deps torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
53+
54+
git clone -b sdxl_training_bbahl [email protected]:entrpn/diffusers.git
55+
cd diffusers/
56+
pip install -r examples/research_projects/pytorch_xla/training/text_to_image/requirements_sdxl.txt
57+
58+
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
59+
ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CMAKE_PREFIX_PATH}/lib/libstdc++.so.6
60+
61+
pip install -e .
5162
```
5263

53-
Verify that PyTorch and PyTorch/XLA were installed correctly:
64+
## Run the training job
5465

55-
```bash
56-
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
57-
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
58-
--command='python3 -c "import torch; import torch_xla;"'
59-
```
66+
Run the following command to authenticate your token.
6067

61-
Install dependencies:
6268
```bash
63-
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
64-
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
65-
--command='
66-
git clone https://github.com/huggingface/diffusers.git
67-
cd diffusers
68-
git checkout main
69-
cd examples/research_projects/pytorch_xla/training/text_to_image/
70-
pip3 install -r requirements_sdxl.txt
71-
pip3 install pillow --upgrade
72-
cd ../../../../../
73-
pip3 install .'
69+
huggingface-cli login
7470
```
7571

76-
## Run the training job
77-
78-
### Authenticate
79-
80-
Run the following command to authenticate your token.
72+
Please update the variables in `train.sh` as needed
8173

8274
```bash
83-
huggingface-cli login
75+
export XLA_DISABLE_FUNCTIONALIZATION=1
76+
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
77+
export PROFILE_DIR=/tmp/xla_profile/
78+
export CACHE_DIR=/tmp/xla_cache/
79+
export DATASET_NAME=lambdalabs/naruto-blip-captions
80+
export PER_HOST_BATCH_SIZE=64 # This is known to work on TPU v5p with gradient checkpointing.
81+
export TRAIN_STEPS=50
82+
export PROFILE_START_STEP=10
83+
export OUTPUT_DIR=/tmp/docker/trained-model/
84+
export HF_HOME="/tmp/hf_home/"
85+
# export XLA_HLO_DEBUG=1
86+
# export XLA_IR_DEBUG=1
87+
python examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --measure_start_step=$PROFILE_START_STEP --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=5000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4 --xla_gradient_checkpointing
8488
```
8589

90+
8691
This script only trains the unet part of the network. The VAE and text encoder
8792
are fixed.
8893

8994
```bash
90-
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
91-
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
92-
--command='
93-
export XLA_DISABLE_FUNCTIONALIZATION=0
94-
export PROFILE_DIR=/tmp/
95-
export CACHE_DIR=/tmp/
96-
export DATASET_NAME=lambdalabs/naruto-blip-captions
97-
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
98-
export TRAIN_STEPS=50
99-
export OUTPUT_DIR=/tmp/trained-model/
100-
python diffusers/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
95+
./train.sh
10196
```
10297

10398
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
@@ -119,52 +114,18 @@ input prompts. The first pass will compile the graph and takes longer with the f
119114

120115
```bash
121116
export CACHE_DIR=/tmp/
117+
export OUTPUT_DIR=/tmp/trained-model
118+
python inference_sdxl.py
122119
```
123-
124-
```python
125-
import torch
126-
import os
127-
import sys
128-
import numpy as np
129-
130-
import torch_xla.core.xla_model as xm
131-
from time import time
132-
from diffusers import StableDiffusionPipeline
133-
import torch_xla.runtime as xr
134-
135-
CACHE_DIR = os.environ.get("CACHE_DIR", None)
136-
if CACHE_DIR:
137-
xr.initialize_cache(CACHE_DIR, readonly=False)
138-
139-
def main():
140-
device = xm.xla_device()
141-
model_path = "jffacevedo/pxla_trained_model"
142-
pipe = StableDiffusionPipeline.from_pretrained(
143-
model_path,
144-
torch_dtype=torch.bfloat16
145-
)
146-
pipe.to(device)
147-
prompt = ["A naruto with green eyes and red legs."]
148-
start = time()
149-
print("compiling...")
150-
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
151-
print(f"compile time: {time() - start}")
152-
print("generate...")
153-
start = time()
154-
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
155-
print(f"generation time (after compile) : {time() - start}")
156-
image.save("naruto.png")
157-
158-
if __name__ == '__main__':
159-
main()
160-
```
161-
162120
Expected Results:
163121

164122
```bash
165123
compiling...
166-
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 30/30 [10:03<00:00, 20.10s/it]
167-
compile time: 720.656970500946
124+
0%| | 0/30 [00:00<?, ?it/s]/mnt/bbahl/miniconda3/envs/torchverify/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:351: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
125+
warnings.warn(
126+
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 30/30 [01:29<00:00, 2.97s/it]
127+
compile time: 226.64892053604126
168128
generate...
169-
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 30/30 [00:01<00:00, 17.65it/s]
170-
generation time (after compile) : 1.8461642265319824
129+
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 30/30 [00:04<00:00, 6.17it/s]
130+
generation time (after compile) : 5.120622396469116
131+
```

β€Žexamples/research_projects/pytorch_xla/training/text_to_image/inference_sdxl.pyβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
# pipe.vae.enable_xla_attention()
2828
start = time()
2929
print("compiling...")
30-
import pdb; pdb.set_trace()
3130
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
3231
print(f"compile time: {time() - start}")
3332
print("generate...")

0 commit comments

Comments
Β (0)