From d2b101a9d96c322bd347119adec1891a50f89aed Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:16:31 -0700 Subject: [PATCH] Fix TPU tests, removed test exclusion mechanism. - Remove the file-based exclusion mechanism as it is no longer needed. - `jax2onnx` fixed the ONNX export on TPU. - tf_save_model export works on TPU, simply, the `native_serialization_platforms` argument has to be passed to also export to CPU as the reloading happens with a CPU only version of TensorFlow. - Only conv transpose tests remain excluded. --- conftest.py | 20 -- keras/src/backend/jax/excluded_tpu_tests.txt | 79 ----- keras/src/export/onnx_test.py | 14 +- keras/src/export/saved_model_test.py | 325 +++++++++++++----- .../convolutional/conv_transpose_test.py | 1 + 5 files changed, 258 insertions(+), 181 deletions(-) delete mode 100644 keras/src/backend/jax/excluded_tpu_tests.txt diff --git a/conftest.py b/conftest.py index 44bf350076d1..e9eae70986c1 100644 --- a/conftest.py +++ b/conftest.py @@ -52,23 +52,11 @@ def pytest_collection_modifyitems(config, items): line.strip() for line in openvino_skipped_tests if line.strip() ] - tpu_skipped_tests = [] if backend() == "jax": import jax has_multiple_devices = jax.device_count() > 1 - if jax.default_backend() == "tpu": - with open( - "keras/src/backend/jax/excluded_tpu_tests.txt", "r" - ) as file: - tpu_skipped_tests = file.readlines() - # it is necessary to check if stripped line is not empty - # and exclude such lines - tpu_skipped_tests = [ - line.strip() for line in tpu_skipped_tests if line.strip() - ] - requires_trainable_backend = pytest.mark.skipif( backend() in ["numpy", "openvino"], reason="Trainer not implemented for NumPy and OpenVINO backend.", @@ -96,14 +84,6 @@ def pytest_collection_modifyitems(config, items): "Not supported operation by openvino backend", ) ) - # also, skip concrete tests for TPU when using JAX backend - for skipped_test in tpu_skipped_tests: - if skipped_test in item.nodeid: - item.add_marker( - pytest.mark.skip( - reason="Known TPU test failure", - ) - ) def skip_if_backend(given_backend, reason): diff --git a/keras/src/backend/jax/excluded_tpu_tests.txt b/keras/src/backend/jax/excluded_tpu_tests.txt deleted file mode 100644 index 8898eb782b69..000000000000 --- a/keras/src/backend/jax/excluded_tpu_tests.txt +++ /dev/null @@ -1,79 +0,0 @@ -ConvTransposeBasicTest -ExportArchiveTest::test_jax_endpoint_registration_tf_function -ExportArchiveTest::test_jax_multi_unknown_endpoint_registration -ExportArchiveTest::test_layer_export -ExportArchiveTest::test_low_level_model_export_functional -ExportArchiveTest::test_low_level_model_export_sequential -ExportArchiveTest::test_low_level_model_export_subclass -ExportArchiveTest::test_low_level_model_export_with_alias -ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_functional -ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_sequential -ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_subclass -ExportArchiveTest::test_low_level_model_export_with_jax2tf_kwargs -ExportArchiveTest::test_low_level_model_export_with_jax2tf_polymorphic_shapes -ExportArchiveTest::test_model_combined_with_tf_preprocessing -ExportArchiveTest::test_model_export_method_functional -ExportArchiveTest::test_model_export_method_sequential -ExportArchiveTest::test_model_export_method_subclass -ExportArchiveTest::test_multi_input_output_functional_model -ExportArchiveTest::test_non_standard_layer_signature -ExportArchiveTest::test_non_standard_layer_signature_with_kwargs -ExportArchiveTest::test_track_multiple_layers -ExportONNXTest::test_export_with_input_names -ExportONNXTest::test_export_with_opset_version_18 -ExportONNXTest::test_export_with_opset_version_none -ExportONNXTest::test_model_with_input_structure_array -ExportONNXTest::test_model_with_input_structure_dict -ExportONNXTest::test_model_with_input_structure_tuple -ExportONNXTest::test_model_with_multiple_inputs -ExportONNXTest::test_standard_model_export_functional -ExportONNXTest::test_standard_model_export_lstm -ExportONNXTest::test_standard_model_export_sequential -ExportONNXTest::test_standard_model_export_subclass -ExportOpenVINOTest::test_model_with_input_structure_array -ExportOpenVINOTest::test_model_with_input_structure_dict -ExportOpenVINOTest::test_model_with_input_structure_tuple -ExportOpenVINOTest::test_model_with_multiple_inputs -ExportOpenVINOTest::test_standard_model_export_functional -ExportOpenVINOTest::test_standard_model_export_sequential -ExportOpenVINOTest::test_standard_model_export_subclass -ExportSavedModelTest::test_input_signature_functional_ -ExportSavedModelTest::test_input_signature_functional_backend_tensor -ExportSavedModelTest::test_input_signature_functional_inputspec(dtype=float32, shape=(none, 10), ndim=2) -ExportSavedModelTest::test_input_signature_functional_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs') -ExportSavedModelTest::test_input_signature_sequential_ -ExportSavedModelTest::test_input_signature_sequential_backend_tensor -ExportSavedModelTest::test_input_signature_sequential_inputspec(dtype=float32, shape=(none, 10), ndim=2) -ExportSavedModelTest::test_input_signature_sequential_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs') -ExportSavedModelTest::test_input_signature_subclass_ -ExportSavedModelTest::test_input_signature_subclass_backend_tensor -ExportSavedModelTest::test_input_signature_subclass_inputspec(dtype=float32, shape=(none, 10), ndim=2) -ExportSavedModelTest::test_input_signature_subclass_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs') -ExportSavedModelTest::test_jax_specific_kwargs_functional_false_{'enable_xla': true, 'native_serialization': true} -ExportSavedModelTest::test_jax_specific_kwargs_functional_false_none -ExportSavedModelTest::test_jax_specific_kwargs_functional_true_{'enable_xla': true, 'native_serialization': true} -ExportSavedModelTest::test_jax_specific_kwargs_functional_true_none -ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_{'enable_xla': true, 'native_serialization': true} -ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_none -ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_{'enable_xla': true, 'native_serialization': true} -ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_none -ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_{'enable_xla': true, 'native_serialization': true} -ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_none -ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_{'enable_xla': true, 'native_serialization': true} -ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_none -ExportSavedModelTest::test_model_with_input_structure_array -ExportSavedModelTest::test_model_with_input_structure_dict -ExportSavedModelTest::test_model_with_input_structure_tuple -ExportSavedModelTest::test_model_with_multiple_inputs -ExportSavedModelTest::test_model_with_non_trainable_state_export_functional -ExportSavedModelTest::test_model_with_non_trainable_state_export_sequential -ExportSavedModelTest::test_model_with_non_trainable_state_export_subclass -ExportSavedModelTest::test_model_with_rng_export_functional -ExportSavedModelTest::test_model_with_rng_export_sequential -ExportSavedModelTest::test_model_with_rng_export_subclass -ExportSavedModelTest::test_model_with_tf_data_layer_functional -ExportSavedModelTest::test_model_with_tf_data_layer_sequential -ExportSavedModelTest::test_model_with_tf_data_layer_subclass -ExportSavedModelTest::test_standard_model_export_functional -ExportSavedModelTest::test_standard_model_export_sequential -ExportSavedModelTest::test_standard_model_export_subclass \ No newline at end of file diff --git a/keras/src/export/onnx_test.py b/keras/src/export/onnx_test.py index 841fb3cfda39..019c1e938e94 100644 --- a/keras/src/export/onnx_test.py +++ b/keras/src/export/onnx_test.py @@ -101,7 +101,12 @@ def test_standard_model_export(self, model_type): ort_inputs = { k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) } - self.assertAllClose(ort_session.run(None, ort_inputs)[0], ref_output) + self.assertAllClose( + ort_session.run(None, ort_inputs)[0], + ref_output, + tpu_atol=1e-3, + tpu_rtol=1e-2, + ) # Test with a different batch size ort_inputs = { k.name: v @@ -291,7 +296,12 @@ def test_export_with_input_names(self): ort_inputs = { k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) } - self.assertAllClose(ort_session.run(None, ort_inputs)[0], ref_output) + self.assertAllClose( + ort_session.run(None, ort_inputs)[0], + ref_output, + tpu_atol=1e-3, + tpu_rtol=1e-2, + ) @parameterized.named_parameters( named_product( diff --git a/keras/src/export/saved_model_test.py b/keras/src/export/saved_model_test.py index 930c9eb24a68..4d9493a06434 100644 --- a/keras/src/export/saved_model_test.py +++ b/keras/src/export/saved_model_test.py @@ -64,6 +64,22 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None): reason="Torch backend export (via torch_xla) is incompatible with np 2.0", ) class ExportSavedModelTest(testing.TestCase): + def setUp(self): + super().setUp() + self.export_kwargs = {} + if testing.jax_uses_gpu(): + self.export_kwargs = { + "jax2tf_kwargs": { + "native_serialization_platforms": ("cpu", "cuda") + } + } + elif testing.jax_uses_tpu(): + self.export_kwargs = { + "jax2tf_kwargs": { + "native_serialization_platforms": ("cpu", "tpu") + } + } + @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) ) @@ -74,9 +90,16 @@ def test_standard_model_export(self, model_type): ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input) - saved_model.export_saved_model(model, temp_filepath) + saved_model.export_saved_model( + model, temp_filepath, **self.export_kwargs + ) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.serve(ref_input), ref_output) + self.assertAllClose( + revived_model.serve(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_model.serve(tf.random.normal((6, 10))) @@ -106,7 +129,9 @@ def call(self, inputs): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - saved_model.export_saved_model(model, temp_filepath) + saved_model.export_saved_model( + model, temp_filepath, **self.export_kwargs + ) revived_model = tf.saved_model.load(temp_filepath) self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape) # Test with a different batch size @@ -142,17 +167,19 @@ def call(self, inputs): model = get_model(model_type, layer_list=[StateLayer()]) model(tf.random.normal((3, 10))) - saved_model.export_saved_model(model, temp_filepath) + saved_model.export_saved_model( + model, temp_filepath, **self.export_kwargs + ) revived_model = tf.saved_model.load(temp_filepath) # The non-trainable counter is expected to increment input = tf.random.normal((6, 10)) output1, counter1 = revived_model.serve(input) - self.assertAllClose(output1, input) - self.assertAllClose(counter1, 2) + self.assertAllClose(output1, input, tpu_atol=0.01, tpu_rtol=0.01) + self.assertAllClose(counter1, 2, tpu_atol=0.01, tpu_rtol=0.01) output2, counter2 = revived_model.serve(input) - self.assertAllClose(output2, input) - self.assertAllClose(counter2, 3) + self.assertAllClose(output2, input, tpu_atol=0.01, tpu_rtol=0.01) + self.assertAllClose(counter2, 3, tpu_atol=0.01, tpu_rtol=0.01) @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) @@ -164,9 +191,16 @@ def test_model_with_tf_data_layer(self, model_type): ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input) - saved_model.export_saved_model(model, temp_filepath) + saved_model.export_saved_model( + model, temp_filepath, **self.export_kwargs + ) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.serve(ref_input), ref_output) + self.assertAllClose( + revived_model.serve(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_model.serve(tf.random.normal((6, 10))) @@ -206,9 +240,16 @@ def call(self, inputs): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) - saved_model.export_saved_model(model, temp_filepath) + saved_model.export_saved_model( + model, temp_filepath, **self.export_kwargs + ) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.serve(ref_input), ref_output) + self.assertAllClose( + revived_model.serve(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with keras.saving_lib temp_filepath = os.path.join( @@ -223,8 +264,12 @@ def call(self, inputs): "DictModel": DictModel, }, ) - self.assertAllClose(revived_model(ref_input), ref_output) - saved_model.export_saved_model(revived_model, self.get_temp_dir()) + self.assertAllClose( + revived_model(ref_input), ref_output, tpu_atol=0.01, tpu_rtol=0.01 + ) + saved_model.export_saved_model( + revived_model, self.get_temp_dir(), **self.export_kwargs + ) # Test with a different batch size bigger_input = tree.map_structure( @@ -247,10 +292,15 @@ def build(self, y_shape, x_shape): ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input_x, ref_input_y) - saved_model.export_saved_model(model, temp_filepath) + saved_model.export_saved_model( + model, temp_filepath, **self.export_kwargs + ) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose( - revived_model.serve(ref_input_x, ref_input_y), ref_output + revived_model.serve(ref_input_x, ref_input_y), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, ) # Test with a different batch size revived_model.serve( @@ -282,11 +332,17 @@ def test_input_signature(self, model_type, input_signature): else: input_signature = (input_signature,) saved_model.export_saved_model( - model, temp_filepath, input_signature=input_signature + model, + temp_filepath, + input_signature=input_signature, + **self.export_kwargs, ) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose( - revived_model.serve(ops.convert_to_numpy(ref_input)), ref_output + revived_model.serve(ops.convert_to_numpy(ref_input)), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, ) def test_input_signature_error(self): @@ -318,14 +374,23 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): ref_input = ops.random.uniform((3, 10)) ref_output = model(ref_input) + export_kwargs = self.export_kwargs.copy() + if jax2tf_kwargs is not None: + export_kwargs.setdefault("jax2tf_kwargs", {}).update(jax2tf_kwargs) + saved_model.export_saved_model( model, temp_filepath, is_static=is_static, - jax2tf_kwargs=jax2tf_kwargs, + **export_kwargs, ) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.serve(ref_input), ref_output) + self.assertAllClose( + revived_model.serve(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) @pytest.mark.skipif( @@ -342,6 +407,22 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): testing.torch_uses_gpu(), reason="Leads to core dumps on CI" ) class ExportArchiveTest(testing.TestCase): + def setUp(self): + super().setUp() + self.add_endpoint_kwargs = {} + if testing.jax_uses_gpu(): + self.add_endpoint_kwargs = { + "jax2tf_kwargs": { + "native_serialization_platforms": ("cpu", "cuda") + } + } + elif testing.jax_uses_tpu(): + self.add_endpoint_kwargs = { + "jax2tf_kwargs": { + "native_serialization_platforms": ("cpu", "tpu") + } + } + @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) ) @@ -365,10 +446,16 @@ def test_low_level_model_export(self, model_type): "call", model.__call__, input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.call(ref_input), ref_output) + self.assertAllClose( + revived_model.call(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_model.call(tf.random.normal((6, 10))) @@ -385,6 +472,7 @@ def test_low_level_model_export_with_alias(self): "call", model.__call__, input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.write_out( temp_filepath, @@ -397,7 +485,10 @@ def test_low_level_model_export_with_alias(self): ), ) self.assertAllClose( - revived_model.function_aliases["call_alias"](ref_input), ref_output + revived_model.function_aliases["call_alias"](ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, ) # Test with a different batch size revived_model.function_aliases["call_alias"](tf.random.normal((6, 10))) @@ -431,10 +522,16 @@ def call(self, inputs): tf.TensorSpec(shape=(None, None), dtype=tf.float32), ] ], + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.call(ref_input), ref_output) + self.assertAllClose( + revived_model.call(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_model.call([tf.random.normal((6, 8)), tf.random.normal((6, 6))]) # Test with a different batch size and different dynamic sizes @@ -445,29 +542,13 @@ def call(self, inputs): reason="This test is only for the JAX backend.", ) def test_low_level_model_export_with_jax2tf_kwargs(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - - model = get_model() - ref_input = tf.random.normal((3, 10)) - ref_output = model(ref_input) - export_archive = saved_model.ExportArchive() - export_archive.track(model) - export_archive.add_endpoint( - "call", - model.__call__, - input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], - jax2tf_kwargs={ - "native_serialization": True, - "native_serialization_platforms": ("cpu", "tpu"), - }, - ) with self.assertRaisesRegex( ValueError, "native_serialization_platforms.*bogus" ): export_archive.add_endpoint( - "call2", - model.__call__, + "call", + lambda x: x, input_signature=[ tf.TensorSpec(shape=(None, 10), dtype=tf.float32) ], @@ -476,9 +557,6 @@ def test_low_level_model_export_with_jax2tf_kwargs(self): "native_serialization_platforms": ("cpu", "bogus"), }, ) - export_archive.write_out(temp_filepath) - revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.call(ref_input), ref_output) @pytest.mark.skipif( backend.backend() != "jax", @@ -506,21 +584,30 @@ def call(self, inputs): "call", model.__call__, input_signature=signature, - jax2tf_kwargs={}, + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) export_archive = saved_model.ExportArchive() export_archive.track(model) + add_endpoint_kwargs = self.add_endpoint_kwargs.copy() + add_endpoint_kwargs.setdefault("jax2tf_kwargs", {}).update( + {"polymorphic_shapes": ["(batch, a, a)"]} + ) export_archive.add_endpoint( "call", model.__call__, input_signature=signature, - jax2tf_kwargs={"polymorphic_shapes": ["(batch, a, a)"]}, + **add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.call(ref_input), ref_output) + self.assertAllClose( + revived_model.call(ref_input), + ref_output, + tpu_atol=0.05, + tpu_rtol=1.0, + ) @pytest.mark.skipif( backend.backend() != "tensorflow", @@ -550,6 +637,7 @@ def my_endpoint(x): export_archive.add_endpoint( "call", my_endpoint, + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) @@ -573,18 +661,16 @@ def test_jax_endpoint_registration_tf_function(self): def model_call(x): return model(x) - from jax import default_backend as jax_device from jax.experimental import jax2tf - native_jax_compatible = not ( - jax_device() == "gpu" - and len(tf.config.list_physical_devices("GPU")) == 0 - ) + add_endpoint_kwargs = self.add_endpoint_kwargs.copy() + add_endpoint_kwargs.setdefault("jax2tf_kwargs", {}) + # now, convert JAX function converted_model_call = jax2tf.convert( model_call, - native_serialization=native_jax_compatible, polymorphic_shapes=["(b, 10)"], + **add_endpoint_kwargs["jax2tf_kwargs"], ) # you can now build a TF inference function @@ -601,13 +687,20 @@ def infer_fn(x): temp_filepath = os.path.join(self.get_temp_dir(), "my_model") export_archive = saved_model.ExportArchive() export_archive.track(model) - export_archive.add_endpoint("serve", infer_fn) + export_archive.add_endpoint( + "serve", infer_fn, **self.add_endpoint_kwargs + ) export_archive.write_out(temp_filepath) # Reload and verify outputs revived_model = tf.saved_model.load(temp_filepath) self.assertFalse(hasattr(revived_model, "_tracked")) - self.assertAllClose(revived_model.serve(ref_input), ref_output) + self.assertAllClose( + revived_model.serve(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) self.assertLen(revived_model.variables, 8) self.assertLen(revived_model.trainable_variables, 6) self.assertLen(revived_model.non_trainable_variables, 2) @@ -647,18 +740,16 @@ def test_jax_multi_unknown_endpoint_registration(self): def model_call(x): return model(x) - from jax import default_backend as jax_device from jax.experimental import jax2tf - native_jax_compatible = not ( - jax_device() == "gpu" - and len(tf.config.list_physical_devices("GPU")) == 0 - ) + add_endpoint_kwargs = self.add_endpoint_kwargs.copy() + add_endpoint_kwargs.setdefault("jax2tf_kwargs", {}) + # now, convert JAX function converted_model_call = jax2tf.convert( model_call, - native_serialization=native_jax_compatible, polymorphic_shapes=["(b, t, 1)"], + **add_endpoint_kwargs["jax2tf_kwargs"], ) # you can now build a TF inference function @@ -678,13 +769,20 @@ def infer_fn(x): temp_filepath = os.path.join(self.get_temp_dir(), "my_model") export_archive = saved_model.ExportArchive() export_archive.track(model) - export_archive.add_endpoint("serve", infer_fn) + export_archive.add_endpoint( + "serve", infer_fn, **self.add_endpoint_kwargs + ) export_archive.write_out(temp_filepath) # Reload and verify outputs revived_model = tf.saved_model.load(temp_filepath) self.assertFalse(hasattr(revived_model, "_tracked")) - self.assertAllClose(revived_model.serve(ref_input), ref_output) + self.assertAllClose( + revived_model.serve(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) self.assertLen(revived_model.variables, 6) self.assertLen(revived_model.trainable_variables, 6) self.assertLen(revived_model.non_trainable_variables, 0) @@ -708,10 +806,16 @@ def test_layer_export(self): "call", layer.call, input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) revived_layer = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_layer.call(ref_input), ref_output) + self.assertAllClose( + revived_layer.call(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) def test_multi_input_output_functional_model(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") @@ -724,10 +828,20 @@ def test_multi_input_output_functional_model(self): ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))] ref_outputs = model(ref_inputs) - model.export(temp_filepath) + model.export(temp_filepath, **self.add_endpoint_kwargs) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(ref_outputs[0], revived_model.serve(ref_inputs)[0]) - self.assertAllClose(ref_outputs[1], revived_model.serve(ref_inputs)[1]) + self.assertAllClose( + ref_outputs[0], + revived_model.serve(ref_inputs)[0], + tpu_atol=0.01, + tpu_rtol=0.01, + ) + self.assertAllClose( + ref_outputs[1], + revived_model.serve(ref_inputs)[1], + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_model.serve( [tf.random.normal((6, 2)), tf.random.normal((6, 2))] @@ -742,10 +856,20 @@ def test_multi_input_output_functional_model(self): } ref_outputs = model(ref_inputs) - model.export(temp_filepath) + model.export(temp_filepath, **self.add_endpoint_kwargs) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(ref_outputs[0], revived_model.serve(ref_inputs)[0]) - self.assertAllClose(ref_outputs[1], revived_model.serve(ref_inputs)[1]) + self.assertAllClose( + ref_outputs[0], + revived_model.serve(ref_inputs)[0], + tpu_atol=0.01, + tpu_rtol=0.01, + ) + self.assertAllClose( + ref_outputs[1], + revived_model.serve(ref_inputs)[1], + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_model.serve( { @@ -773,7 +897,9 @@ def test_model_with_lookup_table(self): ref_input = tf.convert_to_tensor(["one two three four"]) ref_output = model(ref_input) - saved_model.export_saved_model(model, temp_filepath) + saved_model.export_saved_model( + model, temp_filepath, **self.add_endpoint_kwargs + ) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(revived_model.serve(ref_input), ref_output) @@ -803,6 +929,7 @@ def test_model_with_tracked_collection(self): model, temp_filepath, input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)], + **self.add_endpoint_kwargs, ) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(revived_model.serve(ref_input), ref_output) @@ -821,16 +948,28 @@ def test_track_multiple_layers(self): "call_1", layer_1.call, input_signature=[tf.TensorSpec(shape=(None, 4), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.add_endpoint( "call_2", layer_2.call, input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) revived_layer = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_layer.call_1(ref_input_1), ref_output_1) - self.assertAllClose(revived_layer.call_2(ref_input_2), ref_output_2) + self.assertAllClose( + revived_layer.call_1(ref_input_1), + ref_output_1, + tpu_atol=0.01, + tpu_rtol=0.01, + ) + self.assertAllClose( + revived_layer.call_2(ref_input_2), + ref_output_2, + tpu_atol=0.01, + tpu_rtol=0.01, + ) def test_non_standard_layer_signature(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer") @@ -848,10 +987,13 @@ def test_non_standard_layer_signature(self): tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), ], + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) revived_layer = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_layer.call(x1, x2), ref_output) + self.assertAllClose( + revived_layer.call(x1, x2), ref_output, tpu_atol=0.05, tpu_rtol=0.1 + ) def test_non_standard_layer_signature_with_kwargs(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer") @@ -869,10 +1011,16 @@ def test_non_standard_layer_signature_with_kwargs(self): tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32), ], + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) revived_layer = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_layer.call(query=x1, value=x2), ref_output) + self.assertAllClose( + revived_layer.call(query=x1, value=x2), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_layer.call( query=tf.random.normal((6, 2, 2)), value=tf.random.normal((6, 2, 2)) @@ -896,6 +1044,7 @@ def test_variable_collection(self): "call", model.__call__, input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.add_variable_collection( "my_vars", model.layers[1].weights @@ -987,6 +1136,7 @@ def test_export_no_assets(self): "call", model.__call__, input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.write_out(temp_filepath) @@ -999,9 +1149,14 @@ def test_model_export_method(self, model_type): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - model.export(temp_filepath) + model.export(temp_filepath, **self.add_endpoint_kwargs) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.serve(ref_input), ref_output) + self.assertAllClose( + revived_model.serve(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) # Test with a different batch size revived_model.serve(tf.random.normal((6, 10))) @@ -1025,6 +1180,7 @@ def test_model_combined_with_tf_preprocessing(self): "model", model, input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)], + **self.add_endpoint_kwargs, ) export_archive.track(lookup_table) @@ -1034,10 +1190,19 @@ def combined_fn(x): x = model_fn(x) return x - self.assertAllClose(combined_fn(ref_input), ref_output) + self.assertAllClose( + combined_fn(ref_input), ref_output, tpu_atol=0.01, tpu_rtol=0.01 + ) - export_archive.add_endpoint("combined_fn", combined_fn) + export_archive.add_endpoint( + "combined_fn", combined_fn, **self.add_endpoint_kwargs + ) export_archive.write_out(temp_filepath) revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(revived_model.combined_fn(ref_input), ref_output) + self.assertAllClose( + revived_model.combined_fn(ref_input), + ref_output, + tpu_atol=0.01, + tpu_rtol=0.01, + ) diff --git a/keras/src/layers/convolutional/conv_transpose_test.py b/keras/src/layers/convolutional/conv_transpose_test.py index 53fb3e969ee7..8a324abe39a4 100644 --- a/keras/src/layers/convolutional/conv_transpose_test.py +++ b/keras/src/layers/convolutional/conv_transpose_test.py @@ -286,6 +286,7 @@ def np_conv3d_transpose( return output +@pytest.mark.skipif(testing.jax_uses_tpu(), reason="Crashes with JAX on TPU") class ConvTransposeBasicTest(testing.TestCase): @parameterized.parameters( {