|
20 | 20 | from typing import Any |
21 | 21 |
|
22 | 22 | from pymc_marketing.deserialize import deserialize |
23 | | -from pymc_marketing.prior import Prior |
24 | 23 |
|
25 | 24 | # Optional short-name registry ------------------------------------------------- |
26 | 25 | REGISTRY: dict[str, Any] = { |
@@ -58,31 +57,6 @@ def locate(qualname: str) -> Any: |
58 | 57 | return getattr(module_obj, obj_name) |
59 | 58 |
|
60 | 59 |
|
61 | | -def create_prior_from_dict(prior_dict: dict) -> Prior: |
62 | | - """ |
63 | | - Create a Prior object from a dictionary representation. |
64 | | -
|
65 | | - This handles nested priors by recursively converting dictionaries to Prior objects. |
66 | | - """ |
67 | | - if not isinstance(prior_dict, dict) or "distribution" not in prior_dict: |
68 | | - raise ValueError(f"Invalid prior dictionary: {prior_dict}") |
69 | | - |
70 | | - # Make a copy to avoid modifying the original |
71 | | - data = prior_dict.copy() |
72 | | - distribution = data.pop("distribution") |
73 | | - |
74 | | - # Convert list dimensions to tuples to avoid unhashable type errors |
75 | | - if "dims" in data and isinstance(data["dims"], list): |
76 | | - data["dims"] = tuple(data["dims"]) |
77 | | - |
78 | | - # Process nested priors in parameters |
79 | | - for key, value in list(data.items()): |
80 | | - if isinstance(value, dict) and "distribution" in value: |
81 | | - data[key] = create_prior_from_dict(value) |
82 | | - |
83 | | - return Prior(distribution, **data) |
84 | | - |
85 | | - |
86 | 60 | def build(spec: Mapping[str, Any]) -> Any: |
87 | 61 | """ |
88 | 62 | Instantiate the object described by *spec*. |
@@ -123,27 +97,13 @@ def build(spec: Mapping[str, Any]) -> Any: |
123 | 97 | # Create a dictionary of priors |
124 | 98 | priors_dict = {} |
125 | 99 | for prior_key, prior_value in v.items(): |
126 | | - if ( |
127 | | - isinstance(prior_value, dict) |
128 | | - and "distribution" in prior_value |
129 | | - ): |
130 | | - # Use deserialize for individual priors |
131 | | - try: |
132 | | - priors_dict[prior_key] = deserialize(prior_value) |
133 | | - except Exception: |
134 | | - # Fall back to create_prior_from_dict if deserialize fails |
135 | | - priors_dict[prior_key] = create_prior_from_dict( |
136 | | - prior_value |
137 | | - ) |
| 100 | + if isinstance(prior_value, dict): |
| 101 | + priors_dict[prior_key] = deserialize(prior_value) |
138 | 102 | else: |
139 | 103 | priors_dict[prior_key] = prior_value |
140 | 104 | kwargs[k] = priors_dict |
141 | 105 | elif k == "prior" and "distribution" in v: |
142 | | - # Use deserialize for a single prior, with fallback |
143 | | - try: |
144 | | - kwargs[k] = deserialize(v) |
145 | | - except Exception: |
146 | | - kwargs[k] = create_prior_from_dict(v) # type: ignore |
| 106 | + kwargs[k] = deserialize(v) |
147 | 107 | else: |
148 | 108 | kwargs[k] = resolve(v) |
149 | 109 | else: |
|
0 commit comments