Skip to content

Commit cccbefd

Browse files
authored
Fix TPU tests, removed test exclusion mechanism. (#22571)
- Remove the file-based exclusion mechanism as it is no longer needed. - `jax2onnx` fixed the ONNX export on TPU. - tf_save_model export works on TPU, simply, the `native_serialization_platforms` argument has to be passed to also export to CPU as the reloading happens with a CPU only version of TensorFlow. - Only conv transpose tests remain excluded.
1 parent 050671c commit cccbefd

File tree

5 files changed

+258
-181
lines changed

5 files changed

+258
-181
lines changed

conftest.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,11 @@ def pytest_collection_modifyitems(config, items):
5252
line.strip() for line in openvino_skipped_tests if line.strip()
5353
]
5454

55-
tpu_skipped_tests = []
5655
if backend() == "jax":
5756
import jax
5857

5958
has_multiple_devices = jax.device_count() > 1
6059

61-
if jax.default_backend() == "tpu":
62-
with open(
63-
"keras/src/backend/jax/excluded_tpu_tests.txt", "r"
64-
) as file:
65-
tpu_skipped_tests = file.readlines()
66-
# it is necessary to check if stripped line is not empty
67-
# and exclude such lines
68-
tpu_skipped_tests = [
69-
line.strip() for line in tpu_skipped_tests if line.strip()
70-
]
71-
7260
requires_trainable_backend = pytest.mark.skipif(
7361
backend() in ["numpy", "openvino"],
7462
reason="Trainer not implemented for NumPy and OpenVINO backend.",
@@ -96,14 +84,6 @@ def pytest_collection_modifyitems(config, items):
9684
"Not supported operation by openvino backend",
9785
)
9886
)
99-
# also, skip concrete tests for TPU when using JAX backend
100-
for skipped_test in tpu_skipped_tests:
101-
if skipped_test in item.nodeid:
102-
item.add_marker(
103-
pytest.mark.skip(
104-
reason="Known TPU test failure",
105-
)
106-
)
10787

10888

10989
def skip_if_backend(given_backend, reason):

keras/src/backend/jax/excluded_tpu_tests.txt

Lines changed: 0 additions & 79 deletions
This file was deleted.

keras/src/export/onnx_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,12 @@ def test_standard_model_export(self, model_type):
101101
ort_inputs = {
102102
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
103103
}
104-
self.assertAllClose(ort_session.run(None, ort_inputs)[0], ref_output)
104+
self.assertAllClose(
105+
ort_session.run(None, ort_inputs)[0],
106+
ref_output,
107+
tpu_atol=1e-3,
108+
tpu_rtol=1e-2,
109+
)
105110
# Test with a different batch size
106111
ort_inputs = {
107112
k.name: v
@@ -291,7 +296,12 @@ def test_export_with_input_names(self):
291296
ort_inputs = {
292297
k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])
293298
}
294-
self.assertAllClose(ort_session.run(None, ort_inputs)[0], ref_output)
299+
self.assertAllClose(
300+
ort_session.run(None, ort_inputs)[0],
301+
ref_output,
302+
tpu_atol=1e-3,
303+
tpu_rtol=1e-2,
304+
)
295305

296306
@parameterized.named_parameters(
297307
named_product(

0 commit comments

Comments
 (0)