Skip to content

Commit 27c737a

Browse files
committed
Basic tests to stop tests failing
1 parent eeda6e7 commit 27c737a

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

causal_testing/testing/estimators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,12 +472,13 @@ def __init__(
472472
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
473473
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"
474474

475-
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float, list[float]]:
475+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
476476
model = self._run_linear_regression()
477477

478478
x = {"Intercept": 1, self.treatment: self.treatment_value}
479-
for k, v in adjustment_config.items():
480-
x[k] = v
479+
if adjustment_config is not None:
480+
for k, v in adjustment_config.items():
481+
x[k] = v
481482
if self.effect_modifiers is not None:
482483
for k, v in self.effect_modifiers.items():
483484
x[k] = v

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def test_init_valid_values(self):
3434

3535
test_function = lambda x: x **2
3636

37-
surrogate_model = CubicSplineRegressionEstimator()
37+
surrogate_model = CubicSplineRegressionEstimator("", 0, 0, set(), "", 4)
3838

3939
search_function = SearchFitnessFunction(fitness_function=test_function, surrogate_model=surrogate_model)
4040

41-
self.assertIsCallable(search_function.fitness_function)
41+
self.assertTrue(callable(search_function.fitness_function))
4242
self.assertIsInstance(search_function.surrogate_model, CubicSplineRegressionEstimator)

tests/testing_tests/test_estimators.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -421,28 +421,31 @@ def test_program_11_3_cublic_spline(self):
421421

422422
"""Test whether the cublic_spline regression implementation produces the same results as program 11.3 (p. 162).
423423
https://www.hsph.harvard.edu/miguel-hernan/wp-content/uploads/sites/1268/2023/10/hernanrobins_WhatIf_30sep23.pdf
424+
Slightly modified as Hernan et al. use linear regression for this example.
424425
"""
425426

426427
df = self.chapter_11_df.copy()
427428

428429
cublic_spline_estimator = CubicSplineRegressionEstimator(
429-
"treatments", None, None, set(), "outcomes", 3, df)
430+
"treatments", 1, 0, set(), "outcomes", 3, df)
430431

431432
model = cublic_spline_estimator._run_linear_regression()
432433

433-
ate, _ = cublic_spline_estimator.estimate_coefficient()
434-
435434
self.assertEqual(
436435
round(
437-
model.params["Intercept"]
438-
+ 90 * model.params["treatments"]
439-
+ 90 * 90 * model.params["np.power(treatments, 2)"],
436+
cublic_spline_estimator.model.predict({"Intercept": 1, "treatments": 90}).iloc[0],
440437
1,
441438
),
442-
197.1,
439+
195.6,
443440
)
444-
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
445-
self.assertEqual(round(model.params["treatments"], 3), round(ate, 3))
441+
442+
ate_1 = cublic_spline_estimator.estimate_ate_calculated()
443+
cublic_spline_estimator.treatment_value = 2
444+
ate_2 = cublic_spline_estimator.estimate_ate_calculated()
445+
446+
# Doubling the treatemebnt value should roughly but not exactly double the ATE
447+
self.assertNotEqual(ate_1 * 2, ate_2)
448+
self.assertAlmostEqual(ate_1 * 2, ate_2)
446449

447450

448451

0 commit comments

Comments
 (0)