Skip to content

Commit 492930a

Browse files
authored
Adding arguments to model builder (#1867)
* Adding arguments to model builder * changes
1 parent 488ac3c commit 492930a

File tree

3 files changed

+142
-4
lines changed

3 files changed

+142
-4
lines changed

pymc_marketing/model_builder.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,13 +456,20 @@ def set_idata_attrs(
456456
idata.attrs = attrs
457457
return idata
458458

459-
def save(self, fname: str) -> None:
459+
def save(self, fname: str, **kwargs) -> None:
460460
"""Save the model's inference data to a file.
461461
462462
Parameters
463463
----------
464464
fname : str
465465
The name and path of the file to save the inference data with model parameters.
466+
**kwargs
467+
Additional keyword arguments to pass to arviz.InferenceData.to_netcdf().
468+
Common options include:
469+
- engine : str, optional (default "netcdf4")
470+
Library to use for writing files.
471+
- groups : list of str, optional
472+
Groups to save to netcdf. If None, all groups are saved.
466473
467474
Returns
468475
-------
@@ -483,14 +490,20 @@ def save(self, fname: str) -> None:
483490
>>> super().__init__()
484491
>>> model = MyModel()
485492
>>> model.fit(X, y)
493+
>>> # Basic save
494+
>>> model.save("model_results.nc")
495+
>>>
496+
>>> # Save with specific options
486497
>>> model.save(
487-
... "model_results.nc"
488-
... ) # This will call the overridden method in MyModel
498+
... "model_results.nc",
499+
... engine="netcdf4",
500+
... groups=["posterior", "log_likelihood"],
501+
... )
489502
490503
"""
491504
if self.idata is not None and "posterior" in self.idata:
492505
file = Path(str(fname))
493-
self.idata.to_netcdf(str(file))
506+
self.idata.to_netcdf(str(file), **kwargs)
494507
else:
495508
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
496509

tests/mmm/test_mmm.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,63 @@ def test_save_load(self, mmm_fitted: MMM):
765765
assert model.sampler_config == model2.sampler_config
766766
os.remove("test_save_load")
767767

768+
def test_save_load_with_kwargs(self, mmm_fitted: MMM):
769+
"""Test save/load functionality with kwargs (engine and groups)."""
770+
model = mmm_fitted
771+
772+
# Use kwargs to test functionality - ArviZ supports engine and groups
773+
# Note: ArviZ's to_netcdf has limited compression support compared to xarray
774+
compression_kwargs = {
775+
"engine": "h5netcdf", # Alternative engine that may have better compression
776+
}
777+
778+
model.save("test_save_load_kwargs", **compression_kwargs)
779+
780+
# Load and verify
781+
model2 = MMM.load("test_save_load_kwargs")
782+
assert model.date_column == model2.date_column
783+
assert model.control_columns == model2.control_columns
784+
assert model.channel_columns == model2.channel_columns
785+
assert model.adstock.l_max == model2.adstock.l_max
786+
assert model.validate_data == model2.validate_data
787+
assert model.yearly_seasonality == model2.yearly_seasonality
788+
assert model.model_config == model2.model_config
789+
assert model.sampler_config == model2.sampler_config
790+
791+
os.remove("test_save_load_kwargs")
792+
793+
def test_save_load_engine_comparison(self, mmm_fitted: MMM):
794+
"""Test save/load with different engines and kwargs options."""
795+
model = mmm_fitted
796+
797+
# Save with default engine
798+
model.save("test_save_load_default")
799+
800+
# Save with h5netcdf engine (demonstrates kwargs functionality)
801+
engine_kwargs = {
802+
"engine": "h5netcdf",
803+
}
804+
model.save("test_save_load_h5netcdf", **engine_kwargs)
805+
806+
# Verify both files exist
807+
assert os.path.exists("test_save_load_default")
808+
assert os.path.exists("test_save_load_h5netcdf")
809+
810+
# Verify both can be loaded successfully and have the same data
811+
model_default = MMM.load("test_save_load_default")
812+
model_h5netcdf = MMM.load("test_save_load_h5netcdf")
813+
814+
# Both should have the same model configuration
815+
assert (
816+
model.model_config
817+
== model_default.model_config
818+
== model_h5netcdf.model_config
819+
)
820+
821+
# Clean up
822+
os.remove("test_save_load_default")
823+
os.remove("test_save_load_h5netcdf")
824+
768825
def test_fail_id_after_load(self, monkeypatch, toy_X, toy_y):
769826
# This is the new behavior for the property
770827
def mock_property(self):

tests/test_model_builder.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import hashlib
1515
import json
16+
import os
1617
import re
1718
import sys
1819
import tempfile
@@ -221,6 +222,73 @@ def test_save_without_fit_raises_runtime_error():
221222
model_builder.save("saved_model")
222223

223224

225+
def test_save_with_kwargs(fitted_model_instance):
226+
"""Test that kwargs are properly passed to to_netcdf"""
227+
import unittest.mock as mock
228+
229+
with mock.patch.object(fitted_model_instance.idata, "to_netcdf") as mock_to_netcdf:
230+
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
231+
232+
# Test with kwargs supported by InferenceData.to_netcdf()
233+
kwargs = {"engine": "netcdf4", "groups": ["posterior", "log_likelihood"]}
234+
235+
fitted_model_instance.save(temp.name, **kwargs)
236+
237+
# Verify to_netcdf was called with the correct arguments
238+
mock_to_netcdf.assert_called_once_with(temp.name, **kwargs)
239+
temp.close()
240+
241+
242+
def test_save_with_kwargs_integration(fitted_model_instance):
243+
"""Test save function with actual kwargs (integration test)"""
244+
245+
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
246+
temp_path = temp.name
247+
temp.close()
248+
249+
try:
250+
# Test with specific groups - this tests that kwargs are passed through
251+
fitted_model_instance.save(temp_path, groups=["posterior"])
252+
253+
# Verify file was created successfully
254+
assert os.path.exists(temp_path)
255+
256+
# Verify we can read the file and it contains the expected groups
257+
from pymc_marketing.utils import from_netcdf
258+
259+
loaded_idata = from_netcdf(temp_path)
260+
assert "posterior" in loaded_idata.groups()
261+
# Should only have posterior since we specified groups=["posterior"]
262+
assert "fit_data" not in loaded_idata.groups()
263+
264+
finally:
265+
# Clean up
266+
if os.path.exists(temp_path):
267+
os.unlink(temp_path)
268+
269+
270+
def test_save_kwargs_backward_compatibility(fitted_model_instance):
271+
"""Test that save function still works without kwargs (backward compatibility)"""
272+
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
273+
temp_path = temp.name
274+
temp.close()
275+
276+
try:
277+
# Test without any kwargs (original behavior)
278+
fitted_model_instance.save(temp_path)
279+
280+
# Verify file was created and can be loaded
281+
assert os.path.exists(temp_path)
282+
loaded_model = ModelBuilderTest.load(temp_path)
283+
assert loaded_model.idata is not None
284+
assert "posterior" in loaded_model.idata.groups()
285+
286+
finally:
287+
# Clean up
288+
if os.path.exists(temp_path):
289+
os.unlink(temp_path)
290+
291+
224292
def test_empty_sampler_config_fit(toy_X, toy_y, mock_pymc_sample):
225293
sampler_config = {}
226294
model_builder = ModelBuilderTest(sampler_config=sampler_config)

0 commit comments

Comments
 (0)