|
15 | 15 | import os |
16 | 16 | from typing import Callable |
17 | 17 | from unittest import mock |
| 18 | +import warnings |
18 | 19 |
|
19 | 20 | from absl import flags |
20 | 21 | from absl.testing import absltest |
|
24 | 25 | from meridian import constants |
25 | 26 | from meridian.analysis import analyzer |
26 | 27 | from meridian.analysis import visualizer |
| 28 | +from meridian.backend import config as backend_config |
27 | 29 | from meridian.backend import test_utils as backend_test_utils |
28 | 30 | from meridian.data import input_data as meridian_input_data |
29 | 31 | from meridian.data import test_utils |
@@ -186,6 +188,82 @@ def test_serialize( |
186 | 188 | serialized_model.model.Unpack(unpacked_model) |
187 | 189 | self.assertEqual(unpacked_model.model_version, str(meridian_version)) |
188 | 190 |
|
| 191 | + def test_serialize_sets_computation_backend(self): |
| 192 | + meridian_model = model.Meridian( |
| 193 | + input_data=_INPUT_DATA, model_spec=test_data.get_default_model_spec() |
| 194 | + ) |
| 195 | + with mock.patch.object( |
| 196 | + model.Meridian, |
| 197 | + 'computation_backend', |
| 198 | + new_callable=mock.PropertyMock, |
| 199 | + ) as mock_backend: |
| 200 | + mock_backend.return_value = 'JAX' |
| 201 | + serialized_model = self.serde.serialize(meridian_model) |
| 202 | + |
| 203 | + unpacked_model = meridian_pb.MeridianModel() |
| 204 | + serialized_model.model.Unpack(unpacked_model) |
| 205 | + self.assertEqual( |
| 206 | + unpacked_model.computation_backend, meridian_pb.ComputationBackend.JAX |
| 207 | + ) |
| 208 | + |
| 209 | + def test_deserialize_warns_on_backend_mismatch(self): |
| 210 | + # Create a proto indicating it was trained with JAX. |
| 211 | + meridian_model = meridian_pb.MeridianModel( |
| 212 | + model_version='1.2.3', |
| 213 | + hyperparameters=test_data.DEFAULT_HYPERPARAMETERS_PROTO, |
| 214 | + prior_tfp_distributions=meridian_pb.PriorTfpDistributions(), |
| 215 | + inference_data=meridian_pb.InferenceData(), |
| 216 | + computation_backend=meridian_pb.ComputationBackend.JAX, |
| 217 | + ) |
| 218 | + any_model = any_pb2.Any() |
| 219 | + any_model.Pack(meridian_model) |
| 220 | + mmm_kernel = kernel_pb.MmmKernel( |
| 221 | + model=any_model, |
| 222 | + marketing_data=test_data.MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED, |
| 223 | + ) |
| 224 | + |
| 225 | + # Force the current environment to be TENSORFLOW. |
| 226 | + with mock.patch.object( |
| 227 | + backend, |
| 228 | + 'computation_backend', |
| 229 | + return_value=backend_config.ComputationBackend.TENSORFLOW, |
| 230 | + ): |
| 231 | + with self.assertWarnsRegex( |
| 232 | + UserWarning, |
| 233 | + 'The model was trained using JAX, but the current backend is' |
| 234 | + ' TENSORFLOW', |
| 235 | + ): |
| 236 | + self.serde.deserialize(mmm_kernel) |
| 237 | + |
| 238 | + def test_deserialize_no_warning_on_backend_match(self): |
| 239 | + # Create a proto indicating it was trained with TENSORFLOW. |
| 240 | + meridian_model = meridian_pb.MeridianModel( |
| 241 | + model_version='1.2.3', |
| 242 | + hyperparameters=test_data.DEFAULT_HYPERPARAMETERS_PROTO, |
| 243 | + prior_tfp_distributions=meridian_pb.PriorTfpDistributions(), |
| 244 | + inference_data=meridian_pb.InferenceData(), |
| 245 | + computation_backend=meridian_pb.ComputationBackend.TENSORFLOW, |
| 246 | + ) |
| 247 | + any_model = any_pb2.Any() |
| 248 | + any_model.Pack(meridian_model) |
| 249 | + mmm_kernel = kernel_pb.MmmKernel( |
| 250 | + model=any_model, |
| 251 | + marketing_data=test_data.MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED, |
| 252 | + ) |
| 253 | + |
| 254 | + # Force the current environment to be TENSORFLOW. |
| 255 | + with mock.patch.object( |
| 256 | + backend, |
| 257 | + 'computation_backend', |
| 258 | + return_value=backend_config.ComputationBackend.TENSORFLOW, |
| 259 | + ): |
| 260 | + with warnings.catch_warnings(record=True) as w: |
| 261 | + warnings.simplefilter('always') |
| 262 | + self.serde.deserialize(mmm_kernel) |
| 263 | + # Ensure no backend-related warnings were raised. |
| 264 | + backend_warnings = [x for x in w if 'backend' in str(x.message).lower()] |
| 265 | + self.assertEmpty(backend_warnings) |
| 266 | + |
189 | 267 | def test_serialize_no_controls(self): |
190 | 268 | meridian_model = model.Meridian( |
191 | 269 | input_data=_INPUT_DATA_NO_CONTROLS, |
|
0 commit comments