Skip to content

Commit b136449

Browse files
committed
Huge chunck of changes requested by William
1 parent 846eee0 commit b136449

16 files changed

+39
-154
lines changed

data/config_files/basic_model.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
schema_version: 1
2-
3-
# ----------------------------------------------------------------------
41
model:
52
class: pymc_marketing.mmm.multidimensional.MMM
63
kwargs:

data/config_files/example_with_original_scale_vars.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
schema_version: 1
2-
3-
# ----------------------------------------------------------------------
41
model:
52
class: pymc_marketing.mmm.multidimensional.MMM
63
kwargs:

data/config_files/multi_dimensiona_hierarchical_model_nested_config.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
schema_version: 1
2-
3-
# ----------------------------------------------------------------------
41
model:
52
class: pymc_marketing.mmm.multidimensional.MMM
63
kwargs:

data/config_files/multi_dimensional_hierarchical_model.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
schema_version: 1
2-
3-
# ----------------------------------------------------------------------
41
model:
52
class: pymc_marketing.mmm.multidimensional.MMM
63
kwargs:

data/config_files/multi_dimensional_hierarchical_with_arbitrary_effects_model.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
schema_version: 1
2-
3-
# ----------------------------------------------------------------------
41
model:
52
class: pymc_marketing.mmm.multidimensional.MMM
63
kwargs:

data/config_files/multi_dimensional_model.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
schema_version: 1
2-
3-
# ----------------------------------------------------------------------
41
model:
52
class: pymc_marketing.mmm.multidimensional.MMM
63
kwargs:

