Skip to content

Commit 570aad8

Browse files
committed
Fix some TPU tests.
- `jax2onnx` fixed the ONNX export on TPU.
1 parent 5b7ab6f commit 570aad8

File tree

3 files changed

+82
-111
lines changed

3 files changed

+82
-111
lines changed
Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1 @@
1-
ConvTransposeBasicTest
2-
ExportArchiveTest::test_jax_endpoint_registration_tf_function
3-
ExportArchiveTest::test_jax_multi_unknown_endpoint_registration
4-
ExportArchiveTest::test_layer_export
5-
ExportArchiveTest::test_low_level_model_export_functional
6-
ExportArchiveTest::test_low_level_model_export_sequential
7-
ExportArchiveTest::test_low_level_model_export_subclass
8-
ExportArchiveTest::test_low_level_model_export_with_alias
9-
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_functional
10-
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_sequential
11-
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_subclass
12-
ExportArchiveTest::test_low_level_model_export_with_jax2tf_kwargs
13-
ExportArchiveTest::test_low_level_model_export_with_jax2tf_polymorphic_shapes
14-
ExportArchiveTest::test_model_combined_with_tf_preprocessing
15-
ExportArchiveTest::test_model_export_method_functional
16-
ExportArchiveTest::test_model_export_method_sequential
17-
ExportArchiveTest::test_model_export_method_subclass
18-
ExportArchiveTest::test_multi_input_output_functional_model
19-
ExportArchiveTest::test_non_standard_layer_signature
20-
ExportArchiveTest::test_non_standard_layer_signature_with_kwargs
21-
ExportArchiveTest::test_track_multiple_layers
22-
ExportONNXTest::test_export_with_input_names
23-
ExportONNXTest::test_export_with_opset_version_18
24-
ExportONNXTest::test_export_with_opset_version_none
25-
ExportONNXTest::test_model_with_input_structure_array
26-
ExportONNXTest::test_model_with_input_structure_dict
27-
ExportONNXTest::test_model_with_input_structure_tuple
28-
ExportONNXTest::test_model_with_multiple_inputs
29-
ExportONNXTest::test_standard_model_export_functional
30-
ExportONNXTest::test_standard_model_export_lstm
31-
ExportONNXTest::test_standard_model_export_sequential
32-
ExportONNXTest::test_standard_model_export_subclass
33-
ExportOpenVINOTest::test_model_with_input_structure_array
34-
ExportOpenVINOTest::test_model_with_input_structure_dict
35-
ExportOpenVINOTest::test_model_with_input_structure_tuple
36-
ExportOpenVINOTest::test_model_with_multiple_inputs
37-
ExportOpenVINOTest::test_standard_model_export_functional
38-
ExportOpenVINOTest::test_standard_model_export_sequential
39-
ExportOpenVINOTest::test_standard_model_export_subclass
40-
ExportSavedModelTest::test_input_signature_functional_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
41-
ExportSavedModelTest::test_input_signature_functional_backend_tensor
42-
ExportSavedModelTest::test_input_signature_functional_inputspec(dtype=float32, shape=(none, 10), ndim=2)
43-
ExportSavedModelTest::test_input_signature_functional_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
44-
ExportSavedModelTest::test_input_signature_sequential_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
45-
ExportSavedModelTest::test_input_signature_sequential_backend_tensor
46-
ExportSavedModelTest::test_input_signature_sequential_inputspec(dtype=float32, shape=(none, 10), ndim=2)
47-
ExportSavedModelTest::test_input_signature_sequential_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
48-
ExportSavedModelTest::test_input_signature_subclass_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
49-
ExportSavedModelTest::test_input_signature_subclass_backend_tensor
50-
ExportSavedModelTest::test_input_signature_subclass_inputspec(dtype=float32, shape=(none, 10), ndim=2)
51-
ExportSavedModelTest::test_input_signature_subclass_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
52-
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_{'enable_xla': true, 'native_serialization': true}
53-
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_none
54-
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_{'enable_xla': true, 'native_serialization': true}
55-
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_none
56-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_{'enable_xla': true, 'native_serialization': true}
57-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_none
58-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_{'enable_xla': true, 'native_serialization': true}
59-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_none
60-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_{'enable_xla': true, 'native_serialization': true}
61-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_none
62-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_{'enable_xla': true, 'native_serialization': true}
63-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_none
64-
ExportSavedModelTest::test_model_with_input_structure_array
65-
ExportSavedModelTest::test_model_with_input_structure_dict
66-
ExportSavedModelTest::test_model_with_input_structure_tuple
67-
ExportSavedModelTest::test_model_with_multiple_inputs
68-
ExportSavedModelTest::test_model_with_non_trainable_state_export_functional
69-
ExportSavedModelTest::test_model_with_non_trainable_state_export_sequential
70-
ExportSavedModelTest::test_model_with_non_trainable_state_export_subclass
71-
ExportSavedModelTest::test_model_with_rng_export_functional
72-
ExportSavedModelTest::test_model_with_rng_export_sequential
73-
ExportSavedModelTest::test_model_with_rng_export_subclass
74-
ExportSavedModelTest::test_model_with_tf_data_layer_functional
75-
ExportSavedModelTest::test_model_with_tf_data_layer_sequential
76-
ExportSavedModelTest::test_model_with_tf_data_layer_subclass
77-
ExportSavedModelTest::test_standard_model_export_functional
78-
ExportSavedModelTest::test_standard_model_export_sequential
79-
ExportSavedModelTest::test_standard_model_export_subclass
1+
ConvTransposeBasicTest

