Skip to content

Commit 1ba64fb

Browse files
committed
extract input validation code into private methods
1 parent 93987a4 commit 1ba64fb

File tree

2 files changed

+54
-47
lines changed

2 files changed

+54
-47
lines changed

causalpy/pymc_experiments.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,7 @@ def __init__(
6262
**kwargs,
6363
) -> None:
6464
super().__init__(model=model, **kwargs)
65-
66-
# Input validation
67-
if isinstance(data.index, pd.DatetimeIndex):
68-
assert isinstance(
69-
treatment_time, pd.Timestamp
70-
), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
71-
else:
72-
assert (
73-
isinstance(treatment_time, pd.Timestamp) is False
74-
), "If treatment_time is pd.Timestamp, this only makese sense if data.index is DatetimeIndex." # noqa: E501
65+
self._input_validation(data, treatment_time)
7566

7667
self.treatment_time = treatment_time
7768
# split data in to pre and post intervention
@@ -124,6 +115,17 @@ def __init__(
124115
# cumulative impact post
125116
self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")
126117

118+
def _input_validation(self, data, treatment_time):
119+
"""Validate the input data for correctness"""
120+
if isinstance(data.index, pd.DatetimeIndex):
121+
assert isinstance(
122+
treatment_time, pd.Timestamp
123+
), "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
124+
else:
125+
assert (
126+
isinstance(treatment_time, pd.Timestamp) is False
127+
), "If treatment_time is pd.Timestamp, this only makese sense if data.index is DatetimeIndex." # noqa: E501
128+
127129
def plot(self):
128130

129131
"""Plot the results"""
@@ -276,36 +278,15 @@ def __init__(
276278
self.formula = formula
277279
self.time_variable_name = time_variable_name
278280
self.group_variable_name = group_variable_name
281+
self._input_validation()
282+
279283
y, X = dmatrices(formula, self.data)
280284
self._y_design_info = y.design_info
281285
self._x_design_info = X.design_info
282286
self.labels = X.design_info.column_names
283287
self.y, self.X = np.asarray(y), np.asarray(X)
284288
self.outcome_variable_name = y.design_info.column_names[0]
285289

286-
# Input validation ----------------------------------------------------
287-
assert (
288-
"post_treatment" in formula
289-
), "A predictor called `post_treatment` should be in the dataframe"
290-
assert (
291-
"post_treatment" in self.data.columns
292-
), "Require a boolean column labelling observations which are `treated`"
293-
# Check for `unit` in the incoming dataframe.
294-
# *This is only used for plotting purposes*
295-
assert (
296-
"unit" in self.data.columns
297-
), """
298-
Require a `unit` column to label unique units.
299-
This is used for plotting purposes
300-
"""
301-
# Check that `group_variable_name` is dummy coded. It should be 0 or 1
302-
assert not set(self.data[self.group_variable_name]).difference(
303-
set([0, 1])
304-
), f"""
305-
The grouping variable {self.group_variable_name} should be dummy coded.
306-
Consisting of 0's and 1's only.
307-
"""
308-
309290
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
310291
self.model.fit(X=self.X, y=self.y, coords=COORDS)
311292

@@ -374,6 +355,30 @@ def __init__(
374355
if "post_treatment" in label and self.group_variable_name in label:
375356
self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
376357

358+
def _input_validation(self):
359+
"""Validate the input data for correctness"""
360+
assert (
361+
"post_treatment" in self.formula
362+
), "A predictor called `post_treatment` should be in the dataframe"
363+
assert (
364+
"post_treatment" in self.data.columns
365+
), "Require a boolean column labelling observations which are `treated`"
366+
# Check for `unit` in the incoming dataframe.
367+
# *This is only used for plotting purposes*
368+
assert (
369+
"unit" in self.data.columns
370+
), """
371+
Require a `unit` column to label unique units.
372+
This is used for plotting purposes
373+
"""
374+
# Check that `group_variable_name` is dummy coded. It should be 0 or 1
375+
assert not set(self.data[self.group_variable_name]).difference(
376+
set([0, 1])
377+
), f"""
378+
The grouping variable {self.group_variable_name} should be dummy coded.
379+
Consisting of 0's and 1's only.
380+
"""
381+
377382
def plot(self):
378383
"""Plot the results.
379384
Creating the combined mean + HDI legend entries is a bit involved.
@@ -686,6 +691,7 @@ def __init__(
686691
self.formula = formula
687692
self.group_variable_name = group_variable_name
688693
self.pretreatment_variable_name = pretreatment_variable_name
694+
self._input_validation()
689695

690696
y, X = dmatrices(formula, self.data)
691697
self._y_design_info = y.design_info
@@ -694,17 +700,6 @@ def __init__(
694700
self.y, self.X = np.asarray(y), np.asarray(X)
695701
self.outcome_variable_name = y.design_info.column_names[0]
696702

697-
# Input validation ----------------------------------------------------
698-
# Check that `group_variable_name` has TWO levels, representing the
699-
# treated/untreated. But it does not matter what the actual names of
700-
# the levels are.
701-
assert (
702-
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
703-
), f"""
704-
There must be 2 levels of the grouping variable {self.group_variable_name}
705-
.I.e. the treated and untreated.
706-
"""
707-
708703
# fit the model to the observed (pre-intervention) data
709704
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.X.shape[0])}
710705
self.model.fit(X=self.X, y=self.y, coords=COORDS)
@@ -743,6 +738,18 @@ def __init__(
743738

744739
# ================================================================
745740

741+
def _input_validation(self):
742+
"""Validate the input data for correctness"""
743+
# Check that `group_variable_name` has TWO levels, representing the
744+
# treated/untreated. But it does not matter what the actual names of
745+
# the levels are.
746+
assert (
747+
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
748+
), f"""
749+
There must be 2 levels of the grouping variable {self.group_variable_name}
750+
.I.e. the treated and untreated.
751+
"""
752+
746753
def plot(self):
747754
"""Plot the results"""
748755
fig, ax = plt.subplots(

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)