Skip to content

Commit acccf16

Browse files
authored
Merge pull request #152 from pymc-labs/input-validation
2 parents 7f2be73 + 0752be5 commit acccf16

File tree

7 files changed

+340
-46
lines changed

7 files changed

+340
-46
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# See https://pre-commit.com/hooks.html for more hooks
33
repos:
44
- repo: https://github.com/pre-commit/pre-commit-hooks
5-
rev: v4.3.0
5+
rev: v4.4.0
66
hooks:
77
- id: trailing-whitespace
88
exclude_types: [svg]
@@ -11,21 +11,21 @@ repos:
1111
- id: check-yaml
1212
- id: check-added-large-files
1313
- repo: https://github.com/pycqa/isort
14-
rev: 5.11.2
14+
rev: 5.11.4
1515
hooks:
1616
- id: isort
1717
args: [--profile, black]
1818
types: [python]
1919
- repo: https://github.com/ambv/black
20-
rev: 22.10.0
20+
rev: 22.12.0
2121
hooks:
2222
- id: black
2323
- repo: https://github.com/pycqa/flake8
24-
rev: 3.9.2
24+
rev: 6.0.0
2525
hooks:
2626
- id: flake8
2727
- repo: https://github.com/nbQA-dev/nbQA
28-
rev: 1.5.3
28+
rev: 1.6.1
2929
hooks:
3030
- id: nbqa-black
3131
# additional_dependencies: [jupytext] # optional, only if you're using Jupytext

causalpy/custom_exceptions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
class BadIndexException(Exception):
2+
"""Custom exception used when we have a mismatch in types between the dataframe
3+
index and an event, typically a treatment or intervention."""
4+
5+
def __init__(self, message):
6+
self.message = message
7+
8+
9+
class FormulaException(Exception):
10+
"""Exception raised given when there is some error in a user-provided model
11+
formula"""
12+
13+
def __init__(self, message):
14+
self.message = message
15+
16+
17+
class DataException(Exception):
18+
"""Exception raised given when there is some error in user-provided dataframe"""
19+
20+
def __init__(self, message):
21+
self.message = message

causalpy/pymc_experiments.py

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
import arviz as az
24
import matplotlib.pyplot as plt
35
import numpy as np
@@ -6,7 +8,10 @@
68
import xarray as xr
79
from patsy import build_design_matrices, dmatrices
810

11+
from causalpy.custom_exceptions import BadIndexException # NOQA
12+
from causalpy.custom_exceptions import DataException, FormulaException
913
from causalpy.plot_utils import plot_xY
14+
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
1015

1116
LEGEND_FONT_SIZE = 12
1217
az.style.use("arviz-darkgrid")
@@ -54,12 +59,14 @@ class TimeSeriesExperiment(ExperimentalDesign):
5459
def __init__(
5560
self,
5661
data: pd.DataFrame,
57-
treatment_time: int,
62+
treatment_time: Union[int, float, pd.Timestamp],
5863
formula: str,
5964
model=None,
6065
**kwargs,
6166
) -> None:
6267
super().__init__(model=model, **kwargs)
68+
self._input_validation(data, treatment_time)
69+
6370
self.treatment_time = treatment_time
6471
# split data in to pre and post intervention
6572
self.datapre = data[data.index <= self.treatment_time]
@@ -111,6 +118,21 @@ def __init__(
111118
# cumulative impact post
112119
self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")
113120

121+
def _input_validation(self, data, treatment_time):
122+
"""Validate the input data and model formula for correctness"""
123+
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
124+
treatment_time, pd.Timestamp
125+
):
126+
raise BadIndexException(
127+
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
128+
)
129+
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
130+
treatment_time, pd.Timestamp
131+
):
132+
raise BadIndexException(
133+
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
134+
)
135+
114136
def plot(self):
115137

