|
5 | 5 | from absl import logging
|
6 | 6 | from absl.testing import parameterized
|
7 | 7 |
|
| 8 | +from keras.src import backend |
8 | 9 | from keras.src import layers
|
9 | 10 | from keras.src.models import Sequential
|
10 | 11 | from keras.src.saving import saving_api
|
@@ -122,6 +123,11 @@ def get_model(self, dtype=None):
|
122 | 123 | )
|
123 | 124 | def test_basic_load(self, dtype):
|
124 | 125 | """Test basic model loading."""
|
| 126 | + if backend.backend() == "mlx" and dtype == "float64": |
| 127 | + self.skipTest( |
| 128 | + "mlx backend does not yet support float64 in random and uniform" |
| 129 | + ) |
| 130 | + |
125 | 131 | model = self.get_model(dtype)
|
126 | 132 | filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
|
127 | 133 | saving_api.save_model(model, filepath)
|
@@ -208,6 +214,13 @@ def get_model(self, dtype=None):
|
208 | 214 | )
|
209 | 215 | def test_load_keras_weights(self, source_dtype, dest_dtype):
|
210 | 216 | """Test loading keras weights."""
|
| 217 | + if backend.backend() == "mlx": |
| 218 | + if source_dtype == "float64" or dest_dtype == "float64": |
| 219 | + self.skipTest( |
| 220 | + "mlx backend does not yet support float64 in " |
| 221 | + "random and uniform" |
| 222 | + ) |
| 223 | + |
211 | 224 | src_model = self.get_model(dtype=source_dtype)
|
212 | 225 | filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
|
213 | 226 | src_model.save_weights(filepath)
|
|
0 commit comments