Skip to content

Commit f04ee1d

Browse files
update ptxla example
1 parent 41e4779 commit f04ee1d

File tree

3 files changed

+117
-102
lines changed

3 files changed

+117
-102
lines changed

examples/research_projects/pytorch_xla/README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on
77
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
88
where we shard the input batches over the TPU devices.
99

10-
As of 9-11-2024, these are some expected step times.
10+
As of 10-31-2024, these are some expected step times.
1111

1212
| accelerator | global batch size | step time (seconds) |
1313
| ----------- | ----------------- | --------- |
14-
| v5p-128 | 1024 | 0.245 |
15-
| v5p-256 | 2048 | 0.234 |
16-
| v5p-512 | 4096 | 0.2498 |
14+
| v5p-512 | 16384 | 1.01 |
15+
| v5p-256 | 8192 | 1.01 |
16+
| v5p-128 | 4096 | 1.0 |
17+
| v5p-64 | 2048 | 1.01 |
1718

1819
## Create TPU
1920

@@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions:
4344
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
4445
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
4546
--command='
46-
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
47-
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
47+
pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
48+
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
49+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
4850
'
4951
```
5052

@@ -88,7 +90,7 @@ are fixed.
8890
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
8991
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
9092
--command='
91-
export XLA_DISABLE_FUNCTIONALIZATION=1
93+
export XLA_DISABLE_FUNCTIONALIZATION=0
9294
export PROFILE_DIR=/tmp/
9395
export CACHE_DIR=/tmp/
9496
export DATASET_NAME=lambdalabs/naruto-blip-captions

0 commit comments

Comments
 (0)