Skip to content

Commit 007404d

Browse files
committed
asserts at end of tests + use isinstance
1 parent 2e1e41f commit 007404d

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

causalpy/tests/test_integration_skl_examples.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
@pytest.mark.integration
1111
def test_did():
1212
data = cp.load_data("did")
13-
assert type(data) is pd.DataFrame
1413
result = cp.skl_experiments.DifferenceInDifferences(
1514
data,
1615
formula="y ~ 1 + group + t + treated:group",
1716
time_variable_name="t",
1817
prediction_model=LinearRegression(),
1918
)
20-
assert type(result) is cp.skl_experiments.DifferenceInDifferences
19+
assert isinstance(data, pd.DataFrame)
20+
assert isinstance(result, cp.skl_experiments.DifferenceInDifferences)
2121

2222

2323
@pytest.mark.integration
@@ -28,15 +28,15 @@ def test_rd_drinking():
2828
.assign(treated=lambda df_: df_.age > 21)
2929
.dropna(axis=0)
3030
)
31-
assert type(df) is pd.DataFrame
3231
result = cp.skl_experiments.RegressionDiscontinuity(
3332
df,
3433
formula="all ~ 1 + age + treated",
3534
running_variable_name="age",
3635
prediction_model=LinearRegression(),
3736
treatment_threshold=21,
3837
)
39-
assert type(result) is cp.skl_experiments.RegressionDiscontinuity
38+
assert isinstance(df, pd.DataFrame)
39+
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
4040

4141

4242
@pytest.mark.integration
@@ -45,65 +45,65 @@ def test_its():
4545
df["date"] = pd.to_datetime(df["date"])
4646
df.set_index("date", inplace=True)
4747
treatment_time = pd.to_datetime("2017-01-01")
48-
assert type(df) is pd.DataFrame
4948
result = cp.skl_experiments.SyntheticControl(
5049
df,
5150
treatment_time,
5251
formula="y ~ 1 + t + C(month)",
5352
prediction_model=LinearRegression(),
5453
)
55-
assert type(result) is cp.skl_experiments.SyntheticControl
54+
assert isinstance(df, pd.DataFrame)
55+
assert isinstance(result, cp.skl_experiments.SyntheticControl)
5656

5757

5858
@pytest.mark.integration
5959
def test_sc():
6060
df = cp.load_data("sc")
6161
treatment_time = 70
62-
assert type(df) is pd.DataFrame
6362
result = cp.skl_experiments.SyntheticControl(
6463
df,
6564
treatment_time,
6665
formula="actual ~ 0 + a + b + c + d + e + f + g",
6766
prediction_model=cp.skl_models.WeightedProportion(),
6867
)
69-
assert type(result) is cp.skl_experiments.SyntheticControl
68+
assert isinstance(df, pd.DataFrame)
69+
assert isinstance(result, cp.skl_experiments.SyntheticControl)
7070

7171

7272
@pytest.mark.integration
7373
def test_rd_linear_main_effects():
7474
data = cp.load_data("rd")
75-
assert type(data) is pd.DataFrame
7675
result = cp.skl_experiments.RegressionDiscontinuity(
7776
data,
7877
formula="y ~ 1 + x + treated",
7978
prediction_model=LinearRegression(),
8079
treatment_threshold=0.5,
8180
)
82-
assert type(result) is cp.skl_experiments.RegressionDiscontinuity
81+
assert isinstance(data, pd.DataFrame)
82+
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
8383

8484

8585
@pytest.mark.integration
8686
def test_rd_linear_with_interaction():
8787
data = cp.load_data("rd")
88-
assert type(data) is pd.DataFrame
8988
result = cp.skl_experiments.RegressionDiscontinuity(
9089
data,
9190
formula="y ~ 1 + x + treated + x:treated",
9291
prediction_model=LinearRegression(),
9392
treatment_threshold=0.5,
9493
)
95-
assert type(result) is cp.skl_experiments.RegressionDiscontinuity
94+
assert isinstance(data, pd.DataFrame)
95+
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
9696

9797

9898
@pytest.mark.integration
9999
def test_rd_linear_with_gaussian_process():
100100
data = cp.load_data("rd")
101-
assert type(data) is pd.DataFrame
102101
kernel = 1.0 * ExpSineSquared(1.0, 5.0) + WhiteKernel(1e-1)
103102
result = cp.skl_experiments.RegressionDiscontinuity(
104103
data,
105104
formula="y ~ 1 + x + treated",
106105
prediction_model=GaussianProcessRegressor(kernel=kernel),
107106
treatment_threshold=0.5,
108107
)
109-
assert type(result) is cp.skl_experiments.RegressionDiscontinuity
108+
assert isinstance(data, pd.DataFrame)
109+
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)

0 commit comments

Comments
 (0)