docs/source/notebooks/mmm/mmm_build_from_yml_example.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"import matplotlib.pyplot as plt\n",
3333
"import pandas as pd\n",
3434
"\n",
35-
"from pymc_marketing.mmm.builders.yaml import build_from_yaml\n",
35+
"from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml\n",
3636
"from pymc_marketing.paths import data_dir\n",
3737
"\n",
3838
"warnings.filterwarnings(\"ignore\")\n",
@@ -205,7 +205,7 @@
205205
"metadata": {},
206206
"outputs": [],
207207
"source": [
208-
"mmm = build_from_yaml(\n",
208+
"mmm = build_mmm_from_yaml(\n",
209209
" X=X, y=y, config_path=data_dir / \"config_files\" / \"basic_model.yml\"\n",
210210
")"
211211
]
@@ -1055,7 +1055,7 @@
10551055
"3. sample_kwargs: Optional parameters for the sampling process\n",
10561056
"4. data: Optional paths to data files\n",
10571057
"\n",
1058-
"The build_from_yaml function:\n",
1058+
"The build_mmm_from_yaml function:\n",
10591059
"- Parses this YAML configuration\n",
10601060
"- Uses the 'build' function to instantiate objects recursively\n",
10611061
"- Handles special cases like priors and distributions\n",
@@ -1076,7 +1076,7 @@
10761076
"metadata": {},
10771077
"outputs": [],
10781078
"source": [
1079-
"mmm2 = build_from_yaml(\n",
1079+
"mmm2 = build_mmm_from_yaml(\n",
10801080
" X=X, y=y, config_path=data_dir / \"config_files\" / \"multi_dimensional_model.yml\"\n",
10811081
")"
10821082
]
@@ -1333,7 +1333,7 @@
13331333
"metadata": {},
13341334
"outputs": [],
13351335
"source": [
1336-
"mmm3 = build_from_yaml(\n",
1336+
"mmm3 = build_mmm_from_yaml(\n",
13371337
" X=X,\n",
13381338
" y=y,\n",
13391339
" config_path=data_dir / \"config_files\" / \"multi_dimensional_hierarchical_model.yml\",\n",
@@ -1625,7 +1625,7 @@
16251625
"metadata": {},
16261626
"outputs": [],
16271627
"source": [
1628-
"mmm4 = build_from_yaml(\n",
1628+
"mmm4 = build_mmm_from_yaml(\n",
16291629
" X=X,\n",
16301630
" y=y,\n",
16311631
" config_path=data_dir\n",

pymc_marketing/mmm/builders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
# Expose key functionality
1717
from pymc_marketing.deserialize import register_deserialization
1818
from pymc_marketing.mmm.builders.factories import build, create_prior_from_dict
19-
from pymc_marketing.mmm.builders.yaml import build_from_yaml
19+
from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
2020

21-
__all__ = ["build", "build_from_yaml"]
21+
__all__ = ["build", "build_mmm_from_yaml"]
2222

2323
# Register deserializers
2424
register_deserialization(

pymc_marketing/mmm/builders/deserializers.py

Lines changed: 2 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,6 @@
2626
logger = logging.getLogger(__name__)
2727

2828

29-
def is_prior_dict(data: Any) -> bool:
30-
"""Check if the data is a dictionary representing a Prior."""
31-
if not isinstance(data, dict):
32-
return False
33-
return "distribution" in data
34-
35-
36-
def deserialize_prior(data: dict[str, Any]) -> Prior:
37-
"""Deserialize a Prior from the dictionary representation."""
38-
# Make a copy to avoid modifying the original
39-
data_copy = data.copy()
40-
41-
# Extract distribution
42-
distribution = data_copy.pop("distribution")
43-
44-
# Process nested priors in the data before creating the Prior
45-
for key, value in data_copy.items():
46-
if isinstance(value, dict) and is_prior_dict(value):
47-
data_copy[key] = deserialize_prior(value)
48-
elif (
49-
isinstance(value, dict)
50-
and "class" in value
51-
and value["class"] == "pymc_marketing.prior.Prior"
52-
):
53-
data_copy[key] = deserialize_standard_prior(value)
54-
55-
# Create Prior
56-
return Prior(distribution, **data_copy)
57-
58-
5929
def is_alternative_prior(data: Any) -> bool:
6030
"""Check if the data is a dictionary representing a Prior (alternative check)."""
6131
return isinstance(data, dict) and "distribution" in data
@@ -105,72 +75,14 @@ def deserialize_alternative_prior(data: dict[str, Any]) -> Prior:
10575
)
10676

10777

108-
def is_standard_prior_dict(data: Any) -> tuple[bool, str]:
109-
"""
110-
Check if the data is a standard dictionary in the format used by factories.py.
111-
112-
Returns a tuple of (is_match, target_class_name)
113-
"""
114-
if not isinstance(data, dict):
115-
return False, ""
116-
117-
if (
118-
"class" in data
119-
and data["class"] == "pymc_marketing.prior.Prior"
120-
and "kwargs" in data
121-
):
122-
return True, "Prior"
123-
124-
return False, ""
125-
126-
127-
def deserialize_standard_prior(data: dict[str, Any]) -> Prior:
128-
"""
129-
Deserialize a prior from the standard format used by factories.py.
130-
131-
The expected format is:
132-
{
133-
"class": "pymc_marketing.prior.Prior",
134-
"kwargs": {
135-
"args": ["Distribution"],
136-
"param1": value1,
137-
...
138-
}
139-
}
140-
"""
141-
kwargs = data.get("kwargs", {})
142-
143-
# Get distribution from args
144-
args = kwargs.get("args", ["Normal"])
145-
distribution = args[0]
146-
147-
# Create a new kwargs dict without args
148-
new_kwargs = {k: v for k, v in kwargs.items() if k != "args"}
149-
150-
# Process nested priors in kwargs
151-
for key, value in new_kwargs.items():
152-
if isinstance(value, dict):
153-
if "distribution" in value:
154-
new_kwargs[key] = deserialize_prior(value)
155-
elif "class" in value and value["class"] == "pymc_marketing.prior.Prior":
156-
new_kwargs[key] = deserialize_standard_prior(value)
157-
158-
# Create Prior
159-
logger.info(f"Creating Prior with distribution={distribution}, kwargs={new_kwargs}")
160-
return Prior(distribution, **new_kwargs)
161-
162-
16378
def is_priors_dict(data: Any) -> bool:
16479
"""Check if the data is a dictionary of priors."""
16580
if not isinstance(data, dict):
16681
return False
16782

16883
# Check if any value is a prior-like dictionary
16984
for _key, value in data.items():
170-
if isinstance(value, dict) and (
171-
"distribution" in value
172-
or ("class" in value and value["class"] == "pymc_marketing.prior.Prior")
173-
):
85+
if isinstance(value, dict) and "distribution" in value:
17486
return True
17587
return False
17688

@@ -181,26 +93,16 @@ def deserialize_priors_dict(data: dict[str, Any]) -> dict[str, Any]:
18193
for key, value in data.items():
18294
if isinstance(value, dict):
18395
if "distribution" in value:
184-
result[key] = deserialize_prior(value)
185-
elif "class" in value and value["class"] == "pymc_marketing.prior.Prior":
186-
result[key] = deserialize_standard_prior(value)
96+
result[key] = deserialize(value)
18797
else:
18898
result[key] = value
18999
else:
190100
result[key] = value
191101
return result
192102

193103

194-
# Register the simple prior deserializer for distribution-based format
195-
register_deserialization(is_prior_dict, deserialize_prior)
196-
197104
# Register the alternative prior deserializer for more complex nested cases
198105
register_deserialization(is_alternative_prior, deserialize_alternative_prior)
199106

200-
# Register the standard deserializer used by factories.py
201-
register_deserialization(
202-
lambda x: is_standard_prior_dict(x)[0], deserialize_standard_prior
203-
)
204-
205107
# Register the priors dictionary deserializer
206108
register_deserialization(is_priors_dict, deserialize_priors_dict)

pymc_marketing/mmm/builders/factories.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections.abc import Mapping, MutableMapping, Sequence
2020
from typing import Any
2121

22+
from pymc_marketing.deserialize import deserialize
2223
from pymc_marketing.prior import Prior
2324

2425
# Optional short-name registry -------------------------------------------------
@@ -126,13 +127,23 @@ def build(spec: Mapping[str, Any]) -> Any:
126127
isinstance(prior_value, dict)
127128
and "distribution" in prior_value
128129
):
129-
priors_dict[prior_key] = create_prior_from_dict(prior_value)
130+
# Use deserialize for individual priors
131+
try:
132+
priors_dict[prior_key] = deserialize(prior_value)
133+
except Exception:
134+
# Fall back to create_prior_from_dict if deserialize fails
135+
priors_dict[prior_key] = create_prior_from_dict(
136+
prior_value
137+
)
130138
else:
131139
priors_dict[prior_key] = prior_value
132140
kwargs[k] = priors_dict
133141
elif k == "prior" and "distribution" in v:
134-
# Create a single prior object
135-
kwargs[k] = create_prior_from_dict(v) # type: ignore
142+
# Use deserialize for a single prior, with fallback
143+
try:
144+
kwargs[k] = deserialize(v)
145+
except Exception:
146+
kwargs[k] = create_prior_from_dict(v) # type: ignore
136147
else:
137148
kwargs[k] = resolve(v)
138149
else:

0 commit comments

Comments
 (0)