@@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on
77This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
88where we shard the input batches over the TPU devices.
99
10- As of 9-11 -2024, these are some expected step times.
10+ As of 10-31 -2024, these are some expected step times.
1111
1212| accelerator | global batch size | step time (seconds) |
1313| ----------- | ----------------- | --------- |
14- | v5p-128 | 1024 | 0.245 |
15- | v5p-256 | 2048 | 0.234 |
16- | v5p-512 | 4096 | 0.2498 |
14+ | v5p-512 | 16384 | 1.01 |
15+ | v5p-256 | 8192 | 1.01 |
16+ | v5p-128 | 4096 | 1.0 |
17+ | v5p-64 | 2048 | 1.01 |
1718
1819## Create TPU
1920
@@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions:
4344gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
4445--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
4546--command='
46- pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
47- pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
47+ pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
48+ pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
49+ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
4850'
4951```
5052
@@ -88,7 +90,7 @@ are fixed.
8890gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
8991--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
9092--command='
91- export XLA_DISABLE_FUNCTIONALIZATION=1
93+ export XLA_DISABLE_FUNCTIONALIZATION=0
9294export PROFILE_DIR=/tmp/
9395export CACHE_DIR=/tmp/
9496export DATASET_NAME=lambdalabs/naruto-blip-captions
0 commit comments