Skip to content

Commit 7da3aef

Browse files
authored
Merge pull request #324 from CITCOM-project/323-batch-test-execution
Batch test execution
2 parents 15513be + fcab481 commit 7da3aef

File tree

9 files changed

+175
-96
lines changed

9 files changed

+175
-96
lines changed

causal_testing/__main__.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""This module contains the main entrypoint functionality to the Causal Testing Framework."""
22

33
import logging
4+
import tempfile
5+
import json
6+
import os
7+
48
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework
59

610

@@ -34,13 +38,27 @@ def main() -> None:
3438

3539
if args.batch_size > 0:
3640
logging.info(f"Running tests in batches of size {args.batch_size}")
37-
results = framework.run_tests_in_batches(batch_size=args.batch_size, silent=args.silent)
41+
with tempfile.TemporaryDirectory() as tmpdir:
42+
output_files = []
43+
for i, results in enumerate(framework.run_tests_in_batches(batch_size=args.batch_size, silent=args.silent)):
44+
temp_file_path = os.path.join(tmpdir, f"output_{i}.json")
45+
framework.save_results(results, temp_file_path)
46+
output_files.append(temp_file_path)
47+
del results
48+
49+
# Now stitch the results together from the temporary files
50+
all_results = []
51+
for file_path in output_files:
52+
with open(file_path, "r", encoding="utf-8") as f:
53+
all_results.extend(json.load(f))
54+
55+
# Save the final stitched results to your desired location
56+
with open(args.output, "w", encoding="utf-8") as f:
57+
json.dump(all_results, f, indent=4)
3858
else:
3959
logging.info("Running tests in regular mode")
4060
results = framework.run_tests(silent=args.silent)
41-
42-
# Save results
43-
framework.save_results(results)
61+
framework.save_results(results)
4462

4563
logging.info("Causal testing completed successfully.")
4664

causal_testing/estimation/abstract_regression_estimator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
query=query,
4646
)
4747

48-
self.model = None
4948
if effect_modifiers is None:
5049
effect_modifiers = []
5150
if adjustment_set is None:
@@ -79,15 +78,14 @@ def add_modelling_assumptions(self):
7978
"do not need to be linear."
8079
)
8180

82-
def _run_regression(self, data=None) -> RegressionResultsWrapper:
81+
def fit_model(self, data=None) -> RegressionResultsWrapper:
8382
"""Run logistic regression of the treatment and adjustment set against the outcome and return the model.
8483
8584
:return: The model after fitting to data.
8685
"""
8786
if data is None:
8887
data = self.df
8988
model = self.regressor(formula=self.formula, data=data).fit(disp=0)
90-
self.model = model
9189
return model
9290

9391
def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
@@ -102,7 +100,7 @@ def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
102100
if adjustment_config is None:
103101
adjustment_config = {}
104102

105-
model = self._run_regression(data)
103+
model = self.fit_model(data)
106104

107105
x = pd.DataFrame(columns=self.df.columns)
108106
x["Intercept"] = 1 # self.intercept

causal_testing/estimation/cubic_spline_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
5959
6060
:return: The average treatment effect.
6161
"""
62-
model = self._run_regression()
62+
model = self.fit_model()
6363

6464
x = {"Intercept": 1, self.base_test_case.treatment_variable.name: self.treatment_value}
6565
if adjustment_config is not None:

causal_testing/estimation/linear_regression_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
9898
9999
:return: The unit average treatment effect and the 95% Wald confidence intervals.
100100
"""
101-
model = self._run_regression()
101+
model = self.fit_model()
102102
newline = "\n"
103103
patsy_md = ModelDesc.from_formula(self.base_test_case.treatment_variable.name)
104104

@@ -129,7 +129,7 @@ def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
129129
130130
:return: The average treatment effect and the 95% Wald confidence intervals.
131131
"""
132-
model = self._run_regression()
132+
model = self.fit_model()
133133

134134
# Create an empty individual for the control and treated
135135
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)

