11from abc import ABC , abstractmethod
22from collections .abc import Sequence
3+ from io import StringIO
34from typing import Literal , get_args
45
56import arviz as az
67import pandas as pd
78import pymc as pm
89import pytensor .tensor as pt
10+ import rich
911
1012from pymc .backends .arviz import apply_function_over_dataset
1113from pymc .model .fgraph import clone_model
1416
1517from pymc_experimental .model .marginal .marginal_model import MarginalModel
1618from pymc_experimental .model .modular .utilities import ColumnType , encode_categoricals
19+ from pymc_experimental .printing import model_table
1720
1821LIKELIHOOD_TYPES = Literal ["lognormal" , "logt" , "mixture" , "unmarginalized-mixture" ]
1922valid_likelihoods = get_args (LIKELIHOOD_TYPES )
@@ -43,7 +46,7 @@ def __init__(self, target_col: ColumnType, data: pd.DataFrame):
4346
4447 X_df = data .drop (columns = [target_col ])
4548
46- self .obs_dim = data .index .name
49+ self .obs_dim = data .index .name if data . index . name is not None else "obs_idx"
4750 self .coords = {
4851 self .obs_dim : data .index .values ,
4952 }
@@ -70,6 +73,10 @@ def sample(self, **sample_kwargs):
7073 with self .model :
7174 return pm .sample (** sample_kwargs )
7275
76+ def sample_prior_predictive (self , ** sample_kwargs ):
77+ with self .model :
78+ return pm .sample_prior_predictive (** sample_kwargs )
79+
7380 def predict (
7481 self ,
7582 idata : az .InferenceData ,
@@ -137,212 +144,53 @@ def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalMo
137144 """Return the type on model used by the likelihood function"""
138145 raise NotImplementedError
139146
140- def register_mu (
141- self ,
142- * ,
143- df : pd .DataFrame ,
144- mu = None ,
145- ):
147+ def register_mu (self , mu = None ):
146148 with self .model :
147149 if mu is not None :
148- return pm .Deterministic ("mu" , mu .build (df = df ), dims = [self .obs_dim ])
150+ return pm .Deterministic ("mu" , mu .build (self . model ), dims = [self .obs_dim ])
149151 return pm .Normal ("mu" , 0 , 100 )
150152
151- def register_sigma (
152- self ,
153- * ,
154- df : pd .DataFrame ,
155- sigma = None ,
156- ):
153+ def register_sigma (self , sigma = None ):
157154 with self .model :
158155 if sigma is not None :
159- return pm .Deterministic ("sigma" , pt .exp (sigma .build (df = df )), dims = [self .obs_dim ])
160- return pm .Exponential ("sigma" , lam = 1 )
161-
162-
163- class LogNormalLikelihood (Likelihood ):
164- """Class to represent a log-normal likelihood function for a GLM component."""
165-
166- def __init__ (
167- self ,
168- mu ,
169- sigma ,
170- target_col : ColumnType ,
171- data : pd .DataFrame ,
172- ):
173- super ().__init__ (target_col = target_col , data = data )
174-
175- with self .model :
176- self .register_data (data [target_col ])
177- mu = self .register_mu (mu )
178- sigma = self .register_sigma (sigma )
179-
180- pm .LogNormal (
181- target_col ,
182- mu = mu ,
183- sigma = sigma ,
184- observed = self .model [f"{ target_col } _observed" ],
185- dims = [self .obs_dim ],
186- )
187-
188- def _get_model_class (self , coords : dict [str , Sequence ]) -> pm .Model | MarginalModel :
189- return pm .Model (coords = coords )
190-
191-
192- class LogTLikelihood (Likelihood ):
193- """
194- Class to represent a log-t likelihood function for a GLM component.
195- """
196-
197- def __init__ (
198- self ,
199- mu ,
200- * ,
201- sigma = None ,
202- nu = None ,
203- target_col : ColumnType ,
204- data : pd .DataFrame ,
205- ):
206- def log_student_t (nu , mu , sigma , shape = None ):
207- return pm .math .exp (pm .StudentT .dist (mu = mu , sigma = sigma , nu = nu , shape = shape ))
208-
209- super ().__init__ (target_col = target_col , data = data )
210-
211- with self .model :
212- mu = self .register_mu (mu = mu , df = data )
213- sigma = self .register_sigma (sigma = sigma , df = data )
214- nu = self .register_nu (nu = nu , df = data )
215-
216- pm .CustomDist (
217- target_col ,
218- nu ,
219- mu ,
220- sigma ,
221- observed = self .model [f"{ target_col } _observed" ],
222- shape = mu .shape ,
223- dims = [self .obs_dim ],
224- dist = log_student_t ,
225- class_name = "LogStudentT" ,
226- )
227-
228- def register_nu (self , * , df , nu = None ):
229- with self .model :
230- if nu is not None :
231- return pm .Deterministic ("nu" , pt .exp (nu .build (df = df )), dims = [self .obs_dim ])
232- return pm .Uniform ("nu" , 2 , 30 )
233-
234- def _get_model_class (self , coords : dict [str , Sequence ]) -> pm .Model | MarginalModel :
235- return pm .Model (coords = coords )
236-
237-
238- class BaseMixtureLikelihood (Likelihood ):
239- """
240- Base class for mixture likelihood functions to hold common methods for registering parameters.
241- """
242-
243- def register_sigma (self , * , df , sigma = None ):
244- with self .model :
245- if sigma is None :
246- sigma_not_outlier = pm .Exponential ("sigma_not_outlier" , lam = 1 )
247- else :
248- sigma_not_outlier = pm .Deterministic (
249- "sigma_not_outlier" , pt .exp (sigma .build (df = df )), dims = [self .obs_dim ]
250- )
251- sigma_outlier_offset = pm .Gamma ("sigma_outlier_offset" , mu = 0.2 , sigma = 0.5 )
252- sigma = pm .Deterministic (
253- "sigma" ,
254- pt .as_tensor ([sigma_not_outlier , sigma_not_outlier * (1 + sigma_outlier_offset )]),
255- dims = ["outlier" ],
256- )
257-
258- return sigma
259-
260- def register_p_outlier (self , * , df , p_outlier = None , ** param_kwargs ):
261- mean_p = param_kwargs .get ("mean_p" , 0.1 )
262- concentration = param_kwargs .get ("concentration" , 50 )
263-
264- with self .model :
265- if p_outlier is not None :
266156 return pm .Deterministic (
267- "p_outlier " , pt .sigmoid ( p_outlier .build (df = df )), dims = [self .obs_dim ]
157+ "sigma " , pt .exp ( sigma .build (self . model )), dims = [self .obs_dim ]
268158 )
269- return pm .Beta ("p_outlier" , mean_p * concentration , (1 - mean_p ) * concentration )
270-
271- def _get_model_class (self , coords : dict [str , Sequence ]) -> pm .Model | MarginalModel :
272- coords ["outlier" ] = [False , True ]
273- return MarginalModel (coords = coords )
274-
159+ return pm .Exponential ("sigma" , lam = 1 )
275160
276- class MixtureLikelihood (BaseMixtureLikelihood ):
277- """
278- Class to represent a mixture likelihood function for a GLM component. The mixture is implemented using pm.Mixture,
279- and does not allow for automatic marginalization of components.
280- """
161+ def __repr__ (self ):
162+ table = model_table (self .model )
163+ buffer = StringIO ()
164+ rich .print (table , file = buffer )
281165
282- def __init__ (
283- self ,
284- mu ,
285- sigma ,
286- p_outlier ,
287- target_col : ColumnType ,
288- data : pd .DataFrame ,
289- ):
290- super ().__init__ (target_col = target_col , data = data )
166+ return buffer .getvalue ()
291167
292- with self .model :
293- mu = self .register_mu (mu )
294- sigma = self .register_sigma (sigma )
295- p_outlier = self .register_p_outlier (p_outlier )
168+ def to_graphviz (self ):
169+ return self .model .to_graphviz ()
296170
297- pm .Mixture (
298- target_col ,
299- w = [1 - p_outlier , p_outlier ],
300- comp_dists = pm .LogNormal .dist (mu [..., None ], sigma = sigma .T ),
301- shape = mu .shape ,
302- observed = self .model [f"{ target_col } _observed" ],
303- dims = [self .obs_dim ],
304- )
171+ # def _repr_html_(self):
172+ # return model_table(self.model)
305173
306174
307- class UnmarginalizedMixtureLikelihood ( BaseMixtureLikelihood ):
175+ class NormalLikelihood ( Likelihood ):
308176 """
309- Class to represent an unmarginalized mixture likelihood function for a GLM component. The mixture is implemented using
310- a MarginalModel, and allows for automatic marginalization of components.
177+ A model with normally distributed errors
311178 """
312179
313- def __init__ (
314- self ,
315- mu ,
316- sigma ,
317- p_outlier ,
318- target_col : ColumnType ,
319- data : pd .DataFrame ,
320- ):
180+ def __init__ (self , mu , sigma , target_col : ColumnType , data : pd .DataFrame ):
321181 super ().__init__ (target_col = target_col , data = data )
322182
323183 with self .model :
324184 mu = self .register_mu (mu )
325185 sigma = self .register_sigma (sigma )
326- p_outlier = self .register_p_outlier (p_outlier )
327-
328- is_outlier = pm .Bernoulli (
329- "is_outlier" ,
330- p_outlier ,
331- dims = ["cusip" ],
332- # shape=X_pt.shape[0], # Uncomment after https://github.com/pymc-devs/pymc-experimental/pull/304
333- )
334186
335- pm .LogNormal (
187+ pm .Normal (
336188 target_col ,
337189 mu = mu ,
338- sigma = pm . math . switch ( is_outlier , sigma [ 1 ], sigma [ 0 ]) ,
190+ sigma = sigma ,
339191 observed = self .model [f"{ target_col } _observed" ],
340- shape = mu .shape ,
341- dims = [data .index .name ],
192+ dims = [self .obs_dim ],
342193 )
343194
344- self .model .marginalize (["is_outlier" ])
345-
346195 def _get_model_class (self , coords : dict [str , Sequence ]) -> pm .Model | MarginalModel :
347- coords ["outlier" ] = [False , True ]
348- return MarginalModel (coords = coords )
196+ return pm .Model (coords = coords )
0 commit comments