@@ -24,10 +24,10 @@ jobs:
2424 strategy :
2525 fail-fast : false # don't cancel all jobs on failure
2626 matrix :
27- jaxlib-version : ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
27+ jaxlib-version : ["head", " pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
2828 tpu : [
29- # {type: "v3-8", cores: "4"}, # Enable when we have the v3/v4 type available
30- # {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
29+ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
30+ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
3131 {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
3232 ]
3333 python-version : ["3.10"]
@@ -47,14 +47,34 @@ jobs:
4747 # mandates using a specific commit for non-Google actions. We use
4848 # https://github.com/sethvargo/ratchet to pin specific versions.
4949 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
50+ # Checkout XLA at head, if we're building jaxlib at head.
51+ - name : Checkout XLA at head
52+ uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
53+ if : ${{ matrix.jaxlib-version == 'head' }}
54+ with :
55+ repository : openxla/xla
56+ path : xla
5057 - name : Install JAX test requirements
5158 run : |
5259 $PYTHON -m pip install -U -r build/test-requirements.txt
5360 $PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5461 - name : Install JAX
5562 run : |
5663 $PYTHON -m pip uninstall -y jax jaxlib libtpu
57- if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
64+ if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
65+ # Build and install jaxlib at head
66+ $PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \
67+ --bazel_options="--override_repository=xla=$(pwd)/xla" \
68+ --bazel_options=--color=yes
69+ $PYTHON -m pip install dist/*.whl
70+
71+ # Install "jax" at head
72+ $PYTHON -m pip install -U -e .
73+
74+ # Install libtpu
75+ $PYTHON -m pip install --pre libtpu \
76+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
77+ elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
5878 $PYTHON -m pip install .[tpu] \
5979 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6080
0 commit comments