Skip to content

Commit f4fe828

Browse files
authored
Add Censored wrapper for Prior class (#1309)
* support for deserialization of three classes * add the deserialize logic * correct the type hint * add an example * separate out the error * catch error at deserialization * relax the input type * add test suite * add test for arbitrary serialization via Prior * use deserialize within from_json * test for deserialize support within Prior * use general deserialization in individual media transformations * test both deserialize funcs * test arb deserialization in adstock * add similar deserialization check for saturation * better naming of the tests * support parsing of hsgp kwargs * add to the module level docstring * allow VariableFactory in parse_model_config * implement the censored variable as VariableFactory * add test against VariableFactory * test the variables created with censored variable * add few more censored tests * add tests for errors * implement the censored variable as VariableFactory * add test against VariableFactory * test the variables created with censored variable * add few more censored tests * add tests for errors * add module to documentation * add to the module documentation * Reorder the top level documentation * add more docstring about the functions * serialization support for Censored class * run pre-commit * checks that lead to deterministics * switch out with pydantic.dataclasses.dataclass * handle the pytensor variable case * add setter for dims
1 parent 8600ef3 commit f4fe828

File tree

2 files changed

+424
-1
lines changed

2 files changed

+424
-1
lines changed

pymc_marketing/prior.py

Lines changed: 268 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def custom_transform(x):
105105
import pymc as pm
106106
import pytensor.tensor as pt
107107
import xarray as xr
108-
from pydantic import validate_call
108+
from pydantic import InstanceOf, validate_call
109+
from pydantic.dataclasses import dataclass
109110
from pymc.distributions.shape_utils import Dims
110111

111112
from pymc_marketing.deserialize import deserialize, register_deserialization
@@ -1025,8 +1026,274 @@ def create_likelihood_variable(
10251026
return distribution.create_variable(name)
10261027

10271028

1029+
class VariableNotFound(Exception):
1030+
"""Variable is not found."""
1031+
1032+
1033+
def _remove_random_variable(var: pt.TensorVariable) -> None:
1034+
if var.name is None:
1035+
raise ValueError("This isn't removable")
1036+
1037+
name: str = var.name
1038+
1039+
model = pm.modelcontext(None)
1040+
for idx, free_rv in enumerate(model.free_RVs):
1041+
if var == free_rv:
1042+
index_to_remove = idx
1043+
break
1044+
else:
1045+
raise VariableNotFound(f"Variable {var.name!r} not found")
1046+
1047+
var.name = None
1048+
model.free_RVs.pop(index_to_remove)
1049+
model.named_vars.pop(name)
1050+
1051+
1052+
@dataclass
1053+
class Censored:
1054+
"""Create censored random variable.
1055+
1056+
Examples
1057+
--------
1058+
Create a censored Normal distribution:
1059+
1060+
.. code-block:: python
1061+
1062+
from pymc_marketing.prior import Prior, Censored
1063+
1064+
normal = Prior("Normal")
1065+
censored_normal = Censored(normal, lower=0)
1066+
1067+
Create hierarchical censored Normal distribution:
1068+
1069+
.. code-block:: python
1070+
1071+
from pymc_marketing.prior import Prior, Censored
1072+
1073+
normal = Prior(
1074+
"Normal",
1075+
mu=Prior("Normal"),
1076+
sigma=Prior("HalfNormal"),
1077+
dims="channel",
1078+
)
1079+
censored_normal = Censored(normal, lower=0)
1080+
1081+
coords = {"channel": range(3)}
1082+
samples = censored_normal.sample_prior(coords=coords)
1083+
1084+
"""
1085+
1086+
distribution: InstanceOf[Prior]
1087+
lower: float | InstanceOf[pt.TensorVariable] = -np.inf
1088+
upper: float | InstanceOf[pt.TensorVariable] = np.inf
1089+
1090+
def __post_init__(self) -> None:
1091+
"""Check validity at initialization."""
1092+
if not self.distribution.centered:
1093+
raise ValueError(
1094+
"Censored distribution must be centered so that .dist() API can be used on distribution."
1095+
)
1096+
1097+
if self.distribution.transform is not None:
1098+
raise ValueError(
1099+
"Censored distribution can't have a transform so that .dist() API can be used on distribution."
1100+
)
1101+
1102+
@property
1103+
def dims(self) -> tuple[str, ...]:
1104+
"""The dims from the distribution to censor."""
1105+
return self.distribution.dims
1106+
1107+
@dims.setter
1108+
def dims(self, dims) -> None:
1109+
self.distribution.dims = dims
1110+
1111+
def create_variable(self, name: str) -> pt.TensorVariable:
1112+
"""Create censored random variable."""
1113+
dist = self.distribution.create_variable(name)
1114+
_remove_random_variable(var=dist)
1115+
1116+
return pm.Censored(
1117+
name,
1118+
dist,
1119+
lower=self.lower,
1120+
upper=self.upper,
1121+
dims=self.dims,
1122+
)
1123+
1124+
def to_dict(self) -> dict[str, Any]:
1125+
"""Convert the censored distribution to a dictionary."""
1126+
1127+
def handle_value(value):
1128+
if isinstance(value, pt.TensorVariable):
1129+
return value.eval().tolist()
1130+
1131+
return value
1132+
1133+
return {
1134+
"class": "Censored",
1135+
"data": {
1136+
"dist": self.distribution.to_json(),
1137+
"lower": handle_value(self.lower),
1138+
"upper": handle_value(self.upper),
1139+
},
1140+
}
1141+
1142+
@classmethod
1143+
def from_dict(cls, data: dict[str, Any]) -> Censored:
1144+
"""Create a censored distribution from a dictionary."""
1145+
data = data["data"]
1146+
return cls( # type: ignore
1147+
distribution=Prior.from_json(data["dist"]),
1148+
lower=data["lower"],
1149+
upper=data["upper"],
1150+
)
1151+
1152+
def sample_prior(
1153+
self,
1154+
coords=None,
1155+
name: str = "variable",
1156+
**sample_prior_predictive_kwargs,
1157+
) -> xr.Dataset:
1158+
"""Sample the prior distribution for the variable.
1159+
1160+
Parameters
1161+
----------
1162+
coords : dict[str, list[str]], optional
1163+
The coordinates for the variable, by default None.
1164+
Only required if the dims are specified.
1165+
name : str, optional
1166+
The name of the variable, by default "var".
1167+
sample_prior_predictive_kwargs : dict
1168+
Additional arguments to pass to `pm.sample_prior_predictive`.
1169+
1170+
Returns
1171+
-------
1172+
xr.Dataset
1173+
The dataset of the prior samples.
1174+
1175+
Example
1176+
-------
1177+
Sample from a censored Gamma distribution.
1178+
1179+
.. code-block:: python
1180+
1181+
gamma = Prior("Gamma", mu=1, sigma=1, dims="channel")
1182+
dist = Censored(gamma, lower=0.5)
1183+
1184+
coords = {"channel": ["C1", "C2", "C3"]}
1185+
prior = dist.sample_prior(coords=coords)
1186+
1187+
"""
1188+
coords = coords or {}
1189+
1190+
if missing_keys := set(self.dims) - set(coords.keys()):
1191+
raise KeyError(f"Coords are missing the following dims: {missing_keys}")
1192+
1193+
with pm.Model(coords=coords):
1194+
self.create_variable(name)
1195+
1196+
return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior
1197+
1198+
def to_graph(self):
1199+
"""Generate a graph of the variables.
1200+
1201+
Examples
1202+
--------
1203+
Create graph for a censored Normal distribution
1204+
1205+
.. code-block:: python
1206+
1207+
from pymc_marketing.prior import Prior, Censored
1208+
1209+
normal = Prior("Normal")
1210+
censored_normal = Censored(normal, lower=0)
1211+
1212+
censored_normal.to_graph()
1213+
1214+
"""
1215+
coords = {name: ["DUMMY"] for name in self.dims}
1216+
with pm.Model(coords=coords) as model:
1217+
self.create_variable("var")
1218+
1219+
return pm.model_to_graphviz(model)
1220+
1221+
def create_likelihood_variable(
1222+
self,
1223+
name: str,
1224+
mu: pt.TensorLike,
1225+
observed: pt.TensorLike,
1226+
) -> pt.TensorVariable:
1227+
"""Create observed censored variable.
1228+
1229+
Will require that the distribution has a `mu` parameter
1230+
and that it has not been set in the parameters.
1231+
1232+
Parameters
1233+
----------
1234+
name : str
1235+
The name of the variable.
1236+
mu : pt.TensorLike
1237+
The mu parameter for the likelihood.
1238+
observed : pt.TensorLike
1239+
The observed data.
1240+
1241+
Returns
1242+
-------
1243+
pt.TensorVariable
1244+
The PyMC variable.
1245+
1246+
Examples
1247+
--------
1248+
Create a censored likelihood variable in a larger PyMC model.
1249+
1250+
.. code-block:: python
1251+
1252+
import pymc as pm
1253+
from pymc_marketing.prior import Prior, Censored
1254+
1255+
normal = Prior("Normal", sigma=Prior("HalfNormal"))
1256+
dist = Censored(normal, lower=0)
1257+
1258+
observed = 1
1259+
1260+
with pm.Model():
1261+
# Create the likelihood variable
1262+
mu = pm.HalfNormal("mu", sigma=1)
1263+
dist.create_likelihood_variable("y", mu=mu, observed=observed)
1264+
1265+
"""
1266+
if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution):
1267+
raise UnsupportedDistributionError(
1268+
f"Likelihood distribution {self.distribution.distribution!r} is not supported."
1269+
)
1270+
1271+
if "mu" in self.distribution.parameters:
1272+
raise MuAlreadyExistsError(self.distribution)
1273+
1274+
distribution = self.distribution.deepcopy()
1275+
distribution.parameters["mu"] = mu
1276+
1277+
dist = distribution.create_variable(name)
1278+
_remove_random_variable(var=dist)
1279+
1280+
return pm.Censored(
1281+
name,
1282+
dist,
1283+
observed=observed,
1284+
lower=self.lower,
1285+
upper=self.upper,
1286+
dims=self.dims,
1287+
)
1288+
1289+
10281290
def _is_prior_type(data: dict) -> bool:
10291291
return "dist" in data
10301292

10311293

1294+
def _is_censored_type(data: dict) -> bool:
1295+
return data.keys() == {"class", "data"} and data["class"] == "Censored"
1296+
1297+
10321298
register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_json)
1299+
register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)

0 commit comments

Comments
 (0)