Skip to content

Commit d998024

Browse files
committed
replace asserts with custom exceptions + add utility functions and associated tests
1 parent 0dc39ab commit d998024

File tree

5 files changed

+88
-34
lines changed

5 files changed

+88
-34
lines changed

causalpy/custom_exceptions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,18 @@ class BadIndexException(Exception):
44

55
def __init__(self, message):
66
self.message = message
7+
8+
9+
class FormulaException(Exception):
10+
"""Exception raised given when there is some error in a user-provided model
11+
formula"""
12+
13+
def __init__(self, message):
14+
self.message = message
15+
16+
17+
class DataException(Exception):
18+
"""Exception raised given when there is some error in user-provided dataframe"""
19+
20+
def __init__(self, message):
21+
self.message = message

causalpy/pymc_experiments.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
import xarray as xr
99
from patsy import build_design_matrices, dmatrices
1010

11-
from causalpy.custom_exceptions import BadIndexException
11+
from causalpy.custom_exceptions import (
12+
BadIndexException,
13+
DataException,
14+
FormulaException,
15+
)
1216
from causalpy.plot_utils import plot_xY
17+
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
1318

1419
LEGEND_FONT_SIZE = 12
1520
az.style.use("arviz-darkgrid")
@@ -117,7 +122,7 @@ def __init__(
117122
self.post_impact_cumulative = self.post_impact.cumsum(dim="obs_ind")
118123

119124
def _input_validation(self, data, treatment_time):
120-
"""Validate the input data for correctness"""
125+
"""Validate the input data and model formula for correctness"""
121126
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
122127
treatment_time, pd.Timestamp
123128
):
@@ -361,28 +366,27 @@ def __init__(
361366
self.causal_impact = self.idata.posterior["beta"].isel({"coeffs": i})
362367

363368
def _input_validation(self):
364-
"""Validate the input data for correctness"""
365-
assert (
366-
"post_treatment" in self.formula
367-
), "A predictor called `post_treatment` should be in the dataframe"
368-
assert (
369-
"post_treatment" in self.data.columns
370-
), "Require a boolean column labelling observations which are `treated`"
371-
# Check for `unit` in the incoming dataframe.
372-
# *This is only used for plotting purposes*
373-
assert (
374-
"unit" in self.data.columns
375-
), """
376-
Require a `unit` column to label unique units.
377-
This is used for plotting purposes
378-
"""
379-
# Check that `group_variable_name` is dummy coded. It should be 0 or 1
380-
assert not set(self.data[self.group_variable_name]).difference(
381-
set([0, 1])
382-
), f"""
383-
The grouping variable {self.group_variable_name} should be dummy coded.
384-
Consisting of 0's and 1's only.
385-
"""
369+
"""Validate the input data and model formula for correctness"""
370+
if "post_treatment" not in self.formula:
371+
raise FormulaException(
372+
"A predictor called `post_treatment` should be in the formula"
373+
)
374+
375+
if "post_treatment" not in self.data.columns:
376+
raise DataException(
377+
"Require a boolean column labelling observations which are `treated`"
378+
)
379+
380+
if "unit" not in self.data.columns:
381+
raise DataException(
382+
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
383+
)
384+
385+
if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
386+
raise DataException(
387+
f"""The grouping variable {self.group_variable_name} should be dummy
388+
coded. Consisting of 0's and 1's only."""
389+
)
386390

387391
def plot(self):
388392
"""Plot the results.
@@ -744,16 +748,17 @@ def __init__(
744748
# ================================================================
745749

746750
def _input_validation(self):
747-
"""Validate the input data for correctness"""
751+
"""Validate the input data and model formula for correctness"""
748752
# Check that `group_variable_name` has TWO levels, representing the
749753
# treated/untreated. But it does not matter what the actual names of
750754
# the levels are.
751-
assert (
752-
len(pd.Categorical(self.data[self.group_variable_name]).categories) == 2
753-
), f"""
754-
There must be 2 levels of the grouping variable {self.group_variable_name}
755-
.I.e. the treated and untreated.
756-
"""
755+
if not _series_has_2_levels(self.data[self.group_variable_name]):
756+
raise ValueError(
757+
f"""
758+
There must be 2 levels of the grouping variable
759+
{self.group_variable_name}. I.e. the treated and untreated.
760+
"""
761+
)
757762

758763
def plot(self):
759764
"""Plot the results"""

causalpy/tests/test_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pandas as pd
2+
3+
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
4+
5+
6+
def test_dummy_coding():
7+
"""Test if the function to check if a variable is dummy coded works correctly"""
8+
assert _is_variable_dummy_coded(pd.Series([False, True, False, True])) is True
9+
assert _is_variable_dummy_coded(pd.Series([False, True, False, "frog"])) is False
10+
assert _is_variable_dummy_coded(pd.Series([0, 0, 1, 0, 1])) is True
11+
assert _is_variable_dummy_coded(pd.Series([0, 0, 1, 0, 2])) is False
12+
assert _is_variable_dummy_coded(pd.Series([0, 0.5, 1, 0, 1])) is False
13+
14+
15+
def test_2_level_series():
16+
"""Test if the function to check if a variable has 2 levels works correctly"""
17+
assert _series_has_2_levels(pd.Series(["a", "a", "b"])) is True
18+
assert _series_has_2_levels(pd.Series(["a", "a", "b", "c"])) is False
19+
assert _series_has_2_levels(pd.Series(["coffee", "tea", "coffee"])) is True
20+
assert _series_has_2_levels(pd.Series(["water", "tea", "coffee"])) is False
21+
assert _series_has_2_levels(pd.Series([0, 1, 0, 1])) is True
22+
assert _series_has_2_levels(pd.Series([0, 1, 0, 2])) is False

causalpy/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pandas as pd
2+
3+
4+
def _is_variable_dummy_coded(series: pd.Series) -> bool:
5+
"""Check if a data in the provided Series is dummy coded. It should be 0 or 1
6+
only."""
7+
return len(set(series).difference(set([0, 1]))) == 0
8+
9+
10+
def _series_has_2_levels(series: pd.Series) -> bool:
11+
"""Check that the variable in the provided Series has 2 levels"""
12+
return len(pd.Categorical(series).categories) == 2

img/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)