Skip to content

Commit 3df6534

Browse files
Refactor likelihood, add small example
1 parent ab3b4ed commit 3df6534

File tree

5 files changed

+74
-198
lines changed

5 files changed

+74
-198
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pymc_experimental.model.modular.components import Intercept, Regression, Spline
2+
from pymc_experimental.model.modular.likelihood import NormalLikelihood
3+
4+
__all__ = [
5+
"Intercept",
6+
"Regression",
7+
"Spline",
8+
"NormalLikelihood",
9+
]

pymc_experimental/model/modular/components.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import pymc as pm
55
import pytensor.tensor as pt
66

7-
from model.modular.utilities import (
7+
from patsy import dmatrix
8+
from pytensor.graph import Apply, Op
9+
10+
from pymc_experimental.model.modular.utilities import (
811
PRIOR_DEFAULT_KWARGS,
912
ColumnType,
1013
PoolingType,
@@ -13,14 +16,12 @@
1316
make_hierarchical_prior,
1417
select_data_columns,
1518
)
16-
from patsy import dmatrix
17-
from pytensor.graph import Apply, Op
1819

1920

2021
class GLMModel(ABC):
2122
"""Base class for GLM components. Subclasses should implement the build method to construct the component."""
2223

23-
def __init__(self, name):
24+
def __init__(self, name=None):
2425
self.model = None
2526
self.compiled = False
2627
self.name = name
Lines changed: 29 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Sequence
3+
from io import StringIO
34
from typing import Literal, get_args
45

56
import arviz as az
67
import pandas as pd
78
import pymc as pm
89
import pytensor.tensor as pt
10+
import rich
911

1012
from pymc.backends.arviz import apply_function_over_dataset
1113
from pymc.model.fgraph import clone_model
@@ -14,6 +16,7 @@
1416

1517
from pymc_experimental.model.marginal.marginal_model import MarginalModel
1618
from pymc_experimental.model.modular.utilities import ColumnType, encode_categoricals
19+
from pymc_experimental.printing import model_table
1720

1821
LIKELIHOOD_TYPES = Literal["lognormal", "logt", "mixture", "unmarginalized-mixture"]
1922
valid_likelihoods = get_args(LIKELIHOOD_TYPES)
@@ -43,7 +46,7 @@ def __init__(self, target_col: ColumnType, data: pd.DataFrame):
4346

4447
X_df = data.drop(columns=[target_col])
4548

46-
self.obs_dim = data.index.name
49+
self.obs_dim = data.index.name if data.index.name is not None else "obs_idx"
4750
self.coords = {
4851
self.obs_dim: data.index.values,
4952
}
@@ -70,6 +73,10 @@ def sample(self, **sample_kwargs):
7073
with self.model:
7174
return pm.sample(**sample_kwargs)
7275

76+
def sample_prior_predictive(self, **sample_kwargs):
77+
with self.model:
78+
return pm.sample_prior_predictive(**sample_kwargs)
79+
7380
def predict(
7481
self,
7582
idata: az.InferenceData,
@@ -137,212 +144,53 @@ def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalMo
137144
"""Return the type on model used by the likelihood function"""
138145
raise NotImplementedError
139146

140-
def register_mu(
141-
self,
142-
*,
143-
df: pd.DataFrame,
144-
mu=None,
145-
):
147+
def register_mu(self, mu=None):
146148
with self.model:
147149
if mu is not None:
148-
return pm.Deterministic("mu", mu.build(df=df), dims=[self.obs_dim])
150+
return pm.Deterministic("mu", mu.build(self.model), dims=[self.obs_dim])
149151
return pm.Normal("mu", 0, 100)
150152

151-
def register_sigma(
152-
self,
153-
*,
154-
df: pd.DataFrame,
155-
sigma=None,
156-
):
153+
def register_sigma(self, sigma=None):
157154
with self.model:
158155
if sigma is not None:
159-
return pm.Deterministic("sigma", pt.exp(sigma.build(df=df)), dims=[self.obs_dim])
160-
return pm.Exponential("sigma", lam=1)
161-
162-
163-
class LogNormalLikelihood(Likelihood):
164-
"""Class to represent a log-normal likelihood function for a GLM component."""
165-
166-
def __init__(
167-
self,
168-
mu,
169-
sigma,
170-
target_col: ColumnType,
171-
data: pd.DataFrame,
172-
):
173-
super().__init__(target_col=target_col, data=data)
174-
175-
with self.model:
176-
self.register_data(data[target_col])
177-
mu = self.register_mu(mu)
178-
sigma = self.register_sigma(sigma)
179-
180-
pm.LogNormal(
181-
target_col,
182-
mu=mu,
183-
sigma=sigma,
184-
observed=self.model[f"{target_col}_observed"],
185-
dims=[self.obs_dim],
186-
)
187-
188-
def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
189-
return pm.Model(coords=coords)
190-
191-
192-
class LogTLikelihood(Likelihood):
193-
"""
194-
Class to represent a log-t likelihood function for a GLM component.
195-
"""
196-
197-
def __init__(
198-
self,
199-
mu,
200-
*,
201-
sigma=None,
202-
nu=None,
203-
target_col: ColumnType,
204-
data: pd.DataFrame,
205-
):
206-
def log_student_t(nu, mu, sigma, shape=None):
207-
return pm.math.exp(pm.StudentT.dist(mu=mu, sigma=sigma, nu=nu, shape=shape))
208-
209-
super().__init__(target_col=target_col, data=data)
210-
211-
with self.model:
212-
mu = self.register_mu(mu=mu, df=data)
213-
sigma = self.register_sigma(sigma=sigma, df=data)
214-
nu = self.register_nu(nu=nu, df=data)
215-
216-
pm.CustomDist(
217-
target_col,
218-
nu,
219-
mu,
220-
sigma,
221-
observed=self.model[f"{target_col}_observed"],
222-
shape=mu.shape,
223-
dims=[self.obs_dim],
224-
dist=log_student_t,
225-
class_name="LogStudentT",
226-
)
227-
228-
def register_nu(self, *, df, nu=None):
229-
with self.model:
230-
if nu is not None:
231-
return pm.Deterministic("nu", pt.exp(nu.build(df=df)), dims=[self.obs_dim])
232-
return pm.Uniform("nu", 2, 30)
233-
234-
def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
235-
return pm.Model(coords=coords)
236-
237-
238-
class BaseMixtureLikelihood(Likelihood):
239-
"""
240-
Base class for mixture likelihood functions to hold common methods for registering parameters.
241-
"""
242-
243-
def register_sigma(self, *, df, sigma=None):
244-
with self.model:
245-
if sigma is None:
246-
sigma_not_outlier = pm.Exponential("sigma_not_outlier", lam=1)
247-
else:
248-
sigma_not_outlier = pm.Deterministic(
249-
"sigma_not_outlier", pt.exp(sigma.build(df=df)), dims=[self.obs_dim]
250-
)
251-
sigma_outlier_offset = pm.Gamma("sigma_outlier_offset", mu=0.2, sigma=0.5)
252-
sigma = pm.Deterministic(
253-
"sigma",
254-
pt.as_tensor([sigma_not_outlier, sigma_not_outlier * (1 + sigma_outlier_offset)]),
255-
dims=["outlier"],
256-
)
257-
258-
return sigma
259-
260-
def register_p_outlier(self, *, df, p_outlier=None, **param_kwargs):
261-
mean_p = param_kwargs.get("mean_p", 0.1)
262-
concentration = param_kwargs.get("concentration", 50)
263-
264-
with self.model:
265-
if p_outlier is not None:
266156
return pm.Deterministic(
267-
"p_outlier", pt.sigmoid(p_outlier.build(df=df)), dims=[self.obs_dim]
157+
"sigma", pt.exp(sigma.build(self.model)), dims=[self.obs_dim]
268158
)
269-
return pm.Beta("p_outlier", mean_p * concentration, (1 - mean_p) * concentration)
270-
271-
def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
272-
coords["outlier"] = [False, True]
273-
return MarginalModel(coords=coords)
274-
159+
return pm.Exponential("sigma", lam=1)
275160

276-
class MixtureLikelihood(BaseMixtureLikelihood):
277-
"""
278-
Class to represent a mixture likelihood function for a GLM component. The mixture is implemented using pm.Mixture,
279-
and does not allow for automatic marginalization of components.
280-
"""
161+
def __repr__(self):
162+
table = model_table(self.model)
163+
buffer = StringIO()
164+
rich.print(table, file=buffer)
281165

282-
def __init__(
283-
self,
284-
mu,
285-
sigma,
286-
p_outlier,
287-
target_col: ColumnType,
288-
data: pd.DataFrame,
289-
):
290-
super().__init__(target_col=target_col, data=data)
166+
return buffer.getvalue()
291167

292-
with self.model:
293-
mu = self.register_mu(mu)
294-
sigma = self.register_sigma(sigma)
295-
p_outlier = self.register_p_outlier(p_outlier)
168+
def to_graphviz(self):
169+
return self.model.to_graphviz()
296170

297-
pm.Mixture(
298-
target_col,
299-
w=[1 - p_outlier, p_outlier],
300-
comp_dists=pm.LogNormal.dist(mu[..., None], sigma=sigma.T),
301-
shape=mu.shape,
302-
observed=self.model[f"{target_col}_observed"],
303-
dims=[self.obs_dim],
304-
)
171+
# def _repr_html_(self):
172+
# return model_table(self.model)
305173

306174

307-
class UnmarginalizedMixtureLikelihood(BaseMixtureLikelihood):
175+
class NormalLikelihood(Likelihood):
308176
"""
309-
Class to represent an unmarginalized mixture likelihood function for a GLM component. The mixture is implemented using
310-
a MarginalModel, and allows for automatic marginalization of components.
177+
A model with normally distributed errors
311178
"""
312179

313-
def __init__(
314-
self,
315-
mu,
316-
sigma,
317-
p_outlier,
318-
target_col: ColumnType,
319-
data: pd.DataFrame,
320-
):
180+
def __init__(self, mu, sigma, target_col: ColumnType, data: pd.DataFrame):
321181
super().__init__(target_col=target_col, data=data)
322182

323183
with self.model:
324184
mu = self.register_mu(mu)
325185
sigma = self.register_sigma(sigma)
326-
p_outlier = self.register_p_outlier(p_outlier)
327-
328-
is_outlier = pm.Bernoulli(
329-
"is_outlier",
330-
p_outlier,
331-
dims=["cusip"],
332-
# shape=X_pt.shape[0], # Uncomment after https://github.com/pymc-devs/pymc-experimental/pull/304
333-
)
334186

335-
pm.LogNormal(
187+
pm.Normal(
336188
target_col,
337189
mu=mu,
338-
sigma=pm.math.switch(is_outlier, sigma[1], sigma[0]),
190+
sigma=sigma,
339191
observed=self.model[f"{target_col}_observed"],
340-
shape=mu.shape,
341-
dims=[data.index.name],
192+
dims=[self.obs_dim],
342193
)
343194

344-
self.model.marginalize(["is_outlier"])
345-
346195
def _get_model_class(self, coords: dict[str, Sequence]) -> pm.Model | MarginalModel:
347-
coords["outlier"] = [False, True]
348-
return MarginalModel(coords=coords)
196+
return pm.Model(coords=coords)

pymc_experimental/model/modular/utilities.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def select_data_columns(
5656
5757
Returns
5858
-------
59-
X: TensorVariable
6059
A tensor variable representing the selected columns of the independent data
6160
"""
6261
model = pm.modelcontext(model)
@@ -350,9 +349,6 @@ def make_unpooled_hierarchy(
350349
):
351350
coords = model.coords
352351

353-
sigma_dist = hierarchy_kwargs.pop("sigma_dist", "Gamma")
354-
sigma_kwargs = hierarchy_kwargs.pop("sigma_kwargs", {"alpha": 2, "beta": 1})
355-
356352
if X.ndim == 1:
357353
X = X[:, None]
358354

@@ -367,17 +363,8 @@ def make_unpooled_hierarchy(
367363
beta = Prior(f"{name}_mu", **prior_kwargs, dims=dims)
368364

369365
for i, (last_level, level) in enumerate(itertools.pairwise([None, *levels])):
370-
if i == 0:
371-
sigma_dims = dims
372-
else:
373-
sigma_dims = [*dims, last_level] if dims is not None else [last_level]
374366
beta_dims = [*dims, level] if dims is not None else [level]
375-
376-
sigma = make_sigma(f"{name}_{level}_effect", sigma_dist, sigma_kwargs, sigma_dims)
377-
378367
prior_kwargs["mu"] = beta[..., idx_maps[i]]
379-
scale_name = "b" if prior == "Laplace" else "sigma"
380-
prior_kwargs[scale_name] = sigma[..., idx_maps[i]]
381368

382369
beta = Prior(f"{name}_{level}_effect", **prior_kwargs, dims=beta_dims)
383370

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from pymc_experimental.model.modular.likelihood import NormalLikelihood
6+
7+
8+
@pytest.fixture(scope="session")
9+
def rng():
10+
return np.random.default_rng()
11+
12+
13+
@pytest.fixture(scope="session")
14+
def data(rng):
15+
city = ["A", "B", "C"]
16+
race = ["white", "black", "hispanic"]
17+
18+
df = pd.DataFrame(
19+
{
20+
"city": np.random.choice(city, 1000),
21+
"age": rng.normal(size=1000),
22+
"race": rng.choice(race, size=1000),
23+
"income": rng.normal(size=1000),
24+
}
25+
)
26+
return df
27+
28+
29+
def test_normal_likelihood(data):
30+
model = NormalLikelihood(mu=None, sigma=None, target_col="income", data=data)
31+
idata = model.sample_prior_predictive()

0 commit comments

Comments
 (0)