Skip to content

Commit ee6749d

Browse files
authored
Update PJRT plugin API version to 0.57 (#19241)
It closes #19223. `integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h` is updated to the latest (with API version from 0.38 to 0.57), fetching from https://github.com/openxla/xla/blob/a454e14ab0b10e35fb8ad73bd6db7d93782114f6/xla/pjrt/c/pjrt_c_api.h. A blank implementation of `PJRT_Plugin_Attributes` is now provided since an unimplemented `PJRT_Plugin_Attributes` will lead to initialization failure of PJRT plugin (and thus crashes) in recent versions of PJRT clients. Also the JAX version in the CI workflow is updated from 0.4.20 to 0.4.35 and subsequently more tests can be enabled. ci-exactly: build_packages, test_pjrt --------- Signed-off-by: PragmaTwice <[email protected]>
1 parent dfcb594 commit ee6749d

File tree

6 files changed

+295
-152
lines changed

6 files changed

+295
-152
lines changed

.github/workflows/pkgci_test_pjrt.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ jobs:
6060
# install editable into venv
6161
source ${VENV_DIR}/bin/activate
6262
python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
63-
# install jax (must be no larger than 0.4.20, refer to #19223)
64-
# TODO: switch to the latest JAX after #19223 is fixed
65-
python -m pip install jax==0.4.20 jaxlib==0.4.20 'numpy<2'
63+
# install
64+
python -m pip install jax==0.4.35
6665
- name: Run tests
6766
run: |
6867
source ${VENV_DIR}/bin/activate

build_tools/testing/run_jax_tests.sh

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,9 @@ diff_jax_test() {
4444
echo "no difference found"
4545
}
4646

47-
# FIXME: due to #19223, we need to use jax no higher than 0.4.20,
48-
# but in such version of jax, 'stablehlo.broadcast_in_dim' op
49-
# will be emitted without attribute 'broadcast_dimensions',
50-
# which leads to an error in IREE PJRT plugin.
51-
# So currently any program with broadcast will fail,
52-
# e.g. test/test_simple.py.
53-
# After #19223 is fixed, we can uncomment the line below.
54-
55-
# diff_jax_test test/test_simple.py
56-
5747
diff_jax_test test/test_add.py
48+
diff_jax_test test/test_degenerate.py
49+
diff_jax_test test/test_simple.py
5850

5951

6052
# FIXME: we can also utilize the native test cases from JAX,

integrations/pjrt/src/iree_pjrt/common/api_impl.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,15 @@ void BindMonomorphicApi(PJRT_Api* api) {
21642164
BindUndefineds(api);
21652165
ErrorInstance::BindApi(api);
21662166

2167+
// PJRT_Plugin_Attributes should be implemented since it will always be
2168+
// called from the PJRT client in the initial phase.
2169+
// here we provide a blank implementation to avoid crash due to unimplemented.
2170+
api->PJRT_Plugin_Attributes =
2171+
+[](PJRT_Plugin_Attributes_Args* args) -> PJRT_Error* {
2172+
args->num_attributes = 0;
2173+
args->attributes = nullptr;
2174+
return nullptr;
2175+
};
21672176
api->PJRT_Plugin_Initialize =
21682177
+[](PJRT_Plugin_Initialize_Args* args) -> PJRT_Error* { return nullptr; };
21692178

integrations/pjrt/src/iree_pjrt/common/stubs.inc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,8 @@
9494
_STUB(PJRT_Client_CreateViewOfDeviceBuffer);
9595
_STUB(PJRT_Executable_Fingerprint);
9696
_STUB(PJRT_Client_TopologyDescription);
97+
_STUB(PJRT_Executable_GetCompiledMemoryStats);
98+
_STUB(PJRT_Memory_Kind_Id);
99+
_STUB(PJRT_ExecuteContext_Create);
100+
_STUB(PJRT_ExecuteContext_Destroy);
101+
_STUB(PJRT_Buffer_CopyRawToHost);

integrations/pjrt/third_party/pjrt_c_api/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ releases.
88
Last synced from:
99

1010
* https://github.com/openxla/xla.git
11-
* commit: 96d1250d70c0bd6adf2778f31a266c1813fd107a
11+
* commit: a454e14ab0b10e35fb8ad73bd6db7d93782114f6

0 commit comments

Comments
 (0)