Skip to content

Commit fc3964f

Browse files
author
AndrewC19
committed
Product terms now accessible in LR estimator.
1 parent e9f51e5 commit fc3964f

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

causal_testing/testing/estimators.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -294,18 +294,17 @@ def __init__(
294294
):
295295
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
296296

297-
if product_terms is None:
298-
product_terms = []
299-
for (term_a, term_b) in product_terms:
300-
self.add_product_term_to_df(term_a, term_b)
301-
for term in self.effect_modifiers:
302-
self.adjustment_set.add(term)
303-
304-
self.product_terms = product_terms
297+
self.product_terms = []
305298
self.square_terms = []
306299
self.inverse_terms = []
307300
self.intercept = intercept
308301

302+
if product_terms:
303+
for (term_a, term_b) in product_terms:
304+
self.add_product_term_to_df(term_a, term_b)
305+
for term in self.effect_modifiers:
306+
self.adjustment_set.add(term)
307+
309308
def add_modelling_assumptions(self):
310309
"""
311310
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that

tests/testing_tests/test_estimators.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,25 @@ def test_program_15_cate(self):
331331
)
332332
cates_df, _ = causal_forest.estimate_cates()
333333
self.assertGreater(cates_df["cate"].mean(), 0)
334+
335+
336+
class TestLinearRegressionInteraction(unittest.TestCase):
337+
"""Test linear regression for estimating effects involving interaction."""
338+
339+
@classmethod
340+
def setUpClass(cls) -> None:
341+
# Y = 2X1 - 3X2 + 2*X1*X2 + 10
342+
df = pd.DataFrame({'X1': np.random.uniform(-1000, 1000, 1000),
343+
'X2': np.random.uniform(-1000, 1000, 1000)})
344+
df['Y'] = 2*df['X1'] - 3*df['X2'] + 2*df['X1']*df['X2'] + 10
345+
cls.df = df
346+
print(df)
347+
348+
def test_X1_effect(self):
349+
"""When we fix the value of X2 to 0, the effect of X1 on Y should become ~2 (because X2 terms are cancelled)."""
350+
x2 = Input('X2', float)
351+
lr_model = LinearRegressionEstimator(('X1',), 1, 0, {'X2'}, ('Y',), effect_modifiers={x2: 0},
352+
product_terms=[('X1', 'X2')], df=self.df)
353+
test_results = lr_model.estimate_ate()
354+
ate = test_results[0]
355+
self.assertAlmostEqual(ate, 2.0)

0 commit comments

Comments
 (0)