File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
examples/research_projects/pytorch_xla/training/text_to_image Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
6363--command='
6464git clone -b sdxl_xla https://github.com/pytorch-tpu/diffusers.git
6565cd diffusers
66- cd examples/research_projects/pytorch_xla/text_to_image/
66+ cd examples/research_projects/pytorch_xla/training/ text_to_image/
6767pip3 install -r requirements.txt
6868pip3 install pillow --upgrade
6969cd ../../..
@@ -95,7 +95,7 @@ export PER_HOST_BATCH_SIZE=64 # This is known to work on TPU v5p
9595export TRAIN_STEPS=50
9696export PROFILE_START_STEP=10
9797export OUTPUT_DIR=/tmp/trained-model/
98- python diffusers/examples/research_projects/pytorch_xla/text_to_image/train_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --measure_start_step=$PROFILE_START_STEP --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=5000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4 --xla_gradient_checkpointing'
98+ python diffusers/examples/research_projects/pytorch_xla/training/ text_to_image/train_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --measure_start_step=$PROFILE_START_STEP --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=5000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4 --xla_gradient_checkpointing'
9999```
100100
101101Pass ` --print_loss ` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
You can’t perform that action at this time.
0 commit comments