Skip to content

Commit e244470

Browse files
committed
Add: initial unit tests for causal surrogate
1 parent bd0c44f commit e244470

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
from causal_testing.surrogate.causal_surrogate_assisted import SimulationResult, SearchFitnessFunction
3+
from causal_testing.testing.estimators import Estimator, PolynomialRegressionEstimator
4+
5+
class TestSimulationResult(unittest.TestCase):
6+
7+
def setUp(self):
8+
9+
self.data = {'key': 'value'}
10+
11+
def test_inputs(self):
12+
13+
fault_values = [True, False]
14+
15+
relationship_values = ["positive", "negative", None]
16+
17+
for fault in fault_values:
18+
19+
for relationship in relationship_values:
20+
with self.subTest(fault=fault, relationship=relationship):
21+
result = SimulationResult(data=self.data, fault=fault, relationship=relationship)
22+
23+
self.assertIsInstance(result.data, dict)
24+
25+
self.assertEqual(result.fault, fault)
26+
27+
self.assertEqual(result.relationship, relationship)
28+
29+
class TestSearchFitnessFunction(unittest.TestCase):
30+
31+
#TODO: complete tests for causal surrogate
32+
33+
def test_init_valid_values(self):
34+
35+
test_function = lambda x: x **2
36+
37+
surrogate_model = PolynomialRegressionEstimator()
38+
39+
search_function = SearchFitnessFunction(fitness_function=test_function, surrogate_model=surrogate_model)
40+
41+
self.assertIsCallable(search_function.fitness_function)
42+
self.assertIsInstance(search_function.surrogate_model, PolynomialRegressionEstimator)

tests/surrogate_tests/test_surrogate_search_algorithms.py

Whitespace-only changes.

tests/testing_tests/test_estimators.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
CausalForestEstimator,
88
LogisticRegressionEstimator,
99
InstrumentalVariableEstimator,
10+
PolynomialRegressionEstimator
1011
)
1112
from causal_testing.specification.variable import Input
1213
from causal_testing.utils.validation import CausalValidator
@@ -409,6 +410,43 @@ def test_program_11_2_with_robustness_validation(self):
409410
self.assertEqual(round(cv.estimate_robustness(model)["treatments"], 4), 0.7353)
410411

411412

413+
class TestPolynomialRegressionEstimator(TestLinearRegressionEstimator):
414+
415+
@classmethod
416+
417+
def setUpClass(cls):
418+
419+
super().setUpClass()
420+
def test_program_11_3_polynomial(self):
421+
422+
"""Test whether the polynomial regression implementation produces the same results as program 11.3 (p. 162).
423+
https://www.hsph.harvard.edu/miguel-hernan/wp-content/uploads/sites/1268/2023/10/hernanrobins_WhatIf_30sep23.pdf
424+
"""
425+
426+
df = self.chapter_11_df.copy()
427+
428+
polynomial_estimator = PolynomialRegressionEstimator(
429+
"treatments", None, None, set(), "outcomes", 3, df)
430+
431+
model = polynomial_estimator._run_linear_regression()
432+
433+
ate, _ = polynomial_estimator.estimate_coefficient()
434+
435+
self.assertEqual(
436+
round(
437+
model.params["Intercept"]
438+
+ 90 * model.params["treatments"]
439+
+ 90 * 90 * model.params["np.power(treatments, 2)"],
440+
1,
441+
),
442+
197.1,
443+
)
444+
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
445+
self.assertEqual(round(model.params["treatments"], 3), round(ate, 3))
446+
447+
448+
449+
412450
class TestCausalForestEstimator(unittest.TestCase):
413451
"""Test the linear regression estimator against the programming exercises in Section 2 of Hernán and Robins [1].
414452
@@ -491,3 +529,5 @@ def test_X1_effect(self):
491529
test_results = lr_model.estimate_ate()
492530
ate = test_results[0]
493531
self.assertAlmostEqual(ate, 2.0)
532+
533+

0 commit comments

Comments
 (0)