Skip to content

Commit 965b458

Browse files
committed
move to the prior module
1 parent f7dac39 commit 965b458

File tree

3 files changed

+53
-110
lines changed

3 files changed

+53
-110
lines changed

pymc_marketing/mmm/builders/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
"""Configuration I/O for PyMC-Marketing."""
1515

16-
# Expose key functionality
17-
from pymc_marketing.mmm.builders import deserializers # noqa: F401
1816
from pymc_marketing.mmm.builders.factories import build
1917
from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
2018

pymc_marketing/mmm/builders/deserializers.py

Lines changed: 0 additions & 108 deletions
This file was deleted.

pymc_marketing/prior.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,59 @@ def create_likelihood_variable(
11091109
return distribution.create_variable(name)
11101110

11111111

1112+
def is_alternative_prior(data: Any) -> bool:
1113+
"""Check if the data is a dictionary representing a Prior (alternative check)."""
1114+
return isinstance(data, dict) and "distribution" in data
1115+
1116+
1117+
def deserialize_alternative_prior(data: dict[str, Any]) -> Prior:
1118+
"""Alternative deserializer that recursively handles all nested parameters.
1119+
1120+
This implementation is more general and handles cases where any parameter
1121+
might be a nested prior, and also extracts centered and transform parameters.
1122+
1123+
Examples
1124+
--------
1125+
This handles cases like:
1126+
1127+
.. code-block:: yaml
1128+
1129+
distribution: Gamma
1130+
alpha: 1
1131+
beta:
1132+
distribution: HalfNormal
1133+
sigma: 1
1134+
dims: channel
1135+
dims: [brand, channel]
1136+
1137+
"""
1138+
data = copy.deepcopy(data)
1139+
1140+
distribution = data.pop("distribution")
1141+
dims = data.pop("dims", None)
1142+
centered = data.pop("centered", True)
1143+
transform = data.pop("transform", None)
1144+
parameters = data
1145+
1146+
# Recursively deserialize any nested parameters
1147+
parameters = {
1148+
key: value if not isinstance(value, dict) else deserialize(value)
1149+
for key, value in parameters.items()
1150+
}
1151+
1152+
return Prior(
1153+
distribution,
1154+
transform=transform,
1155+
centered=centered,
1156+
dims=dims,
1157+
**parameters,
1158+
)
1159+
1160+
1161+
# Register the alternative prior deserializer for more complex nested cases
1162+
register_deserialization(is_alternative_prior, deserialize_alternative_prior)
1163+
1164+
11121165
class VariableNotFound(Exception):
11131166
"""Variable is not found."""
11141167

0 commit comments

Comments
 (0)