13
13
# limitations under the License.
14
14
"""CLV Model base class."""
15
15
16
- import json
17
16
import warnings
18
17
from collections .abc import Sequence
19
- from pathlib import Path
20
18
from typing import Literal , cast
21
19
22
20
import arviz as az
27
25
from pymc .backends .base import MultiTrace
28
26
from pymc .model .core import Model
29
27
30
- from pymc_marketing .model_builder import ModelBuilder
28
+ from pymc_marketing .model_builder import DifferentModelError , ModelBuilder
31
29
from pymc_marketing .model_config import ModelConfig , parse_model_config
32
- from pymc_marketing .utils import from_netcdf
33
30
34
31
35
32
class CLVModel (ModelBuilder ):
@@ -46,6 +43,7 @@ def __init__(
46
43
sampler_config : dict | None = None ,
47
44
non_distributions : list [str ] | None = None ,
48
45
):
46
+ self .data = data
49
47
model_config = model_config or {}
50
48
51
49
deprecated_keys = [key for key in model_config if key .endswith ("_prior" )]
@@ -60,14 +58,14 @@ def __init__(
60
58
61
59
model_config [new_key ] = model_config .pop (key )
62
60
63
- model_config = parse_model_config (
64
- model_config ,
61
+ super ().__init__ (model_config , sampler_config )
62
+
63
+ # Parse model config after merging with defaults
64
+ self .model_config = parse_model_config (
65
+ self .model_config ,
65
66
non_distributions = non_distributions ,
66
67
)
67
68
68
- super ().__init__ (model_config , sampler_config )
69
- self .data = data
70
-
71
69
@staticmethod
72
70
def _validate_cols (
73
71
data : pd .DataFrame ,
@@ -260,59 +258,39 @@ def _fit_approx(
260
258
)
261
259
262
260
@classmethod
263
- def load (cls , fname : str ):
264
- """Create a ModelBuilder instance from a file.
265
-
266
- Loads inference data for the model.
267
-
268
- Parameters
269
- ----------
270
- fname : string
271
- This denotes the name with path from where idata should be loaded from.
261
+ def idata_to_init_kwargs (cls , idata : az .InferenceData ) -> dict :
262
+ """Create the initialization kwargs from an InferenceData object."""
263
+ kwargs = cls .attrs_to_init_kwargs (idata .attrs )
264
+ kwargs ["data" ] = idata .fit_data .to_dataframe ()
272
265
273
- Returns
274
- -------
275
- Returns an instance of ModelBuilder.
276
-
277
- Raises
278
- ------
279
- ValueError
280
- If the inference data that is loaded doesn't match with the model.
281
-
282
- Examples
283
- --------
284
- >>> class MyModel(ModelBuilder):
285
- >>> ...
286
- >>> name = "./mymodel.nc"
287
- >>> imported_model = MyModel.load(name)
288
-
289
- """
290
- filepath = Path (str (fname ))
291
- idata = from_netcdf (filepath )
292
- return cls ._build_with_idata (idata )
266
+ return kwargs
293
267
294
268
@classmethod
295
- def _build_with_idata (cls , idata : az .InferenceData ):
296
- dataset = idata .fit_data .to_dataframe ()
269
+ def build_from_idata (cls , idata : az .InferenceData ) -> None :
270
+ """Build the model from the InferenceData object."""
271
+ kwargs = cls .idata_to_init_kwargs (idata )
297
272
with warnings .catch_warnings ():
298
273
warnings .filterwarnings (
299
274
"ignore" ,
300
275
category = DeprecationWarning ,
301
276
)
302
- model = cls (
303
- dataset ,
304
- model_config = json .loads (idata .attrs ["model_config" ]), # type: ignore
305
- sampler_config = json .loads (idata .attrs ["sampler_config" ]),
306
- )
277
+ model = cls (** kwargs )
307
278
308
279
model .idata = idata
309
280
model ._rename_posterior_variables ()
310
281
311
282
model .build_model () # type: ignore
312
283
if model .id != idata .attrs ["id" ]:
313
- raise ValueError (f"Inference data not compatible with { cls ._model_type } " )
284
+ msg = (
285
+ "The model id in the InferenceData does not match the model id. "
286
+ "There was no error loading the inference data, but the model may "
287
+ "be different. "
288
+ "Investigate if the model structure or configuration has changed."
289
+ )
290
+ raise DifferentModelError (msg )
314
291
return model
315
292
293
+ # TODO: Remove in 2026Q1?
316
294
def _rename_posterior_variables (self ):
317
295
"""Rename variables in the posterior group to remove the _prior suffix.
318
296
@@ -355,7 +333,7 @@ def thin_fit_result(self, keep_every: int):
355
333
self .fit_result # noqa: B018 (Raise Error if fit didn't happen yet)
356
334
assert self .idata is not None # noqa: S101
357
335
new_idata = self .idata .isel (draw = slice (None , None , keep_every )).copy ()
358
- return type ( self ). _build_with_idata (new_idata )
336
+ return self . build_from_idata (new_idata )
359
337
360
338
@property
361
339
def default_sampler_config (self ) -> dict :
@@ -378,8 +356,3 @@ def fit_summary(self, **kwargs):
378
356
return res ["mean" ].rename ("value" )
379
357
else :
380
358
return az .summary (self .fit_result , ** kwargs )
381
-
382
- @property
383
- def output_var (self ):
384
- """Output variable of the model."""
385
- pass
0 commit comments