116138
"""Plot the results"""
@@ -263,36 +285,15 @@ def __init__(
263285
self.formula = formula
264286
self.time_variable_name = time_variable_name
265287
self.group_variable_name = group_variable_name
288+
self._input_validation()
289+
266290
y, X = dmatrices(formula, self.data)
267291
self._y_design_info = y.design_info
268292
self._x_design_info = X.design_info
269293
self.labels = X.design_info.column_names
270294
self.y, self.X = np.asarray(y), np.asarray(X)
271295
self.outcome_variable_name = y.design_info.column_names[0]
272296

273-
# Input validation ----------------------------------------------------
274-
assert (
275-
"post_treatment" in formula
276-
), "A predictor called `post_treatment` should be in the dataframe"
277-
assert (
278-
"post_treatment" in self.data.columns
279-
), "Require a boolean column labelling observations which are `treated`"
280-
# Check for `unit` in the incoming dataframe.
281-
# *This is only used for plotting purposes*
282-
assert (
283-
"unit" in self.data.columns
284-
), """
285-
Require a `unit` column to label unique units.
286-
This is used for plotting purposes
287-
"""
288-
# Check that `group_variable_name` is dummy coded. It should be 0 or 1
289-
assert not set(self.data[self.group_variable_name]).difference(
290-
set([0, 1])
291-
), f"""
292-
The grouping variable {self.group_variable_name} should be dummy coded.
293-
Consisting of 0's and 1's only.
294-
"""
295-
296297
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
297298
self.model.fit(X=self.X, y=self.y, coords=COORDS)
298299

@@ -361,6 +362,29 @@ def __init__(
361362
if "post_treatment" in label and self.group_variable_name in label:
362363
self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
363364

365+
def _input_validation(self):
366+
"""Validate the input data and model formula for correctness"""
367+
if "post_treatment" not in self.formula:
368+
raise FormulaException(
369+
"A predictor called `post_treatment` should be in the formula"
370+
)
371+
372+
if "post_treatment" not in self.data.columns:
373+
raise DataException(
374+
"Require a boolean column labelling observations which are `treated`"
375+
)
376+
377+
if "unit" not in self.data.columns:
378+
raise DataException(
379+
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
380+
)
381+
382+
if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
383+
raise DataException(
384+
f"""The grouping variable {self.group_variable_name} should be dummy
385+
coded. Consisting of 0's and 1's only."""
386+
)
387+
364388
def plot(self):
365389
"""Plot the results.
366390
Creating the combined mean + HDI legend entries is a bit involved.
@@ -536,16 +560,15 @@ def __init__(
536560
self.formula = formula
537561
self.running_variable_name = running_variable_name
538562
self.treatment_threshold = treatment_threshold
563+
self._input_validation()
564+
539565
y, X = dmatrices(formula, self.data)
540566
self._y_design_info = y.design_info
541567
self._x_design_info = X.design_info
542568
self.labels = X.design_info.column_names
543569
self.y, self.X = np.asarray(y), np.asarray(X)
544570
self.outcome_variable_name = y.design_info.column_names[0]
545571

546-
# TODO: `treated` is a deterministic function of x and treatment_threshold, so
547-
# this could be a function rather than supplied data
548-
549572
# DEVIATION FROM SKL EXPERIMENT CODE =============================
550573
# fit the model to the observed (pre-intervention) data
551574
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
@@ -586,6 +609,18 @@ def __init__(
586609
- self.pred_discon["posterior_predictive"].sel(obs_ind=0)["mu"]
587610
)
588611

612+
def _input_validation(self):
613+
"""Validate the input data and model formula for correctness"""
614+
if "treated" not in self.formula:
615+
raise FormulaException(
616+
"A predictor called `treated` should be in the formula"
617+
)
618+
619+
if _is_variable_dummy_coded(self.data["treated"]) is False:
620+
raise DataException(
621+
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
622+
)
623+
589624
def _is_treated(self, x):
590625
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
591626
@@ -673,6 +708,7 @@ def __init__(
673708
self.formula = formula
674709
self.group_variable_name = group_variable_name
675710
self.pretreatment_variable_name = pretreatment_variable_name
711+
self._input_validation()
676712

677713
y, X = dmatrices(formula, self.data)
678714
self._y_design_info = y.design_info
@@ -681,17 +717,6 @@ def __init__(
681717
self.y, self.X = np.asarray(y), np.asarray(X)
682718
self.outcome_variable_name = y.design_info.column_names[0]
683719

684-
# Input validation ----------------------------------------------------
685-
# Check that `group_variable_name` has TWO levels, representing the
686-
# treated/untreated. But it does not matter what the actual names of
687-
# the levels are.
688-
assert (
689-
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
690-
), f"""
691-
There must be 2 levels of the grouping variable {self.group_variable_name}
692-
.I.e. the treated and untreated.
693-
"""
694-
695720
# fit the model to the observed (pre-intervention) data
696721
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
697722
self.model.fit(X=self.X, y=self.y, coords=COORDS)
@@ -730,6 +755,16 @@ def __init__(
730755

731756
# ================================================================
732757

758+
def _input_validation(self):
759+
"""Validate the input data and model formula for correctness"""
760+
if not _series_has_2_levels(self.data[self.group_variable_name]):
761+
raise DataException(
762+
f"""
763+
There must be 2 levels of the grouping variable
764+
{self.group_variable_name}. I.e. the treated and untreated.
765+
"""
766+
)
767+
733768
def plot(self):
734769
"""Plot the results"""
735770
fig, ax = plt.subplots(

0 commit comments

Comments
 (0)