20
20
from absl import logging
21
21
from absl .testing import absltest
22
22
23
+ from jax import version
23
24
from jax ._src import compiler
24
25
from jax ._src import config
25
26
from jax ._src import test_util as jtu
@@ -202,7 +203,12 @@ def test_register_plugin(self):
202
203
self .assertIn ("name2" , xb ._backend_factories )
203
204
self .assertEqual (registration .priority , 400 )
204
205
self .assertTrue (registration .experimental )
205
- mock_make .assert_called_once_with ("name1" , {}, None )
206
+
207
+ options = {}
208
+ if xb .get_backend ().platform == 'tpu' and xla_extension_version >= 267 :
209
+ options ["ml_framework_name" ] = "JAX"
210
+ options ["ml_framework_version" ] = version .__version__
211
+ mock_make .assert_called_once_with ("name1" , options , None )
206
212
207
213
def test_register_plugin_with_config (self ):
208
214
test_json_file_path = os .path .join (
@@ -229,16 +235,19 @@ def test_register_plugin_with_config(self):
229
235
self .assertIn ("name1" , xb ._backend_factories )
230
236
self .assertEqual (registration .priority , 400 )
231
237
self .assertTrue (registration .experimental )
232
- mock_make .assert_called_once_with (
233
- "name1" ,
234
- {
235
- "int_option" : 64 ,
236
- "int_list_option" : [32 , 64 ],
237
- "string_option" : "string" ,
238
- "float_option" : 1.0 ,
239
- },
240
- None ,
241
- )
238
+
239
+ # The expectation is specified in example_pjrt_plugin_config.json.
240
+ options = {
241
+ "int_option" : 64 ,
242
+ "int_list_option" : [32 , 64 ],
243
+ "string_option" : "string" ,
244
+ "float_option" : 1.0 ,
245
+ }
246
+ if xb .get_backend ().platform == 'tpu' and xla_extension_version >= 267 :
247
+ options ["ml_framework_name" ] = "JAX"
248
+ options ["ml_framework_version" ] = version .__version__
249
+
250
+ mock_make .assert_called_once_with ("name1" , options , None )
242
251
243
252
244
253
class GetBackendTest (jtu .JaxTestCase ):
0 commit comments