Skip to content

Commit a1f11ec

Browse files
committed
remove the already covered case
1 parent 8522eb1 commit a1f11ec

File tree

2 files changed

+5
-51
lines changed

2 files changed

+5
-51
lines changed

pymc_marketing/mmm/builders/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,8 @@
1414
"""Configuration I/O for PyMC-Marketing."""
1515

1616
# Expose key functionality
17-
from pymc_marketing.deserialize import register_deserialization
18-
from pymc_marketing.mmm.builders.factories import build, create_prior_from_dict
17+
from pymc_marketing.mmm.builders import deserializers # noqa: F401
18+
from pymc_marketing.mmm.builders.factories import build
1919
from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
2020

2121
__all__ = ["build", "build_mmm_from_yaml"]
22-
23-
# Register deserializers
24-
register_deserialization(
25-
is_type=lambda data: isinstance(data, dict) and "distribution" in data,
26-
deserialize=create_prior_from_dict,
27-
)

pymc_marketing/mmm/builders/factories.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import Any
2121

2222
from pymc_marketing.deserialize import deserialize
23-
from pymc_marketing.prior import Prior
2423

2524
# Optional short-name registry -------------------------------------------------
2625
REGISTRY: dict[str, Any] = {
@@ -58,31 +57,6 @@ def locate(qualname: str) -> Any:
5857
return getattr(module_obj, obj_name)
5958

6059

61-
def create_prior_from_dict(prior_dict: dict) -> Prior:
62-
"""
63-
Create a Prior object from a dictionary representation.
64-
65-
This handles nested priors by recursively converting dictionaries to Prior objects.
66-
"""
67-
if not isinstance(prior_dict, dict) or "distribution" not in prior_dict:
68-
raise ValueError(f"Invalid prior dictionary: {prior_dict}")
69-
70-
# Make a copy to avoid modifying the original
71-
data = prior_dict.copy()
72-
distribution = data.pop("distribution")
73-
74-
# Convert list dimensions to tuples to avoid unhashable type errors
75-
if "dims" in data and isinstance(data["dims"], list):
76-
data["dims"] = tuple(data["dims"])
77-
78-
# Process nested priors in parameters
79-
for key, value in list(data.items()):
80-
if isinstance(value, dict) and "distribution" in value:
81-
data[key] = create_prior_from_dict(value)
82-
83-
return Prior(distribution, **data)
84-
85-
8660
def build(spec: Mapping[str, Any]) -> Any:
8761
"""
8862
Instantiate the object described by *spec*.
@@ -123,27 +97,13 @@ def build(spec: Mapping[str, Any]) -> Any:
12397
# Create a dictionary of priors
12498
priors_dict = {}
12599
for prior_key, prior_value in v.items():
126-
if (
127-
isinstance(prior_value, dict)
128-
and "distribution" in prior_value
129-
):
130-
# Use deserialize for individual priors
131-
try:
132-
priors_dict[prior_key] = deserialize(prior_value)
133-
except Exception:
134-
# Fall back to create_prior_from_dict if deserialize fails
135-
priors_dict[prior_key] = create_prior_from_dict(
136-
prior_value
137-
)
100+
if isinstance(prior_value, dict):
101+
priors_dict[prior_key] = deserialize(prior_value)
138102
else:
139103
priors_dict[prior_key] = prior_value
140104
kwargs[k] = priors_dict
141105
elif k == "prior" and "distribution" in v:
142-
# Use deserialize for a single prior, with fallback
143-
try:
144-
kwargs[k] = deserialize(v)
145-
except Exception:
146-
kwargs[k] = create_prior_from_dict(v) # type: ignore
106+
kwargs[k] = deserialize(v)
147107
else:
148108
kwargs[k] = resolve(v)
149109
else:

0 commit comments

Comments
 (0)