causal_testing/estimation/logistic_regression_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def estimate_unit_odds_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series
3838
3939
:return: The odds ratio. Confidence intervals are not yet supported.
4040
"""
41-
model = self._run_regression(self.df)
41+
model = self.fit_model(self.df)
4242
ci_low, ci_high = np.exp(model.conf_int(self.alpha).loc[self.base_test_case.treatment_variable.name])
4343
return pd.Series(np.exp(model.params[self.base_test_case.treatment_variable.name])), [
4444
pd.Series(ci_low),

causal_testing/main.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict, Any, Optional, List, Union, Sequence
99
from tqdm import tqdm
1010

11+
1112
import pandas as pd
1213
import numpy as np
1314

@@ -344,7 +345,6 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
344345
num_batches = int(np.ceil(num_tests / batch_size))
345346

346347
logger.info(f"Processing {num_tests} tests in {num_batches} batches of up to {batch_size} tests each")
347-
all_results = []
348348
with tqdm(total=num_tests, desc="Overall progress", mininterval=0.1) as progress:
349349
# Process each batch
350350
for batch_idx in range(num_batches):
@@ -360,26 +360,23 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
360360
batch_results = []
361361
for test_case in current_batch:
362362
try:
363-
result = test_case.execute_test()
364-
batch_results.append(result)
365-
except (TypeError, AttributeError) as e:
363+
batch_results.append(test_case.execute_test())
364+
# pylint: disable=broad-exception-caught
365+
except Exception as e:
366366
if not silent:
367367
logger.error(f"Type or attribute error in test: {str(e)}")
368368
raise
369-
result = CausalTestResult(
370-
estimator=test_case.estimator,
371-
test_value=TestValue("Error", str(e)),
369+
batch_results.append(
370+
CausalTestResult(
371+
estimator=test_case.estimator,
372+
test_value=TestValue("Error", str(e)),
373+
)
372374
)
373-
batch_results.append(result)
374375

375376
progress.update(1)
376377

377-
all_results.extend(batch_results)
378-
379-
logger.info(f"Completed batch {batch_idx + 1} of {num_batches}")
380-
381-
logger.info(f"Completed processing all {len(all_results)} tests in {num_batches} batches")
382-
return all_results
378+
yield batch_results
379+
logger.info(f"Completed processing in {num_batches} batches")
383380

384381
def run_tests(self, silent=False) -> List[CausalTestResult]:
385382
"""
@@ -399,7 +396,6 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
399396
try:
400397
result = test_case.execute_test()
401398
results.append(result)
402-
logger.info(f"Test completed: {test_case}")
403399
# pylint: disable=broad-exception-caught
404400
except Exception as e:
405401
if not silent:
@@ -414,9 +410,11 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
414410

415411
return results
416412

417-
def save_results(self, results: List[CausalTestResult]) -> None:
413+
def save_results(self, results: List[CausalTestResult], output_path: str = None) -> None:
418414
"""Save test results to JSON file in the expected format."""
419-
logger.info(f"Saving results to {self.paths.output_path}")
415+
if output_path is None:
416+
output_path = self.paths.output_path
417+
logger.info(f"Saving results to {output_path}")
420418

421419
# Load original test configs to preserve test metadata
422420
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
@@ -460,7 +458,7 @@ def save_results(self, results: List[CausalTestResult]) -> None:
460458
json_results.append(output)
461459

462460
# Save to file
463-
with open(self.paths.output_path, "w", encoding="utf-8") as f:
461+
with open(output_path, "w", encoding="utf-8") as f:
464462
json.dump(json_results, f, indent=2)
465463

466464
logger.info("Results saved successfully")

tests/estimation_tests/test_cubic_spline_estimator.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
import unittest
2-
import pandas as pd
3-
import numpy as np
4-
import matplotlib.pyplot as plt
5-
from causal_testing.specification.variable import Input
6-
from causal_testing.utils.validation import CausalValidator
72

83
from causal_testing.estimation.cubic_spline_estimator import CubicSplineRegressionEstimator
9-
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
10-
11-
from tests.estimation_tests.test_linear_regression_estimator import TestLinearRegressionEstimator
124
from causal_testing.testing.base_test_case import BaseTestCase
135
from causal_testing.specification.variable import Input, Output
146

7+
from tests.estimation_tests.test_linear_regression_estimator import load_chapter_11_df
8+
159

