Skip to content

Commit bae387e

Browse files
authored
Add jax test script (#406)
* Add jax test script * Adjust tag filters * Enable rbe for jax-ut * Fix docker url * Add rbe settings manually * Fix rbe pool name * Build the tests instead of running them as they are still red * Restore platform * Use xla path passed as argument
1 parent 061033b commit bae387e

File tree

4 files changed

+26
-49
lines changed

4 files changed

+26
-49
lines changed

build_tools/rocm/platform/linux_x64/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ platform(
1313
exec_properties = {
1414
"container-image": "rocm/tensorflow-build@sha256:7cd444ac48657fee2f5087fbda7766266704d3f8fb2299f681952ae4eabed060",
1515
"OSFamily": "Linux",
16-
"Pool": "linux_x64_large",
1716
},
1817
)
1918

@@ -27,6 +26,6 @@ platform(
2726
exec_properties = {
2827
"container-image": "rocm/tensorflow-build@sha256:7cd444ac48657fee2f5087fbda7766266704d3f8fb2299f681952ae4eabed060",
2928
"OSFamily": "Linux",
30-
"Pool": "amd_gpu",
29+
"Pool": "linux_x64_gpu",
3130
},
3231
)

build_tools/rocm/run_jax_ut.sh

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
JAX_DIR=$1
6+
XLA_DIR=$2
7+
8+
pushd $JAX_DIR
9+
10+
python build/build.py build \
11+
--wheels=jax-rocm-plugin \
12+
--configure_only \
13+
--local_xla_path=${XLA_DIR} \
14+
--python_version=3.12
15+
16+
# TODO: run the tests when they are green
17+
bazel build \
18+
--config=rocm \
19+
--build_tag_filters=cpu,gpu,-tpu,-config-cuda-only \
20+
--test_tag_filters=cpu,gpu,-tpu,-config-cuda-only \
21+
--action_env=TF_ROCM_AMDGPU_TARGETS=gfx908,gfx90a,gfx942 \
22+
--//jax:build_jaxlib=true \
23+
"//tests/..."
24+
25+
popd

build_tools/rocm/run_xla_ci_build.sh

Lines changed: 0 additions & 46 deletions
This file was deleted.

build_tools/rocm/run_xla_multi_gpu.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ EXCLUDED_TESTS=(
7171
# //xla/tools/multihost_hlo_runner:functional_hlo_runner_test
7272
FunctionalHloRunnerTest.Sharded2DevicesHloUnoptimizedSnapshot
7373
FunctionalHloRunnerTest.ShardedComputationUnderStreamCapture
74-
7574
)
7675

7776
SCRIPT_DIR=$(realpath $(dirname $0))

0 commit comments

Comments
 (0)