@@ -64,6 +64,22 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
6464 reason = "Torch backend export (via torch_xla) is incompatible with np 2.0" ,
6565)
6666class ExportSavedModelTest (testing .TestCase ):
67+ def setUp (self ):
68+ super ().setUp ()
69+ self .export_kwargs = {}
70+ if testing .jax_uses_gpu ():
71+ self .export_kwargs = {
72+ "jax2tf_kwargs" : {
73+ "native_serialization_platforms" : ("cpu" , "cuda" )
74+ }
75+ }
76+ elif testing .jax_uses_tpu ():
77+ self .export_kwargs = {
78+ "jax2tf_kwargs" : {
79+ "native_serialization_platforms" : ("cpu" , "tpu" )
80+ }
81+ }
82+
6783 @parameterized .named_parameters (
6884 named_product (model_type = ["sequential" , "functional" , "subclass" ])
6985 )
@@ -74,7 +90,9 @@ def test_standard_model_export(self, model_type):
7490 ref_input = np .random .normal (size = (batch_size , 10 )).astype ("float32" )
7591 ref_output = model (ref_input )
7692
77- saved_model .export_saved_model (model , temp_filepath )
93+ saved_model .export_saved_model (
94+ model , temp_filepath , ** self .export_kwargs
95+ )
7896 revived_model = tf .saved_model .load (temp_filepath )
7997 self .assertAllClose (ref_output , revived_model .serve (ref_input ))
8098 # Test with a different batch size
@@ -106,7 +124,9 @@ def call(self, inputs):
106124 ref_input = tf .random .normal ((3 , 10 ))
107125 ref_output = model (ref_input )
108126
109- saved_model .export_saved_model (model , temp_filepath )
127+ saved_model .export_saved_model (
128+ model , temp_filepath , ** self .export_kwargs
129+ )
110130 revived_model = tf .saved_model .load (temp_filepath )
111131 self .assertEqual (ref_output .shape , revived_model .serve (ref_input ).shape )
112132 # Test with a different batch size
@@ -142,7 +162,9 @@ def call(self, inputs):
142162 model = get_model (model_type , layer_list = [StateLayer ()])
143163 model (tf .random .normal ((3 , 10 )))
144164
145- saved_model .export_saved_model (model , temp_filepath )
165+ saved_model .export_saved_model (
166+ model , temp_filepath , ** self .export_kwargs
167+ )
146168 revived_model = tf .saved_model .load (temp_filepath )
147169
148170 # The non-trainable counter is expected to increment
@@ -164,7 +186,9 @@ def test_model_with_tf_data_layer(self, model_type):
164186 ref_input = np .random .normal (size = (batch_size , 10 )).astype ("float32" )
165187 ref_output = model (ref_input )
166188
167- saved_model .export_saved_model (model , temp_filepath )
189+ saved_model .export_saved_model (
190+ model , temp_filepath , ** self .export_kwargs
191+ )
168192 revived_model = tf .saved_model .load (temp_filepath )
169193 self .assertAllClose (ref_output , revived_model .serve (ref_input ))
170194 # Test with a different batch size
@@ -206,7 +230,9 @@ def call(self, inputs):
206230 temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
207231 ref_output = model (tree .map_structure (ops .convert_to_tensor , ref_input ))
208232
209- saved_model .export_saved_model (model , temp_filepath )
233+ saved_model .export_saved_model (
234+ model , temp_filepath , ** self .export_kwargs
235+ )
210236 revived_model = tf .saved_model .load (temp_filepath )
211237 self .assertAllClose (ref_output , revived_model .serve (ref_input ))
212238
@@ -247,7 +273,9 @@ def build(self, y_shape, x_shape):
247273 ref_input_y = np .random .normal (size = (batch_size , 10 )).astype ("float32" )
248274 ref_output = model (ref_input_x , ref_input_y )
249275
250- saved_model .export_saved_model (model , temp_filepath )
276+ saved_model .export_saved_model (
277+ model , temp_filepath , ** self .export_kwargs
278+ )
251279 revived_model = tf .saved_model .load (temp_filepath )
252280 self .assertAllClose (
253281 ref_output , revived_model .serve (ref_input_x , ref_input_y )
@@ -282,7 +310,10 @@ def test_input_signature(self, model_type, input_signature):
282310 else :
283311 input_signature = (input_signature ,)
284312 saved_model .export_saved_model (
285- model , temp_filepath , input_signature = input_signature
313+ model ,
314+ temp_filepath ,
315+ input_signature = input_signature ,
316+ ** self .export_kwargs ,
286317 )
287318 revived_model = tf .saved_model .load (temp_filepath )
288319 self .assertAllClose (
@@ -318,11 +349,17 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs):
318349 ref_input = ops .random .uniform ((3 , 10 ))
319350 ref_output = model (ref_input )
320351
352+ export_kwargs = self .export_kwargs .copy ()
353+ if "jax2tf_kwargs" in export_kwargs :
354+ export_kwargs ["jax2tf_kwargs" ].update (jax2tf_kwargs )
355+ else :
356+ export_kwargs ["jax2tf_kwargs" ] = jax2tf_kwargs
357+
321358 saved_model .export_saved_model (
322359 model ,
323360 temp_filepath ,
324361 is_static = is_static ,
325- jax2tf_kwargs = jax2tf_kwargs ,
362+ ** export_kwargs ,
326363 )
327364 revived_model = tf .saved_model .load (temp_filepath )
328365 self .assertAllClose (ref_output , revived_model .serve (ref_input ))
@@ -342,6 +379,22 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs):
342379 testing .torch_uses_gpu (), reason = "Leads to core dumps on CI"
343380)
344381class ExportArchiveTest (testing .TestCase ):
382+ def setUp (self ):
383+ super ().setUp ()
384+ self .add_endpoint_kwargs = {}
385+ if testing .jax_uses_gpu ():
386+ self .add_endpoint_kwargs = {
387+ "jax2tf_kwargs" : {
388+ "native_serialization_platforms" : ("cpu" , "cuda" )
389+ }
390+ }
391+ elif testing .jax_uses_tpu ():
392+ self .add_endpoint_kwargs = {
393+ "jax2tf_kwargs" : {
394+ "native_serialization_platforms" : ("cpu" , "tpu" )
395+ }
396+ }
397+
345398 @parameterized .named_parameters (
346399 named_product (model_type = ["sequential" , "functional" , "subclass" ])
347400 )
@@ -365,6 +418,7 @@ def test_low_level_model_export(self, model_type):
365418 "call" ,
366419 model .__call__ ,
367420 input_signature = [tf .TensorSpec (shape = (None , 10 ), dtype = tf .float32 )],
421+ ** self .add_endpoint_kwargs ,
368422 )
369423 export_archive .write_out (temp_filepath )
370424 revived_model = tf .saved_model .load (temp_filepath )
@@ -385,6 +439,7 @@ def test_low_level_model_export_with_alias(self):
385439 "call" ,
386440 model .__call__ ,
387441 input_signature = [tf .TensorSpec (shape = (None , 10 ), dtype = tf .float32 )],
442+ ** self .add_endpoint_kwargs ,
388443 )
389444 export_archive .write_out (
390445 temp_filepath ,
@@ -431,6 +486,7 @@ def call(self, inputs):
431486 tf .TensorSpec (shape = (None , None ), dtype = tf .float32 ),
432487 ]
433488 ],
489+ ** self .add_endpoint_kwargs ,
434490 )
435491 export_archive .write_out (temp_filepath )
436492 revived_model = tf .saved_model .load (temp_filepath )
@@ -445,29 +501,13 @@ def call(self, inputs):
445501 reason = "This test is only for the JAX backend." ,
446502 )
447503 def test_low_level_model_export_with_jax2tf_kwargs (self ):
448- temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
449-
450- model = get_model ()
451- ref_input = tf .random .normal ((3 , 10 ))
452- ref_output = model (ref_input )
453-
454504 export_archive = saved_model .ExportArchive ()
455- export_archive .track (model )
456- export_archive .add_endpoint (
457- "call" ,
458- model .__call__ ,
459- input_signature = [tf .TensorSpec (shape = (None , 10 ), dtype = tf .float32 )],
460- jax2tf_kwargs = {
461- "native_serialization" : True ,
462- "native_serialization_platforms" : ("cpu" , "tpu" ),
463- },
464- )
465505 with self .assertRaisesRegex (
466506 ValueError , "native_serialization_platforms.*bogus"
467507 ):
468508 export_archive .add_endpoint (
469- "call2 " ,
470- model . __call__ ,
509+ "call " ,
510+ lambda x : x ,
471511 input_signature = [
472512 tf .TensorSpec (shape = (None , 10 ), dtype = tf .float32 )
473513 ],
@@ -476,9 +516,6 @@ def test_low_level_model_export_with_jax2tf_kwargs(self):
476516 "native_serialization_platforms" : ("cpu" , "bogus" ),
477517 },
478518 )
479- export_archive .write_out (temp_filepath )
480- revived_model = tf .saved_model .load (temp_filepath )
481- self .assertAllClose (ref_output , revived_model .call (ref_input ))
482519
483520 @pytest .mark .skipif (
484521 backend .backend () != "jax" ,
@@ -506,12 +543,13 @@ def call(self, inputs):
506543 "call" ,
507544 model .__call__ ,
508545 input_signature = signature ,
509- jax2tf_kwargs = {} ,
546+ ** self . add_endpoint_kwargs ,
510547 )
511548 export_archive .write_out (temp_filepath )
512549
513550 export_archive = saved_model .ExportArchive ()
514551 export_archive .track (model )
552+ # TODO
515553 export_archive .add_endpoint (
516554 "call" ,
517555 model .__call__ ,
@@ -585,6 +623,7 @@ def model_call(x):
585623 model_call ,
586624 native_serialization = native_jax_compatible ,
587625 polymorphic_shapes = ["(b, 10)" ],
626+ # TODO
588627 )
589628
590629 # you can now build a TF inference function
0 commit comments