@@ -34,34 +34,17 @@ jobs:
3434 tpu_unit_tests :
3535 runs-on : [linux-x86-ct5lp-224-8tpu]
3636 container :
37- image : python:3.12-slim
38- options : --privileged --cpus=2 --memory=4Gi
37+ image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
38+ options : --privileged
3939 env :
40- TPU_ACCELERATOR_TYPE : " "
41- JAX_PLATFORMS : " cpu"
42- CUDA_VISIBLE_DEVICES : " "
43- TF_CPP_MIN_LOG_LEVEL : " 3"
40+ CLOUD_TPU_ACCELERATOR : v5e-8
41+ JAX_PLATFORMS : tpu
4442 steps :
4543 - name : Checkout code
4644 uses : actions/checkout@v4
4745 with :
4846 fetch-depth : 0
4947
50- - name : Set up Python
51- uses : actions/setup-python@v4
52- with :
53- python-version : ' 3.12'
54-
55- - name : Install system dependencies
56- run : |
57- sudo apt-get update
58- sudo apt-get install -y git curl
59-
60- - name : Set up JAX for TPU
61- run : |
62- pip install --upgrade pip
63- pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
64-
6548 - name : Install tunix dependencies
6649 run : |
6750 pip install -e .
@@ -115,9 +98,10 @@ jobs:
11598 print('Llama3 params loaded successfully')
11699 "
117100
101+
118102 notify_failure :
119103 name : Notify failed build
120- needs : [tpu_unit_tests]
104+ needs : [tpu_unit_tests, tpu_integration_tests ]
121105 if : ${{ always() }}
122106 runs-on : ubuntu-latest
123107 permissions :
@@ -144,7 +128,7 @@ jobs:
144128 notify_success_and_close :
145129 name : Close issue after 3 successful builds
146130 if : ${{ success() && github.event.pull_request == null && github.event_name != 'workflow_dispatch' }}
147- needs : [tpu_unit_tests]
131+ needs : [tpu_unit_tests, tpu_integration_tests ]
148132 runs-on : ubuntu-latest
149133 permissions :
150134 issues : write
0 commit comments