16-
class TestCubicSplineRegressionEstimator(TestLinearRegressionEstimator):
10+
class TestCubicSplineRegressionEstimator(unittest.TestCase):
1711
@classmethod
1812
def setUpClass(cls):
1913
super().setUpClass()
@@ -24,22 +18,14 @@ def test_program_11_3_cublic_spline(self):
2418
Slightly modified as Hernan et al. use linear regression for this example.
2519
"""
2620

27-
df = self.chapter_11_df.copy()
21+
df = load_chapter_11_df()
2822

2923
base_test_case = BaseTestCase(Input("treatments", float), Output("outcomes", float))
3024

3125
cublic_spline_estimator = CubicSplineRegressionEstimator(base_test_case, 1, 0, set(), 3, df)
3226

3327
ate_1 = cublic_spline_estimator.estimate_ate_calculated()
3428

35-
self.assertEqual(
36-
round(
37-
cublic_spline_estimator.model.predict({"Intercept": 1, "treatments": 90}).iloc[0],
38-
1,
39-
),
40-
195.6,
41-
)
42-
4329
cublic_spline_estimator.treatment_value = 2
4430
ate_2 = cublic_spline_estimator.estimate_ate_calculated()
4531

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import unittest
22
import pandas as pd
33
import numpy as np
4-
import matplotlib.pyplot as plt
5-
from causal_testing.specification.variable import Input
4+
from causal_testing.specification.variable import Input, Output
65
from causal_testing.utils.validation import CausalValidator
76

87
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
98
from causal_testing.testing.base_test_case import BaseTestCase
10-
from causal_testing.specification.variable import Input, Output
119

1210

1311
def load_nhefs_df():
@@ -77,7 +75,7 @@ def test_linear_regression_categorical_ate(self):
7775
df = self.scarf_df.copy()
7876
base_test_case = BaseTestCase(Input("color", float), Output("completed", float))
7977
logistic_regression_estimator = LinearRegressionEstimator(base_test_case, None, None, set(), df)
80-
ate, confidence = logistic_regression_estimator.estimate_coefficient()
78+
_, confidence = logistic_regression_estimator.estimate_coefficient()
8179
self.assertTrue(all([ci_low < 0 < ci_high for ci_low, ci_high in zip(confidence[0], confidence[1])]))
8280

8381
def test_program_11_2(self):
@@ -86,22 +84,8 @@ def test_program_11_2(self):
8684
linear_regression_estimator = LinearRegressionEstimator(self.base_test_case, None, None, set(), df)
8785
ate, _ = linear_regression_estimator.estimate_coefficient()
8886

89-
self.assertEqual(
90-
round(
91-
linear_regression_estimator.model.params["Intercept"]
92-
+ 90 * linear_regression_estimator.model.params["treatments"],
93-
1,
94-
),
95-
216.9,
96-
)
97-
9887
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
99-
self.assertTrue(
100-
all(
101-
round(linear_regression_estimator.model.params["treatments"], 1) == round(ate_single, 1)
102-
for ate_single in ate
103-
)
104-
)
88+
self.assertTrue(all(round(ate["treatments"], 1) == round(ate_single, 1) for ate_single in ate))
10589

10690
def test_program_11_3(self):
10791
"""Test whether our linear regression implementation produces the same results as program 11.3 (p. 144)."""
@@ -110,23 +94,8 @@ def test_program_11_3(self):
11094
self.base_test_case, None, None, set(), df, formula="outcomes ~ treatments + I(treatments ** 2)"
11195
)
11296
ate, _ = linear_regression_estimator.estimate_coefficient()
113-
print(linear_regression_estimator.model.summary())
114-
self.assertEqual(
115-
round(
116-
linear_regression_estimator.model.params["Intercept"]
117-
+ 90 * linear_regression_estimator.model.params["treatments"]
118-
+ 90 * 90 * linear_regression_estimator.model.params["I(treatments ** 2)"],
119-
1,
120-
),
121-
197.1,
122-
)
12397
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
124-
self.assertTrue(
125-
all(
126-
round(linear_regression_estimator.model.params["treatments"], 3) == round(ate_single, 3)
127-
for ate_single in ate
128-
)
129-
)
98+
self.assertTrue(all(round(ate["treatments"], 3) == round(ate_single, 3) for ate_single in ate))
13099

131100
def test_program_15_1A(self):
132101
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)."""
@@ -161,15 +130,9 @@ def test_program_15_1A(self):
161130
I(smokeyrs ** 2) +
162131
(qsmk * smokeintensity)""",
163132
)
164-
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
165-
# terms_to_product = [("qsmk", "smokeintensity")]
166-
# for term_to_square in terms_to_square:
167-
# for term_a, term_b in terms_to_product:
168-
# linear_regression_estimator.add_product_term_to_df(term_a, term_b)
169133

170-
linear_regression_estimator.estimate_coefficient()
171-
self.assertEqual(round(linear_regression_estimator.model.params["qsmk"], 1), 2.6)
172-
self.assertEqual(round(linear_regression_estimator.model.params["qsmk:smokeintensity"], 2), 0.05)
134+
coefficient, _ = linear_regression_estimator.estimate_coefficient()
135+
self.assertEqual(round(coefficient["qsmk"], 1), 2.6)
173136

174137
def test_program_15_no_interaction(self):
175138
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)
@@ -281,10 +244,11 @@ def test_program_11_2_with_robustness_validation(self):
281244
"""Test whether our linear regression estimator, as used in test_program_11_2 can correctly estimate robustness."""
282245
df = self.chapter_11_df.copy()
283246
linear_regression_estimator = LinearRegressionEstimator(self.base_test_case, 100, 90, set(), df)
284-
linear_regression_estimator.estimate_coefficient()
285247

286248
cv = CausalValidator()
287-
self.assertEqual(round(cv.estimate_robustness(linear_regression_estimator.model)["treatments"], 4), 0.7353)
249+
self.assertEqual(
250+
round(cv.estimate_robustness(linear_regression_estimator.fit_model())["treatments"], 4), 0.7353
251+
)
288252

289253
def test_gp(self):
290254
df = pd.DataFrame()

0 commit comments

Comments
 (0)