Skip to content

Commit e8cd63f

Browse files
andyl7anThe Meridian Authors
authored andcommitted
[JAX] Adds computation backend information to serialized Meridian models.
PiperOrigin-RevId: 872779914
1 parent 254216b commit e8cd63f

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

meridian/schema/serde/meridian_serde.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ def _make_meridian_model_proto(
154154
Returns:
155155
A MeridianModel proto.
156156
"""
157+
158+
backend_name = getattr(
159+
mmm, 'computation_backend', 'COMPUTATION_BACKEND_UNSPECIFIED'
160+
)
161+
computation_backend_enum = getattr(
162+
meridian_pb.ComputationBackend,
163+
backend_name,
164+
meridian_pb.ComputationBackend.COMPUTATION_BACKEND_UNSPECIFIED,
165+
)
166+
157167
model_proto = meridian_pb.MeridianModel(
158168
model_id=model_id,
159169
model_version=str(meridian_version),
@@ -167,6 +177,7 @@ def _make_meridian_model_proto(
167177
mmm.inference_data
168178
),
169179
arviz_version=az.__version__,
180+
computation_backend=computation_backend_enum,
170181
)
171182
# For backwards compatibility, only serialize EDA spec if it exists.
172183
if hasattr(mmm, 'eda_spec'):
@@ -264,6 +275,23 @@ def deserialize(
264275
serialized.model.Unpack(ser_meridian)
265276
serialized_version = semver.VersionInfo.parse(ser_meridian.model_version)
266277

278+
stored_backend = ser_meridian.computation_backend
279+
current_backend = backend.computation_backend()
280+
if (
281+
stored_backend
282+
!= meridian_pb.ComputationBackend.COMPUTATION_BACKEND_UNSPECIFIED
283+
and stored_backend != current_backend
284+
):
285+
warnings.warn(
286+
(
287+
'The model was trained using'
288+
f' {meridian_pb.ComputationBackend.Name(stored_backend)}, but the'
289+
f' current backend is {current_backend.name}. This may lead to'
290+
' numerical discrepancies or compatibility issues.'
291+
),
292+
UserWarning,
293+
)
294+
267295
deserialized_hyperparameters = (
268296
hyperparameters.HyperparametersSerde().deserialize(
269297
ser_meridian.hyperparameters, str(serialized_version)

meridian/schema/serde/meridian_serde_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
from typing import Callable
1717
from unittest import mock
18+
import warnings
1819

1920
from absl import flags
2021
from absl.testing import absltest
@@ -24,6 +25,7 @@
2425
from meridian import constants
2526
from meridian.analysis import analyzer
2627
from meridian.analysis import visualizer
28+
from meridian.backend import config as backend_config
2729
from meridian.backend import test_utils as backend_test_utils
2830
from meridian.data import input_data as meridian_input_data
2931
from meridian.data import test_utils
@@ -186,6 +188,82 @@ def test_serialize(
186188
serialized_model.model.Unpack(unpacked_model)
187189
self.assertEqual(unpacked_model.model_version, str(meridian_version))
188190

191+
def test_serialize_sets_computation_backend(self):
192+
meridian_model = model.Meridian(
193+
input_data=_INPUT_DATA, model_spec=test_data.get_default_model_spec()
194+
)
195+
with mock.patch.object(
196+
model.Meridian,
197+
'computation_backend',
198+
new_callable=mock.PropertyMock,
199+
) as mock_backend:
200+
mock_backend.return_value = 'JAX'
201+
serialized_model = self.serde.serialize(meridian_model)
202+
203+
unpacked_model = meridian_pb.MeridianModel()
204+
serialized_model.model.Unpack(unpacked_model)
205+
self.assertEqual(
206+
unpacked_model.computation_backend, meridian_pb.ComputationBackend.JAX
207+
)
208+
209+
def test_deserialize_warns_on_backend_mismatch(self):
210+
# Create a proto indicating it was trained with JAX.
211+
meridian_model = meridian_pb.MeridianModel(
212+
model_version='1.2.3',
213+
hyperparameters=test_data.DEFAULT_HYPERPARAMETERS_PROTO,
214+
prior_tfp_distributions=meridian_pb.PriorTfpDistributions(),
215+
inference_data=meridian_pb.InferenceData(),
216+
computation_backend=meridian_pb.ComputationBackend.JAX,
217+
)
218+
any_model = any_pb2.Any()
219+
any_model.Pack(meridian_model)
220+
mmm_kernel = kernel_pb.MmmKernel(
221+
model=any_model,
222+
marketing_data=test_data.MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED,
223+
)
224+
225+
# Force the current environment to be TENSORFLOW.
226+
with mock.patch.object(
227+
backend,
228+
'computation_backend',
229+
return_value=backend_config.ComputationBackend.TENSORFLOW,
230+
):
231+
with self.assertWarnsRegex(
232+
UserWarning,
233+
'The model was trained using JAX, but the current backend is'
234+
' TENSORFLOW',
235+
):
236+
self.serde.deserialize(mmm_kernel)
237+
238+
def test_deserialize_no_warning_on_backend_match(self):
239+
# Create a proto indicating it was trained with TENSORFLOW.
240+
meridian_model = meridian_pb.MeridianModel(
241+
model_version='1.2.3',
242+
hyperparameters=test_data.DEFAULT_HYPERPARAMETERS_PROTO,
243+
prior_tfp_distributions=meridian_pb.PriorTfpDistributions(),
244+
inference_data=meridian_pb.InferenceData(),
245+
computation_backend=meridian_pb.ComputationBackend.TENSORFLOW,
246+
)
247+
any_model = any_pb2.Any()
248+
any_model.Pack(meridian_model)
249+
mmm_kernel = kernel_pb.MmmKernel(
250+
model=any_model,
251+
marketing_data=test_data.MOCK_PROTO_MEDIA_PAID_GRANULAR_NOT_LAGGED,
252+
)
253+
254+
# Force the current environment to be TENSORFLOW.
255+
with mock.patch.object(
256+
backend,
257+
'computation_backend',
258+
return_value=backend_config.ComputationBackend.TENSORFLOW,
259+
):
260+
with warnings.catch_warnings(record=True) as w:
261+
warnings.simplefilter('always')
262+
self.serde.deserialize(mmm_kernel)
263+
# Ensure no backend-related warnings were raised.
264+
backend_warnings = [x for x in w if 'backend' in str(x.message).lower()]
265+
self.assertEmpty(backend_warnings)
266+
189267
def test_serialize_no_controls(self):
190268
meridian_model = model.Meridian(
191269
input_data=_INPUT_DATA_NO_CONTROLS,

0 commit comments

Comments
 (0)