@@ -105,7 +105,8 @@ def custom_transform(x):
105
105
import pymc as pm
106
106
import pytensor .tensor as pt
107
107
import xarray as xr
108
- from pydantic import validate_call
108
+ from pydantic import InstanceOf , validate_call
109
+ from pydantic .dataclasses import dataclass
109
110
from pymc .distributions .shape_utils import Dims
110
111
111
112
from pymc_marketing .deserialize import deserialize , register_deserialization
@@ -1025,8 +1026,274 @@ def create_likelihood_variable(
1025
1026
return distribution .create_variable (name )
1026
1027
1027
1028
1029
+ class VariableNotFound (Exception ):
1030
+ """Variable is not found."""
1031
+
1032
+
1033
+ def _remove_random_variable (var : pt .TensorVariable ) -> None :
1034
+ if var .name is None :
1035
+ raise ValueError ("This isn't removable" )
1036
+
1037
+ name : str = var .name
1038
+
1039
+ model = pm .modelcontext (None )
1040
+ for idx , free_rv in enumerate (model .free_RVs ):
1041
+ if var == free_rv :
1042
+ index_to_remove = idx
1043
+ break
1044
+ else :
1045
+ raise VariableNotFound (f"Variable { var .name !r} not found" )
1046
+
1047
+ var .name = None
1048
+ model .free_RVs .pop (index_to_remove )
1049
+ model .named_vars .pop (name )
1050
+
1051
+
1052
+ @dataclass
1053
+ class Censored :
1054
+ """Create censored random variable.
1055
+
1056
+ Examples
1057
+ --------
1058
+ Create a censored Normal distribution:
1059
+
1060
+ .. code-block:: python
1061
+
1062
+ from pymc_marketing.prior import Prior, Censored
1063
+
1064
+ normal = Prior("Normal")
1065
+ censored_normal = Censored(normal, lower=0)
1066
+
1067
+ Create hierarchical censored Normal distribution:
1068
+
1069
+ .. code-block:: python
1070
+
1071
+ from pymc_marketing.prior import Prior, Censored
1072
+
1073
+ normal = Prior(
1074
+ "Normal",
1075
+ mu=Prior("Normal"),
1076
+ sigma=Prior("HalfNormal"),
1077
+ dims="channel",
1078
+ )
1079
+ censored_normal = Censored(normal, lower=0)
1080
+
1081
+ coords = {"channel": range(3)}
1082
+ samples = censored_normal.sample_prior(coords=coords)
1083
+
1084
+ """
1085
+
1086
+ distribution : InstanceOf [Prior ]
1087
+ lower : float | InstanceOf [pt .TensorVariable ] = - np .inf
1088
+ upper : float | InstanceOf [pt .TensorVariable ] = np .inf
1089
+
1090
+ def __post_init__ (self ) -> None :
1091
+ """Check validity at initialization."""
1092
+ if not self .distribution .centered :
1093
+ raise ValueError (
1094
+ "Censored distribution must be centered so that .dist() API can be used on distribution."
1095
+ )
1096
+
1097
+ if self .distribution .transform is not None :
1098
+ raise ValueError (
1099
+ "Censored distribution can't have a transform so that .dist() API can be used on distribution."
1100
+ )
1101
+
1102
+ @property
1103
+ def dims (self ) -> tuple [str , ...]:
1104
+ """The dims from the distribution to censor."""
1105
+ return self .distribution .dims
1106
+
1107
+ @dims .setter
1108
+ def dims (self , dims ) -> None :
1109
+ self .distribution .dims = dims
1110
+
1111
+ def create_variable (self , name : str ) -> pt .TensorVariable :
1112
+ """Create censored random variable."""
1113
+ dist = self .distribution .create_variable (name )
1114
+ _remove_random_variable (var = dist )
1115
+
1116
+ return pm .Censored (
1117
+ name ,
1118
+ dist ,
1119
+ lower = self .lower ,
1120
+ upper = self .upper ,
1121
+ dims = self .dims ,
1122
+ )
1123
+
1124
+ def to_dict (self ) -> dict [str , Any ]:
1125
+ """Convert the censored distribution to a dictionary."""
1126
+
1127
+ def handle_value (value ):
1128
+ if isinstance (value , pt .TensorVariable ):
1129
+ return value .eval ().tolist ()
1130
+
1131
+ return value
1132
+
1133
+ return {
1134
+ "class" : "Censored" ,
1135
+ "data" : {
1136
+ "dist" : self .distribution .to_json (),
1137
+ "lower" : handle_value (self .lower ),
1138
+ "upper" : handle_value (self .upper ),
1139
+ },
1140
+ }
1141
+
1142
+ @classmethod
1143
+ def from_dict (cls , data : dict [str , Any ]) -> Censored :
1144
+ """Create a censored distribution from a dictionary."""
1145
+ data = data ["data" ]
1146
+ return cls ( # type: ignore
1147
+ distribution = Prior .from_json (data ["dist" ]),
1148
+ lower = data ["lower" ],
1149
+ upper = data ["upper" ],
1150
+ )
1151
+
1152
+ def sample_prior (
1153
+ self ,
1154
+ coords = None ,
1155
+ name : str = "variable" ,
1156
+ ** sample_prior_predictive_kwargs ,
1157
+ ) -> xr .Dataset :
1158
+ """Sample the prior distribution for the variable.
1159
+
1160
+ Parameters
1161
+ ----------
1162
+ coords : dict[str, list[str]], optional
1163
+ The coordinates for the variable, by default None.
1164
+ Only required if the dims are specified.
1165
+ name : str, optional
1166
+ The name of the variable, by default "var".
1167
+ sample_prior_predictive_kwargs : dict
1168
+ Additional arguments to pass to `pm.sample_prior_predictive`.
1169
+
1170
+ Returns
1171
+ -------
1172
+ xr.Dataset
1173
+ The dataset of the prior samples.
1174
+
1175
+ Example
1176
+ -------
1177
+ Sample from a censored Gamma distribution.
1178
+
1179
+ .. code-block:: python
1180
+
1181
+ gamma = Prior("Gamma", mu=1, sigma=1, dims="channel")
1182
+ dist = Censored(gamma, lower=0.5)
1183
+
1184
+ coords = {"channel": ["C1", "C2", "C3"]}
1185
+ prior = dist.sample_prior(coords=coords)
1186
+
1187
+ """
1188
+ coords = coords or {}
1189
+
1190
+ if missing_keys := set (self .dims ) - set (coords .keys ()):
1191
+ raise KeyError (f"Coords are missing the following dims: { missing_keys } " )
1192
+
1193
+ with pm .Model (coords = coords ):
1194
+ self .create_variable (name )
1195
+
1196
+ return pm .sample_prior_predictive (** sample_prior_predictive_kwargs ).prior
1197
+
1198
+ def to_graph (self ):
1199
+ """Generate a graph of the variables.
1200
+
1201
+ Examples
1202
+ --------
1203
+ Create graph for a censored Normal distribution
1204
+
1205
+ .. code-block:: python
1206
+
1207
+ from pymc_marketing.prior import Prior, Censored
1208
+
1209
+ normal = Prior("Normal")
1210
+ censored_normal = Censored(normal, lower=0)
1211
+
1212
+ censored_normal.to_graph()
1213
+
1214
+ """
1215
+ coords = {name : ["DUMMY" ] for name in self .dims }
1216
+ with pm .Model (coords = coords ) as model :
1217
+ self .create_variable ("var" )
1218
+
1219
+ return pm .model_to_graphviz (model )
1220
+
1221
+ def create_likelihood_variable (
1222
+ self ,
1223
+ name : str ,
1224
+ mu : pt .TensorLike ,
1225
+ observed : pt .TensorLike ,
1226
+ ) -> pt .TensorVariable :
1227
+ """Create observed censored variable.
1228
+
1229
+ Will require that the distribution has a `mu` parameter
1230
+ and that it has not been set in the parameters.
1231
+
1232
+ Parameters
1233
+ ----------
1234
+ name : str
1235
+ The name of the variable.
1236
+ mu : pt.TensorLike
1237
+ The mu parameter for the likelihood.
1238
+ observed : pt.TensorLike
1239
+ The observed data.
1240
+
1241
+ Returns
1242
+ -------
1243
+ pt.TensorVariable
1244
+ The PyMC variable.
1245
+
1246
+ Examples
1247
+ --------
1248
+ Create a censored likelihood variable in a larger PyMC model.
1249
+
1250
+ .. code-block:: python
1251
+
1252
+ import pymc as pm
1253
+ from pymc_marketing.prior import Prior, Censored
1254
+
1255
+ normal = Prior("Normal", sigma=Prior("HalfNormal"))
1256
+ dist = Censored(normal, lower=0)
1257
+
1258
+ observed = 1
1259
+
1260
+ with pm.Model():
1261
+ # Create the likelihood variable
1262
+ mu = pm.HalfNormal("mu", sigma=1)
1263
+ dist.create_likelihood_variable("y", mu=mu, observed=observed)
1264
+
1265
+ """
1266
+ if "mu" not in _get_pymc_parameters (self .distribution .pymc_distribution ):
1267
+ raise UnsupportedDistributionError (
1268
+ f"Likelihood distribution { self .distribution .distribution !r} is not supported."
1269
+ )
1270
+
1271
+ if "mu" in self .distribution .parameters :
1272
+ raise MuAlreadyExistsError (self .distribution )
1273
+
1274
+ distribution = self .distribution .deepcopy ()
1275
+ distribution .parameters ["mu" ] = mu
1276
+
1277
+ dist = distribution .create_variable (name )
1278
+ _remove_random_variable (var = dist )
1279
+
1280
+ return pm .Censored (
1281
+ name ,
1282
+ dist ,
1283
+ observed = observed ,
1284
+ lower = self .lower ,
1285
+ upper = self .upper ,
1286
+ dims = self .dims ,
1287
+ )
1288
+
1289
+
1028
1290
def _is_prior_type (data : dict ) -> bool :
1029
1291
return "dist" in data
1030
1292
1031
1293
1294
+ def _is_censored_type (data : dict ) -> bool :
1295
+ return data .keys () == {"class" , "data" } and data ["class" ] == "Censored"
1296
+
1297
+
1032
1298
register_deserialization (is_type = _is_prior_type , deserialize = Prior .from_json )
1299
+ register_deserialization (is_type = _is_censored_type , deserialize = Censored .from_dict )
0 commit comments