Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions causalpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@
"RegressionKink",
"skl_models",
"SyntheticControl",
"variable_selection_priors",
]
65 changes: 53 additions & 12 deletions causalpy/experiments/instrumental_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ class InstrumentalVariable(BaseExperiment):
If priors are not specified we will substitute MLE estimates for
the beta coefficients. Example: ``priors = {"mus": [0, 0],
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
vs_prior_type : str or None, default=None
Type of variable selection prior: 'spike_and_slab', 'horseshoe', or None.
If None, uses standard normal priors.
vs_hyperparams : dict, optional
Hyperparameters for variable selection priors. Only used if vs_prior_type
is not None.
binary_treatment : bool, default=False
A indicator for whether the treatment to be modelled is binary or not.
Determines which PyMC model we use to model the joint outcome and
treatment.

Example
--------
Expand Down Expand Up @@ -85,6 +95,16 @@ class InstrumentalVariable(BaseExperiment):
... formula=formula,
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
... )
>>> # With variable selection
>>> iv = cp.InstrumentalVariable(
... instruments_data=instruments_data,
... data=data,
... instruments_formula=instruments_formula,
... formula=formula,
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
... vs_prior_type="spike_and_slab",
... vs_hyperparams={"slab_sigma": 5.0},
... )
"""

supports_ols = False
Expand All @@ -98,6 +118,9 @@ def __init__(
formula: str,
model: BaseExperiment | None = None,
priors: dict | None = None,
vs_prior_type=None,
vs_hyperparams=None,
binary_treatment=False,
**kwargs: dict,
) -> None:
super().__init__(model=model)
Expand All @@ -107,6 +130,9 @@ def __init__(
self.formula = formula
self.instruments_formula = instruments_formula
self.model = model
self.vs_prior_type = vs_prior_type
self.vs_hyperparams = vs_hyperparams or {}
self.binary_treatment = binary_treatment
self.input_validation()

y, X = dmatrices(formula, self.data)
Expand All @@ -130,15 +156,33 @@ def __init__(
COORDS = {"instruments": self.labels_instruments, "covariates": self.labels}
self.coords = COORDS
if priors is None:
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 1,
}
if binary_treatment:
# Different default priors for binary treatment
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"sigma_U": 1.0,
"rho_bounds": [-0.99, 0.99],
}
else:
# Original continuous treatment priors
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 1,
}
self.priors = priors
self.model.fit( # type: ignore[call-arg,union-attr]
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
X=self.X,
Z=self.Z,
y=self.y,
t=self.t,
coords=COORDS,
priors=self.priors,
vs_prior_type=vs_prior_type,
vs_hyperparams=vs_hyperparams,
binary_treatment=self.binary_treatment,
)

def input_validation(self) -> None:
Expand All @@ -159,11 +203,8 @@ def input_validation(self) -> None:
if check_binary:
warnings.warn(
"""Warning. The treatment variable is not Binary.
This is not necessarily a problem but it violates
the assumption of a simple IV experiment.
The coefficients should be interpreted appropriately.""",
UserWarning,
stacklevel=2,
We will use the multivariate normal likelihood
for continuous treatment."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Validation warning ignores binary_treatment flag setting

The input_validation method checks if the treatment variable has more than 2 unique values and warns that "We will use the multivariate normal likelihood for continuous treatment." However, this warning doesn't account for the new binary_treatment parameter. If a user sets binary_treatment=True while having continuous treatment data, the warning incorrectly suggests MVN will be used, when actually the Bernoulli likelihood will be applied (which would fail on non-binary data). The validation needs to cross-check the actual data against the self.binary_treatment flag.

Fix in Cursor Fix in Web

)

def get_2SLS_fit(self) -> None:
Expand Down
Loading
Loading