Skip to content

Commit 5c07e4d

Browse files
committed
change data validation mixins to class methods
1 parent 133ee3b commit 5c07e4d

11 files changed

+157
-193
lines changed

causalpy/data_validation.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

causalpy/exp_inverse_propensity_weighting.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
from patsy import dmatrices
2424
from sklearn.linear_model import LinearRegression as sk_lin_reg
2525

26-
from causalpy.data_validation import PropensityDataValidator
26+
from causalpy.custom_exceptions import DataException
2727
from causalpy.experiments import ExperimentalDesign
2828

2929
# from causalpy.pymc_models import PyMCModel
3030
# from causalpy.utils import round_num
3131

3232

33-
class InversePropensityWeighting(ExperimentalDesign, PropensityDataValidator):
33+
class InversePropensityWeighting(ExperimentalDesign):
3434
"""
3535
A class to analyse inverse propensity weighting experiments.
3636
@@ -97,6 +97,28 @@ def __init__(
9797
self.coords = COORDS
9898
self.model.fit(X=self.X, t=self.t, coords=COORDS)
9999

100+
def _input_validation(self):
101+
"""Validate the input data and model formula for correctness"""
102+
treatment = self.formula.split("~")[0]
103+
test = treatment.strip() in self.data.columns
104+
test = test & (self.outcome_variable in self.data.columns)
105+
if not test:
106+
raise DataException(
107+
f"""
108+
The treatment variable:
109+
{treatment} must appear in the data to be used
110+
as an outcome variable. And {self.outcome_variable}
111+
must also be available in the data to be re-weighted
112+
"""
113+
)
114+
T = self.data[treatment.strip()]
115+
check_binary = len(np.unique(T)) > 2
116+
if check_binary:
117+
raise DataException(
118+
"""Warning. The treatment variable is not 0-1 Binary.
119+
"""
120+
)
121+
100122
def make_robust_adjustments(self, ps):
101123
"""This estimator is discussed in Aronow
102124
and Miller's book as being related to the

causalpy/expt_diff_in_diff.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
import pandas as pd
2020
from patsy import build_design_matrices, dmatrices
2121

22-
from causalpy.data_validation import DiDDataValidator
22+
from causalpy.custom_exceptions import (
23+
DataException,
24+
FormulaException,
25+
)
2326
from causalpy.experiments import ExperimentalDesign
2427
from causalpy.pymc_models import PyMCModel
2528
from causalpy.skl_models import ScikitLearnModel
26-
from causalpy.utils import convert_to_string
29+
from causalpy.utils import _is_variable_dummy_coded, convert_to_string
2730

2831

29-
class DifferenceInDifferences(ExperimentalDesign, DiDDataValidator):
32+
class DifferenceInDifferences(ExperimentalDesign):
3033
"""A class to analyse data from Difference in Difference settings.
3134
3235
.. note::
@@ -174,6 +177,29 @@ def __init__(
174177
else:
175178
raise ValueError("Model type not recognized")
176179

180+
def _input_validation(self):
181+
"""Validate the input data and model formula for correctness"""
182+
if "post_treatment" not in self.formula:
183+
raise FormulaException(
184+
"A predictor called `post_treatment` should be in the formula"
185+
)
186+
187+
if "post_treatment" not in self.data.columns:
188+
raise DataException(
189+
"Require a boolean column labelling observations which are `treated`"
190+
)
191+
192+
if "unit" not in self.data.columns:
193+
raise DataException(
194+
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
195+
)
196+
197+
if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
198+
raise DataException(
199+
f"""The grouping variable {self.group_variable_name} should be dummy
200+
coded. Consisting of 0's and 1's only."""
201+
)
202+
177203
def plot(self, round_to=None):
178204
"""
179205
Plot the results

causalpy/expt_instrumental_variable.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,21 @@
1515
Instrumental variable regression
1616
"""
1717

18+
import warnings # noqa: I001
19+
1820
import numpy as np
1921
import pandas as pd
2022
from patsy import dmatrices
2123
from sklearn.linear_model import LinearRegression as sk_lin_reg
2224

23-
from causalpy.data_validation import IVDataValidator
25+
from causalpy.custom_exceptions import DataException
2426
from causalpy.experiments import ExperimentalDesign
2527

2628
# from causalpy.pymc_models import PyMCModel
2729
# from causalpy.utils import round_num
2830

2931

30-
class InstrumentalVariable(ExperimentalDesign, IVDataValidator):
32+
class InstrumentalVariable(ExperimentalDesign):
3133
"""
3234
A class to analyse instrumental variable style experiments.
3335
@@ -140,6 +142,29 @@ def __init__(
140142
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
141143
)
142144

145+
def _input_validation(self):
146+
"""Validate the input data and model formula for correctness"""
147+
treatment = self.instruments_formula.split("~")[0]
148+
test = treatment.strip() in self.instruments_data.columns
149+
test = test & (treatment.strip() in self.data.columns)
150+
if not test:
151+
raise DataException(
152+
f"""
153+
The treatment variable:
154+
{treatment} must appear in the instrument_data to be used
155+
as an outcome variable and in the data object to be used as a covariate.
156+
"""
157+
)
158+
Z = self.data[treatment.strip()]
159+
check_binary = len(np.unique(Z)) > 2
160+
if check_binary:
161+
warnings.warn(
162+
"""Warning. The treatment variable is not Binary.
163+
This is not necessarily a problem but it violates
164+
the assumption of a simple IV experiment.
165+
The coefficients should be interpreted appropriately."""
166+
)
167+
143168
def get_2SLS_fit(self):
144169
"""
145170
Two Stage Least Squares Fit

causalpy/expt_prepostfit.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
import pandas as pd
2222
from patsy import build_design_matrices, dmatrices
2323

24-
from causalpy.data_validation import PrePostFitDataValidator
24+
from causalpy.custom_exceptions import BadIndexException
2525
from causalpy.experiments import ExperimentalDesign
2626
from causalpy.pymc_models import PyMCModel
2727
from causalpy.skl_models import ScikitLearnModel
2828

2929

30-
class PrePostFit(ExperimentalDesign, PrePostFitDataValidator):
30+
class PrePostFit(ExperimentalDesign):
3131
"""
3232
A base class for quasi-experimental designs where parameter estimation is based on
3333
just pre-intervention data. This class is not directly invoked by the user.
@@ -91,6 +91,21 @@ def __init__(
9191
self.post_impact
9292
)
9393

94+
def _input_validation(self, data, treatment_time):
95+
"""Validate the input data and model formula for correctness"""
96+
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
97+
treatment_time, pd.Timestamp
98+
):
99+
raise BadIndexException(
100+
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
101+
)
102+
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
103+
treatment_time, pd.Timestamp
104+
):
105+
raise BadIndexException(
106+
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
107+
)
108+
94109
def plot(self):
95110
"""
96111
Plot the results

0 commit comments

Comments
 (0)