Skip to content

Commit 0cf4f46

Browse files
committed
remove unnecessary kwargs in did experiments
1 parent 6a9214b commit 0cf4f46

File tree

4 files changed

+48
-68
lines changed

4 files changed

+48
-68
lines changed

causalpy/tests/test_input_validation.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ def test_did_validation_post_treatment_formula():
5858
formula="y ~ 1 + group*post_SOMETHING",
5959
time_variable_name="t",
6060
group_variable_name="group",
61-
treated=1,
62-
untreated=0,
6361
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
6462
)
6563

@@ -91,8 +89,6 @@ def test_did_validation_post_treatment_data():
9189
formula="y ~ 1 + group*post_treatment",
9290
time_variable_name="t",
9391
group_variable_name="group",
94-
treated=1,
95-
untreated=0,
9692
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
9793
)
9894

@@ -124,8 +120,6 @@ def test_did_validation_unit_data():
124120
formula="y ~ 1 + group*post_treatment",
125121
time_variable_name="t",
126122
group_variable_name="group",
127-
treated=1,
128-
untreated=0,
129123
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
130124
)
131125

@@ -157,8 +151,6 @@ def test_did_validation_group_dummy_coded():
157151
formula="y ~ 1 + group*post_treatment",
158152
time_variable_name="t",
159153
group_variable_name="group",
160-
treated=1,
161-
untreated=0,
162154
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
163155
)
164156

causalpy/tests/test_pymc_models.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,6 @@ def test_idata_property():
130130
formula="y ~ 1 + group + t + group:post_treatment",
131131
time_variable_name="t",
132132
group_variable_name="group",
133-
treated=1,
134-
untreated=0,
135133
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
136134
)
137135
assert hasattr(result, "idata")
@@ -157,17 +155,13 @@ def test_result_reproducibility(seed):
157155
formula="y ~ 1 + group + t + group:post_treatment",
158156
time_variable_name="t",
159157
group_variable_name="group",
160-
treated=1,
161-
untreated=0,
162158
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
163159
)
164160
result2 = cp.DifferenceInDifferences(
165161
df,
166162
formula="y ~ 1 + group + t + group:post_treatment",
167163
time_variable_name="t",
168164
group_variable_name="group",
169-
treated=1,
170-
untreated=0,
171165
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
172166
)
173167
assert np.all(result1.idata.posterior.mu == result2.idata.posterior.mu)

0 commit comments

Comments
 (0)