Skip to content

Commit e7af66d

Browse files
committed
Fix some TPU tests.
`jax2onnx` fixed the ONNX export on TPU, but the `native_serialization_platforms` argument has to be passed to also export to CPU as the reloading happens with a CPU only version of TensorFlow.
1 parent 9708582 commit e7af66d

File tree

3 files changed

+112
-118
lines changed

3 files changed

+112
-118
lines changed
Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1 @@
1-
ConvTransposeBasicTest
2-
ExportArchiveTest::test_jax_endpoint_registration_tf_function
3-
ExportArchiveTest::test_jax_multi_unknown_endpoint_registration
4-
ExportArchiveTest::test_layer_export
5-
ExportArchiveTest::test_low_level_model_export_functional
6-
ExportArchiveTest::test_low_level_model_export_sequential
7-
ExportArchiveTest::test_low_level_model_export_subclass
8-
ExportArchiveTest::test_low_level_model_export_with_alias
9-
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_functional
10-
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_sequential
11-
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_subclass
12-
ExportArchiveTest::test_low_level_model_export_with_jax2tf_kwargs
13-
ExportArchiveTest::test_low_level_model_export_with_jax2tf_polymorphic_shapes
14-
ExportArchiveTest::test_model_combined_with_tf_preprocessing
15-
ExportArchiveTest::test_model_export_method_functional
16-
ExportArchiveTest::test_model_export_method_sequential
17-
ExportArchiveTest::test_model_export_method_subclass
18-
ExportArchiveTest::test_multi_input_output_functional_model
19-
ExportArchiveTest::test_non_standard_layer_signature
20-
ExportArchiveTest::test_non_standard_layer_signature_with_kwargs
21-
ExportArchiveTest::test_track_multiple_layers
22-
ExportONNXTest::test_export_with_input_names
23-
ExportONNXTest::test_export_with_opset_version_18
24-
ExportONNXTest::test_export_with_opset_version_none
25-
ExportONNXTest::test_model_with_input_structure_array
26-
ExportONNXTest::test_model_with_input_structure_dict
27-
ExportONNXTest::test_model_with_input_structure_tuple
28-
ExportONNXTest::test_model_with_multiple_inputs
29-
ExportONNXTest::test_standard_model_export_functional
30-
ExportONNXTest::test_standard_model_export_lstm
31-
ExportONNXTest::test_standard_model_export_sequential
32-
ExportONNXTest::test_standard_model_export_subclass
33-
ExportOpenVINOTest::test_model_with_input_structure_array
34-
ExportOpenVINOTest::test_model_with_input_structure_dict
35-
ExportOpenVINOTest::test_model_with_input_structure_tuple
36-
ExportOpenVINOTest::test_model_with_multiple_inputs
37-
ExportOpenVINOTest::test_standard_model_export_functional
38-
ExportOpenVINOTest::test_standard_model_export_sequential
39-
ExportOpenVINOTest::test_standard_model_export_subclass
40-
ExportSavedModelTest::test_input_signature_functional_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
41-
ExportSavedModelTest::test_input_signature_functional_backend_tensor
42-
ExportSavedModelTest::test_input_signature_functional_inputspec(dtype=float32, shape=(none, 10), ndim=2)
43-
ExportSavedModelTest::test_input_signature_functional_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
44-
ExportSavedModelTest::test_input_signature_sequential_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
45-
ExportSavedModelTest::test_input_signature_sequential_backend_tensor
46-
ExportSavedModelTest::test_input_signature_sequential_inputspec(dtype=float32, shape=(none, 10), ndim=2)
47-
ExportSavedModelTest::test_input_signature_sequential_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
48-
ExportSavedModelTest::test_input_signature_subclass_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
49-
ExportSavedModelTest::test_input_signature_subclass_backend_tensor
50-
ExportSavedModelTest::test_input_signature_subclass_inputspec(dtype=float32, shape=(none, 10), ndim=2)
51-
ExportSavedModelTest::test_input_signature_subclass_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
52-
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_{'enable_xla': true, 'native_serialization': true}
53-
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_none
54-
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_{'enable_xla': true, 'native_serialization': true}
55-
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_none
56-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_{'enable_xla': true, 'native_serialization': true}
57-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_none
58-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_{'enable_xla': true, 'native_serialization': true}
59-
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_none
60-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_{'enable_xla': true, 'native_serialization': true}
61-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_none
62-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_{'enable_xla': true, 'native_serialization': true}
63-
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_none
64-
ExportSavedModelTest::test_model_with_input_structure_array
65-
ExportSavedModelTest::test_model_with_input_structure_dict
66-
ExportSavedModelTest::test_model_with_input_structure_tuple
67-
ExportSavedModelTest::test_model_with_multiple_inputs
68-
ExportSavedModelTest::test_model_with_non_trainable_state_export_functional
69-
ExportSavedModelTest::test_model_with_non_trainable_state_export_sequential
70-
ExportSavedModelTest::test_model_with_non_trainable_state_export_subclass
71-
ExportSavedModelTest::test_model_with_rng_export_functional
72-
ExportSavedModelTest::test_model_with_rng_export_sequential
73-
ExportSavedModelTest::test_model_with_rng_export_subclass
74-
ExportSavedModelTest::test_model_with_tf_data_layer_functional
75-
ExportSavedModelTest::test_model_with_tf_data_layer_sequential
76-
ExportSavedModelTest::test_model_with_tf_data_layer_subclass
77-
ExportSavedModelTest::test_standard_model_export_functional
78-
ExportSavedModelTest::test_standard_model_export_sequential
79-
ExportSavedModelTest::test_standard_model_export_subclass
1+
ConvTransposeBasicTest

keras/src/export/onnx_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,12 @@ def test_standard_model_export(self, model_type):
101101
ort_inputs = {
102102
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
103103
}
104-
self.assertAllClose(ort_session.run(None, ort_inputs)[0], ref_output)
104+
self.assertAllClose(
105+
ort_session.run(None, ort_inputs)[0],
106+
ref_output,
107+
tpu_atol=1e-3,
108+
tpu_rtol=1e-2,
109+
)
105110
# Test with a different batch size
106111
ort_inputs = {
107112
k.name: v
@@ -291,7 +296,12 @@ def test_export_with_input_names(self):
291296
ort_inputs = {
292297
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
293298
}
294-
self.assertAllClose(ort_session.run(None, ort_inputs)[0], ref_output)
299+
self.assertAllClose(
300+
ort_session.run(None, ort_inputs)[0],
301+
ref_output,
302+
tpu_atol=1e-3,
303+
tpu_rtol=1e-2,
304+
)
295305

296306
@parameterized.named_parameters(
297307
named_product(

0 commit comments

Comments
 (0)