Skip to content

Commit 4312dc9

Browse files
committed
Resolve additional merge conflicts from remote updates
- Use Prior class create_likelihood_variable method for model definitions - Updated print_coefficients to handle both sigma and y_hat_sigma variables - Simplified PropensityScore to use Prior class exclusively - Deleted interrogate badge to be regenerated
2 parents 7565b7b + f51f994 commit 4312dc9

File tree

3 files changed

+44
-94
lines changed

3 files changed

+44
-94
lines changed

causalpy/pymc_models.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class PyMCModel(pm.Model):
9595
def default_priors(self):
9696
return {}
9797

98+
def priors_from_data(self, X, y) -> Dict[str, Any]:
99+
return {}
100+
98101
def __init__(
99102
self,
100103
sample_kwargs: Optional[Dict[str, Any]] = None,
@@ -155,6 +158,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
155158
# sample_posterior_predictive() if provided in sample_kwargs.
156159
random_seed = self.sample_kwargs.get("random_seed", None)
157160

161+
self.priors = {**self.priors_from_data(X, y), **self.priors}
162+
158163
self.build_model(X, y, coords)
159164
with self:
160165
self.idata = pm.sample(**self.sample_kwargs)
@@ -250,26 +255,34 @@ def print_coefficients_for_unit(
250255
) -> None:
251256
"""Print coefficients for a single unit"""
252257
# Determine the width of the longest label
253-
max_label_length = max(len(name) for name in labels + ["sigma"])
258+
max_label_length = max(len(name) for name in labels + ["y_hat_sigma"])
254259

255260
for name in labels:
256261
coeff_samples = unit_coeffs.sel(coeffs=name)
257262
print_row(max_label_length, name, coeff_samples, round_to)
258263

259264
# Add coefficient for measurement std
260-
print_row(max_label_length, "sigma", unit_sigma, round_to)
265+
print_row(max_label_length, "y_hat_sigma", unit_sigma, round_to)
261266

262267
print("Model coefficients:")
263268
coeffs = az.extract(self.idata.posterior, var_names="beta")
264269

265-
# Always has treated_units dimension - no branching needed!
270+
# Check if sigma or y_hat_sigma variable exists
271+
sigma_var_name = None
272+
if "sigma" in self.idata.posterior:
273+
sigma_var_name = "sigma"
274+
elif "y_hat_sigma" in self.idata.posterior:
275+
sigma_var_name = "y_hat_sigma"
276+
else:
277+
raise ValueError("Neither 'sigma' nor 'y_hat_sigma' found in posterior")
278+
266279
treated_units = coeffs.coords["treated_units"].values
267280
for unit in treated_units:
268281
if len(treated_units) > 1:
269282
print(f"\nTreated unit: {unit}")
270283

271284
unit_coeffs = coeffs.sel(treated_units=unit)
272-
unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel(
285+
unit_sigma = az.extract(self.idata.posterior, var_names=sigma_var_name).sel(
273286
treated_units=unit
274287
)
275288
print_coefficients_for_unit(unit_coeffs, unit_sigma, labels, round_to or 2)
@@ -314,7 +327,11 @@ class LinearRegression(PyMCModel):
314327

315328
default_priors = {
316329
"beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]),
317-
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"),
330+
"y_hat": Prior(
331+
"Normal",
332+
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
333+
dims=["obs_ind", "treated_units"],
334+
),
318335
}
319336

320337
def build_model(self, X, y, coords):
@@ -331,11 +348,10 @@ def build_model(self, X, y, coords):
331348
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
332349
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
333350
beta = self.priors["beta"].create_variable("beta")
334-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
335351
mu = pm.Deterministic(
336352
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
337353
)
338-
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
354+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
339355

340356

341357
class WeightedSumFitter(PyMCModel):
@@ -379,26 +395,34 @@ class WeightedSumFitter(PyMCModel):
379395
""" # noqa: W605
380396

381397
default_priors = {
382-
"y_hat": Prior("Normal", sigma=Prior("HalfNormal", sigma=1), dims="obs_ind"),
398+
"y_hat": Prior(
399+
"Normal",
400+
sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]),
401+
dims=["obs_ind", "treated_units"],
402+
),
383403
}
384404

405+
def priors_from_data(self, X, y) -> Dict[str, Any]:
406+
n_predictors = X.shape[1]
407+
return {
408+
"beta": Prior(
409+
"Dirichlet", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
410+
),
411+
}
412+
385413
def build_model(self, X, y, coords):
386414
"""
387415
Defines the PyMC model
388416
"""
389417
with self:
390418
self.add_coords(coords)
391-
n_predictors = X.sizes["coeffs"]
392419
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
393420
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
394-
beta = pm.Dirichlet(
395-
"beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
396-
)
397-
sigma = pm.HalfNormal("sigma", 1, dims="treated_units")
421+
beta = self.priors["beta"].create_variable("beta")
398422
mu = pm.Deterministic(
399423
"mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"]
400424
)
401-
pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"])
425+
self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y)
402426

403427

404428
class InstrumentalVariableRegression(PyMCModel):
@@ -598,24 +622,8 @@ def build_model(self, X, t, coords, prior=None, noncentred=True):
598622
self.add_coords(coords)
599623
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
600624
t_data = pm.Data("t", t.flatten(), dims="obs_ind")
601-
602-
if prior is not None:
603-
# Use legacy interface for backward compatibility
604-
if noncentred:
605-
mu_beta, sigma_beta = prior["b"]
606-
beta_std = pm.Normal("beta_std", 0, 1, dims="coeffs")
607-
b = pm.Deterministic(
608-
"beta_", mu_beta + sigma_beta * beta_std, dims="coeffs"
609-
)
610-
else:
611-
b = pm.Normal(
612-
"b", mu=prior["b"][0], sigma=prior["b"][1], dims="coeffs"
613-
)
614-
else:
615-
# Use Prior class
616-
b = self.priors["b"].create_variable("b")
617-
618-
mu = pm.math.dot(X_data, b)
625+
b = self.priors["b"].create_variable("b")
626+
mu = pt.dot(X_data, b)
619627
p = pm.Deterministic("p", pm.math.invlogit(mu))
620628
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")
621629

causalpy/tests/test_pymc_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def build_model(self, X, y, coords):
4545
X_ = pm.Data(name="X", value=X, dims=["obs_ind", "coeffs"])
4646
y_ = pm.Data(name="y", value=y, dims=["obs_ind", "treated_units"])
4747
beta = pm.Normal("beta", mu=0, sigma=1, dims=["treated_units", "coeffs"])
48-
sigma = pm.HalfNormal("sigma", sigma=1, dims="treated_units")
48+
sigma = pm.HalfNormal("y_hat_sigma", sigma=1, dims="treated_units")
4949
mu = pm.Deterministic(
5050
"mu", pm.math.dot(X_, beta.T), dims=["obs_ind", "treated_units"]
5151
)
@@ -159,7 +159,7 @@ def test_fit_predict(self, coords, rng, mock_pymc_sample) -> None:
159159
2,
160160
2 * 2,
161161
) # (treated_units, coeffs, sample)
162-
assert az.extract(data=model.idata, var_names=["sigma"]).shape == (
162+
assert az.extract(data=model.idata, var_names=["y_hat_sigma"]).shape == (
163163
1,
164164
2 * 2,
165165
) # (treated_units, sample)
@@ -402,7 +402,7 @@ def test_multi_unit_coefficients(self, synthetic_control_data):
402402

403403
# Extract coefficients
404404
beta = az.extract(wsf.idata.posterior, var_names="beta")
405-
sigma = az.extract(wsf.idata.posterior, var_names="sigma")
405+
sigma = az.extract(wsf.idata.posterior, var_names="y_hat_sigma")
406406

407407
# Check beta dimensions: should be (sample, treated_units, coeffs)
408408
assert "treated_units" in beta.dims
@@ -461,7 +461,7 @@ def test_print_coefficients_multi_unit(self, synthetic_control_data, capsys):
461461
assert control in output
462462

463463
# Check that sigma is printed for each unit
464-
assert output.count("sigma") == len(treated_units)
464+
assert output.count("y_hat_sigma") == len(treated_units)
465465

466466
def test_scoring_multi_unit(self, synthetic_control_data):
467467
"""Test that scoring works with multiple treated units."""

docs/source/_static/interrogate_badge.svg

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)