Skip to content

Commit 6c67f80

Browse files
Merge branch 'main' into fix_PyPI_action
2 parents 2fc578c + 590f218 commit 6c67f80

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

causal_testing/testing/estimators.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -304,19 +304,18 @@ def __init__(
304304
):
305305
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
306306

307-
if product_terms is None:
308-
product_terms = []
309-
for term_a, term_b in product_terms:
310-
self.add_product_term_to_df(term_a, term_b)
311-
for term in self.effect_modifiers:
312-
self.adjustment_set.add(term)
313-
314-
self.product_terms = product_terms
307+
self.product_terms = []
315308
self.square_terms = []
316309
self.inverse_terms = []
317310
self.intercept = intercept
318311
self.model = None
319312

313+
if product_terms:
314+
for term_a, term_b in product_terms:
315+
self.add_product_term_to_df(term_a, term_b)
316+
for term in self.effect_modifiers:
317+
self.adjustment_set.add(term)
318+
320319
def add_modelling_assumptions(self):
321320
"""
322321
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that

tests/testing_tests/test_estimators.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,25 @@ def test_program_15_cate(self):
331331
)
332332
cates_df, _ = causal_forest.estimate_cates()
333333
self.assertGreater(cates_df["cate"].mean(), 0)
334+
335+
336+
class TestLinearRegressionInteraction(unittest.TestCase):
337+
"""Test linear regression for estimating effects involving interaction."""
338+
339+
@classmethod
340+
def setUpClass(cls) -> None:
341+
# Y = 2X1 - 3X2 + 2*X1*X2 + 10
342+
df = pd.DataFrame({"X1": np.random.uniform(-1000, 1000, 1000), "X2": np.random.uniform(-1000, 1000, 1000)})
343+
df["Y"] = 2 * df["X1"] - 3 * df["X2"] + 2 * df["X1"] * df["X2"] + 10
344+
cls.df = df
345+
print(df)
346+
347+
def test_X1_effect(self):
348+
"""When we fix the value of X2 to 0, the effect of X1 on Y should become ~2 (because X2 terms are cancelled)."""
349+
x2 = Input("X2", float)
350+
lr_model = LinearRegressionEstimator(
351+
("X1",), 1, 0, {"X2"}, ("Y",), effect_modifiers={x2: 0}, product_terms=[("X1", "X2")], df=self.df
352+
)
353+
test_results = lr_model.estimate_ate()
354+
ate = test_results[0]
355+
self.assertAlmostEqual(ate, 2.0)

0 commit comments

Comments
 (0)