Skip to content

Commit 6714b6f

Browse files
committed
Fix training script path
1 parent e3b9695 commit 6714b6f

File tree

1 file changed

+2
-2
lines changed
  • examples/research_projects/pytorch_xla/training/text_to_image

1 file changed

+2
-2
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/README_sdxl.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
6363
--command='
6464
git clone -b sdxl_xla https://github.com/pytorch-tpu/diffusers.git
6565
cd diffusers
66-
cd examples/research_projects/pytorch_xla/text_to_image/
66+
cd examples/research_projects/pytorch_xla/training/text_to_image/
6767
pip3 install -r requirements.txt
6868
pip3 install pillow --upgrade
6969
cd ../../..
@@ -95,7 +95,7 @@ export PER_HOST_BATCH_SIZE=64 # This is known to work on TPU v5p
9595
export TRAIN_STEPS=50
9696
export PROFILE_START_STEP=10
9797
export OUTPUT_DIR=/tmp/trained-model/
98-
python diffusers/examples/research_projects/pytorch_xla/text_to_image/train_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --measure_start_step=$PROFILE_START_STEP --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=5000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4 --xla_gradient_checkpointing'
98+
python diffusers/examples/research_projects/pytorch_xla/training/text_to_image/train_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --measure_start_step=$PROFILE_START_STEP --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=5000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4 --xla_gradient_checkpointing'
9999
```
100100

101101
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.

0 commit comments

Comments
 (0)