keras/src/export/onnx_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,12 @@ def test_standard_model_export(self, model_type):
101101
ort_inputs = {
102102
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
103103
}
104-
self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])
104+
self.assertAllClose(
105+
ref_output,
106+
ort_session.run(None, ort_inputs)[0],
107+
tpu_atol=1e-3,
108+
tpu_rtol=1e-2,
109+
)
105110
# Test with a different batch size
106111
ort_inputs = {
107112
k.name: v
@@ -291,7 +296,12 @@ def test_export_with_input_names(self):
291296
ort_inputs = {
292297
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
293298
}
294-
self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])
299+
self.assertAllClose(
300+
ref_output,
301+
ort_session.run(None, ort_inputs)[0],
302+
tpu_atol=1e-3,
303+
tpu_rtol=1e-2,
304+
)
295305

296306
@parameterized.named_parameters(
297307
named_product(

keras/src/export/saved_model_test.py

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
6464
reason="Torch backend export (via torch_xla) is incompatible with np 2.0",
6565
)
6666
class ExportSavedModelTest(testing.TestCase):
67+
def setUp(self):
68+
super().setUp()
69+
self.export_kwargs = {}
70+
if testing.jax_uses_gpu():
71+
self.export_kwargs = {
72+
"jax2tf_kwargs": {
73+
"native_serialization_platforms": ("cpu", "cuda")
74+
}
75+
}
76+
elif testing.jax_uses_tpu():
77+
self.export_kwargs = {
78+
"jax2tf_kwargs": {
79+
"native_serialization_platforms": ("cpu", "tpu")
80+
}
81+
}
82+
6783
@parameterized.named_parameters(
6884
named_product(model_type=["sequential", "functional", "subclass"])
6985
)
@@ -74,7 +90,9 @@ def test_standard_model_export(self, model_type):
7490
ref_input = np.random.normal(size=(batch_size, 10)).astype("float32")
7591
ref_output = model(ref_input)
7692

77-
saved_model.export_saved_model(model, temp_filepath)
93+
saved_model.export_saved_model(
94+
model, temp_filepath, **self.export_kwargs
95+
)
7896
revived_model = tf.saved_model.load(temp_filepath)
7997
self.assertAllClose(ref_output, revived_model.serve(ref_input))
8098
# Test with a different batch size
@@ -106,7 +124,9 @@ def call(self, inputs):
106124
ref_input = tf.random.normal((3, 10))
107125
ref_output = model(ref_input)
108126

109-
saved_model.export_saved_model(model, temp_filepath)
127+
saved_model.export_saved_model(
128+
model, temp_filepath, **self.export_kwargs
129+
)
110130
revived_model = tf.saved_model.load(temp_filepath)
111131
self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape)
112132
# Test with a different batch size
@@ -142,7 +162,9 @@ def call(self, inputs):
142162
model = get_model(model_type, layer_list=[StateLayer()])
143163
model(tf.random.normal((3, 10)))
144164

