Skip to content

Commit 341e63b

Browse files
zuasiajax authors
authored andcommitted
add xla_bridge test guard on cloud tpu env
PiperOrigin-RevId: 640269835
1 parent eba0564 commit 341e63b

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

tests/xla_bridge_test.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from absl import logging
2121
from absl.testing import absltest
2222

23+
from jax import version
2324
from jax._src import compiler
2425
from jax._src import config
2526
from jax._src import test_util as jtu
@@ -202,7 +203,12 @@ def test_register_plugin(self):
202203
self.assertIn("name2", xb._backend_factories)
203204
self.assertEqual(registration.priority, 400)
204205
self.assertTrue(registration.experimental)
205-
mock_make.assert_called_once_with("name1", {}, None)
206+
207+
options = {}
208+
if xb.get_backend().platform == 'tpu' and xla_extension_version >= 267:
209+
options["ml_framework_name"] = "JAX"
210+
options["ml_framework_version"] = version.__version__
211+
mock_make.assert_called_once_with("name1", options, None)
206212

207213
def test_register_plugin_with_config(self):
208214
test_json_file_path = os.path.join(
@@ -229,16 +235,19 @@ def test_register_plugin_with_config(self):
229235
self.assertIn("name1", xb._backend_factories)
230236
self.assertEqual(registration.priority, 400)
231237
self.assertTrue(registration.experimental)
232-
mock_make.assert_called_once_with(
233-
"name1",
234-
{
235-
"int_option": 64,
236-
"int_list_option": [32, 64],
237-
"string_option": "string",
238-
"float_option": 1.0,
239-
},
240-
None,
241-
)
238+
239+
# The expectation is specified in example_pjrt_plugin_config.json.
240+
options = {
241+
"int_option": 64,
242+
"int_list_option": [32, 64],
243+
"string_option": "string",
244+
"float_option": 1.0,
245+
}
246+
if xb.get_backend().platform == 'tpu' and xla_extension_version >= 267:
247+
options["ml_framework_name"] = "JAX"
248+
options["ml_framework_version"] = version.__version__
249+
250+
mock_make.assert_called_once_with("name1", options, None)
242251

243252

244253
class GetBackendTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)