Skip to content

Commit dd3f34c

Browse files
authored
Use bazel for PR tests (#216)
* Use bazel for running pre-merge CI tests * Don't use HEREDOC * Fix block text * Use bash array * Add bazel install * Put Bazel in the build image * Use Bazelisk * Remove bazel install in Docker * Go back to upstream XLA * Remove bazel test command from workflow * Move test command to build container * Fix string format typos
1 parent 60f51d2 commit dd3f34c

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

.github/workflows/rocm-ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ jobs:
5959
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
6060
path: ./dist/*.whl
6161
- name: Run tests
62+
env:
63+
GPU_COUNT: "8"
64+
GFX: "gfx90a"
6265
run: |
6366
cd $WORKSPACE_DIR
6467
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"

build/rocm/ci_build

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,31 @@ def dist_wheels(
127127
]
128128
)
129129

130+
# Add command for unit tests
131+
cmd.extend(
132+
[
133+
"&&",
134+
"bazel",
135+
"test",
136+
"-k",
137+
"--jobs=4",
138+
"--test_verbose_timeout_warnings=true",
139+
"--test_output=all",
140+
"--test_summary=detailed",
141+
"--local_test_jobs=1",
142+
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
143+
"--test_env=JAX_SKIP_SLOW_TESTS=0",
144+
"--verbose_failures=true",
145+
"--config=rocm",
146+
"--action_env=ROCM_PATH=/opt/rocm",
147+
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
148+
"--test_tag_filters=-multiaccelerator",
149+
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
150+
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
151+
"//tests:gpu_tests",
152+
]
153+
)
154+
130155
LOG.info("Running: %s", cmd)
131156
_ = subprocess.run(cmd, check=True)
132157

third_party/xla/workspace.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def repo():
2929
name = "xla",
3030
sha256 = XLA_SHA256,
3131
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
32-
urls = tf_mirror_urls("https://github.com/rocm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
32+
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
3333
)
3434

3535
# For development, one often wants to make changes to the TF repository as well

0 commit comments

Comments
 (0)