145-
saved_model.export_saved_model(model, temp_filepath)
165+
saved_model.export_saved_model(
166+
model, temp_filepath, **self.export_kwargs
167+
)
146168
revived_model = tf.saved_model.load(temp_filepath)
147169

148170
# The non-trainable counter is expected to increment
@@ -164,7 +186,9 @@ def test_model_with_tf_data_layer(self, model_type):
164186
ref_input = np.random.normal(size=(batch_size, 10)).astype("float32")
165187
ref_output = model(ref_input)
166188

167-
saved_model.export_saved_model(model, temp_filepath)
189+
saved_model.export_saved_model(
190+
model, temp_filepath, **self.export_kwargs
191+
)
168192
revived_model = tf.saved_model.load(temp_filepath)
169193
self.assertAllClose(ref_output, revived_model.serve(ref_input))
170194
# Test with a different batch size
@@ -206,7 +230,9 @@ def call(self, inputs):
206230
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
207231
ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input))
208232

209-
saved_model.export_saved_model(model, temp_filepath)
233+
saved_model.export_saved_model(
234+
model, temp_filepath, **self.export_kwargs
235+
)
210236
revived_model = tf.saved_model.load(temp_filepath)
211237
self.assertAllClose(ref_output, revived_model.serve(ref_input))
212238

@@ -247,7 +273,9 @@ def build(self, y_shape, x_shape):
247273
ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32")
248274
ref_output = model(ref_input_x, ref_input_y)
249275

250-
saved_model.export_saved_model(model, temp_filepath)
276+
saved_model.export_saved_model(
277+
model, temp_filepath, **self.export_kwargs
278+
)
251279
revived_model = tf.saved_model.load(temp_filepath)
252280
self.assertAllClose(
253281
ref_output, revived_model.serve(ref_input_x, ref_input_y)
@@ -282,7 +310,10 @@ def test_input_signature(self, model_type, input_signature):
282310
else:
283311
input_signature = (input_signature,)
284312
saved_model.export_saved_model(
285-
model, temp_filepath, input_signature=input_signature
313+
model,
314+
temp_filepath,
315+
input_signature=input_signature,
316+
**self.export_kwargs,
286317
)
287318
revived_model = tf.saved_model.load(temp_filepath)
288319
self.assertAllClose(
@@ -318,11 +349,17 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs):
318349
ref_input = ops.random.uniform((3, 10))
319350
ref_output = model(ref_input)
320351

352+
export_kwargs = self.export_kwargs.copy()
353+
if "jax2tf_kwargs" in export_kwargs:
354+
export_kwargs["jax2tf_kwargs"].update(jax2tf_kwargs)
355+
else:
356+
export_kwargs["jax2tf_kwargs"] = jax2tf_kwargs
357+
321358
saved_model.export_saved_model(
322359
model,
323360
temp_filepath,
324361
is_static=is_static,
325-
jax2tf_kwargs=jax2tf_kwargs,
362+
**export_kwargs,
326363
)
327364
revived_model = tf.saved_model.load(temp_filepath)
328365
self.assertAllClose(ref_output, revived_model.serve(ref_input))
@@ -342,6 +379,22 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs):
342379
testing.torch_uses_gpu(), reason="Leads to core dumps on CI"
343380
)
344381
class ExportArchiveTest(testing.TestCase):
382+
def setUp(self):
383+
super().setUp()
384+
self.add_endpoint_kwargs = {}
385+
if testing.jax_uses_gpu():
386+
self.add_endpoint_kwargs = {
387+
"jax2tf_kwargs": {
388+
"native_serialization_platforms": ("cpu", "cuda")
389+
}
390+
}
391+
elif testing.jax_uses_tpu():
392+
self.add_endpoint_kwargs = {
393+
"jax2tf_kwargs": {
394+
"native_serialization_platforms": ("cpu", "tpu")
395+
}
396+
}
397+
345398
@parameterized.named_parameters(
346399
named_product(model_type=["sequential", "functional", "subclass"])
347400
)
@@ -365,6 +418,7 @@ def test_low_level_model_export(self, model_type):
365418
"call",
366419
model.__call__,
367420
input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
421+
**self.add_endpoint_kwargs,
368422
)
369423
export_archive.write_out(temp_filepath)
370424
revived_model = tf.saved_model.load(temp_filepath)
@@ -385,6 +439,7 @@ def test_low_level_model_export_with_alias(self):
385439
"call",
386440
model.__call__,
387441
input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
442+
**self.add_endpoint_kwargs,
388443
)
389444
export_archive.write_out(
390445
temp_filepath,
@@ -431,6 +486,7 @@ def call(self, inputs):
431486
tf.TensorSpec(shape=(None, None), dtype=tf.float32),
432487
]
433488
],
489+
**self.add_endpoint_kwargs,
434490
)
435491
export_archive.write_out(temp_filepath)
436492
revived_model = tf.saved_model.load(temp_filepath)
@@ -445,29 +501,13 @@ def call(self, inputs):
445501
reason="This test is only for the JAX backend.",
446502
)
447503
def test_low_level_model_export_with_jax2tf_kwargs(self):
448-
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
449-
450-
model = get_model()
451-
ref_input = tf.random.normal((3, 10))
452-
ref_output = model(ref_input)
453-
454504
export_archive = saved_model.ExportArchive()
455-
export_archive.track(model)
456-
export_archive.add_endpoint(
457-
"call",
458-
model.__call__,
459-
input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
460-
jax2tf_kwargs={
461-
"native_serialization": True,
462-
"native_serialization_platforms": ("cpu", "tpu"),
463-
},
464-
)
465505
with self.assertRaisesRegex(
466506
ValueError, "native_serialization_platforms.*bogus"
467507
):
468508
export_archive.add_endpoint(
469-
"call2",
470-
model.__call__,
509+
"call",
510+
lambda x: x,
471511
input_signature=[
472512
tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
473513
],
@@ -476,9 +516,6 @@ def test_low_level_model_export_with_jax2tf_kwargs(self):
476516
"native_serialization_platforms": ("cpu", "bogus"),
477517
},
478518
)
479-
export_archive.write_out(temp_filepath)
480-
revived_model = tf.saved_model.load(temp_filepath)
481-
self.assertAllClose(ref_output, revived_model.call(ref_input))
482519

483520
@pytest.mark.skipif(
484521
backend.backend() != "jax",
@@ -506,12 +543,13 @@ def call(self, inputs):
506543
"call",
507544
model.__call__,
508545
input_signature=signature,
509-
jax2tf_kwargs={},
546+
**self.add_endpoint_kwargs,
510547
)
511548
export_archive.write_out(temp_filepath)
512549

513550
export_archive = saved_model.ExportArchive()
514551
export_archive.track(model)
552+
# TODO
515553
export_archive.add_endpoint(
516554
"call",
517555
model.__call__,
@@ -585,6 +623,7 @@ def model_call(x):
585623
model_call,
586624
native_serialization=native_jax_compatible,
587625
polymorphic_shapes=["(b, 10)"],
626+
# TODO
588627
)
589628

590629
# you can now build a TF inference function

0 commit comments

Comments
 (0)