Skip to content

Commit aefccd3

Browse files
Merge pull request #219 from CITCOM-project/test-engine-refactor
Test engine refactor
2 parents e028e84 + b2a7d31 commit aefccd3

File tree

18 files changed

+371
-554
lines changed

18 files changed

+371
-554
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
149149
150150
:return: A pandas dataframe containing execution data that is valid for the scenario-under-test.
151151
"""
152-
153152
execution_data_df = self.data
154153
for meta in self.scenario.metas():
155154
if meta.name not in self.data:

causal_testing/json_front/json_class.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77

8-
from collections.abc import Iterable, Mapping
8+
from collections.abc import Mapping
99
from dataclasses import dataclass
1010
from pathlib import Path
1111
from statistics import StatisticsError
@@ -22,7 +22,6 @@
2222
from causal_testing.specification.variable import Input, Meta, Output
2323
from causal_testing.testing.causal_test_case import CausalTestCase
2424
from causal_testing.testing.causal_test_result import CausalTestResult
25-
from causal_testing.testing.causal_test_engine import CausalTestEngine
2625
from causal_testing.testing.estimators import Estimator
2726
from causal_testing.testing.base_test_case import BaseTestCase
2827

@@ -49,12 +48,12 @@ class JsonUtility:
4948
def __init__(self, output_path: str, output_overwrite: bool = False):
5049
self.input_paths = None
5150
self.variables = {"inputs": {}, "outputs": {}, "metas": {}}
52-
self.data = []
5351
self.test_plan = None
5452
self.scenario = None
5553
self.causal_specification = None
5654
self.output_path = Path(output_path)
5755
self.check_file_exists(self.output_path, output_overwrite)
56+
self.data_collector = None
5857

5958
def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None):
6059
"""
@@ -69,6 +68,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
6968

7069
def setup(self, scenario: Scenario):
7170
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
71+
data = []
7272
self.scenario = scenario
7373
self._get_scenario_variables()
7474
self.scenario.setup_treatment_variables()
@@ -80,20 +80,21 @@ def setup(self, scenario: Scenario):
8080
self.test_plan = json.load(f)
8181
# Populate the data
8282
if self.input_paths.data_paths:
83-
self.data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
84-
if len(self.data) == 0:
83+
data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
84+
if len(data) == 0:
8585
raise ValueError(
8686
"No data found, either provide a path to a file containing data or manually populate the .data "
8787
"attribute with a dataframe before calling .setup()"
8888
)
89+
self.data_collector = ObservationalDataCollector(self.scenario, data)
8990
self._populate_metas()
9091

9192
def _create_abstract_test_case(self, test, mutates, effects):
9293
assert len(test["mutations"]) == 1
9394
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
9495

9596
if not treatment_var.distribution:
96-
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
97+
fitter = Fitter(self.data_collector.data[treatment_var.name], distributions=get_common_distributions())
9798
fitter.fit()
9899
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
99100
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
@@ -149,6 +150,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
149150
treatment_value=test["treatment_value"],
150151
estimate_type=test["estimate_type"],
151152
)
153+
152154
failed, _ = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
153155

154156
msg = (
@@ -231,6 +233,7 @@ def _execute_tests(self, concrete_tests, test, f_flag):
231233
details = []
232234
if "formula" in test:
233235
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
236+
234237
for concrete_test in concrete_tests:
235238
failed, result = self._execute_test_case(concrete_test, test, f_flag)
236239
details.append(result)
@@ -243,10 +246,10 @@ def _populate_metas(self):
243246
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
244247
"""
245248
for meta in self.scenario.variables_of_type(Meta):
246-
meta.populate(self.data)
249+
meta.populate(self.data_collector.data)
247250

248251
def _execute_test_case(
249-
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
252+
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
250253
) -> (bool, CausalTestResult):
251254
"""Executes a singular test case, prints the results and returns the test case result
252255
:param causal_test_case: The concrete test case to be executed
@@ -258,10 +261,10 @@ def _execute_test_case(
258261
"""
259262
failed = False
260263

261-
causal_test_engine, estimation_model = self._setup_test(
262-
causal_test_case, test, test["conditions"] if "conditions" in test else None
264+
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test)
265+
causal_test_result = causal_test_case.execute_test(
266+
estimator=estimation_model, data_collector=self.data_collector
263267
)
264-
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
265268

266269
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
267270

@@ -283,9 +286,7 @@ def _execute_test_case(
283286
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
284287
return failed, causal_test_result
285288

286-
def _setup_test(
287-
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
288-
) -> tuple[CausalTestEngine, Estimator]:
289+
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estimator:
289290
"""Create the necessary inputs for a single test case
290291
:param causal_test_case: The concrete test case to be executed
291292
:param test: Single JSON test definition stored in a mapping (dict)
@@ -296,12 +297,6 @@ def _setup_test(
296297
- causal_test_engine - Test Engine instance for the test being run
297298
- estimation_model - Estimator instance for the test being run
298299
"""
299-
300-
data_collector = ObservationalDataCollector(
301-
self.scenario, self.data.query(" & ".join(conditions)) if conditions else self.data
302-
)
303-
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
304-
305300
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
306301
treatment_var = causal_test_case.treatment_variable
307302
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
@@ -311,14 +306,13 @@ def _setup_test(
311306
"control_value": causal_test_case.control_value,
312307
"adjustment_set": minimal_adjustment_set,
313308
"outcome": causal_test_case.outcome_variable.name,
314-
"df": causal_test_engine.scenario_execution_data_df,
315309
"effect_modifiers": causal_test_case.effect_modifier_configuration,
316310
"alpha": test["alpha"] if "alpha" in test else 0.05,
317311
}
318312
if "formula" in test:
319313
estimator_kwargs["formula"] = test["formula"]
320314
estimation_model = test["estimator"](**estimator_kwargs)
321-
return causal_test_engine, estimation_model
315+
return estimation_model
322316

323317
def _append_to_file(self, line: str, log_level: int = None):
324318
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the

causal_testing/specification/causal_specification.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""This module holds the abstract CausalSpecification data class, which holds a Scenario and CausalDag"""
22

3-
from abc import ABC
43
from dataclasses import dataclass
54
from typing import Union
65

@@ -11,9 +10,9 @@
1110

1211

1312
@dataclass
14-
class CausalSpecification(ABC):
13+
class CausalSpecification:
1514
"""
16-
Abstract Class for the Causal Specification (combination of Scenario and Causal Dag)
15+
Data class storing the Causal Specification (combination of Scenario and Causal Dag)
1716
"""
1817

1918
scenario: Scenario

causal_testing/testing/causal_test_case.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from causal_testing.specification.variable import Variable
66
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
77
from causal_testing.testing.base_test_case import BaseTestCase
8+
from causal_testing.testing.estimators import Estimator
9+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
10+
from causal_testing.data_collection.data_collector import DataCollector
11+
812

913
logger = logging.getLogger(__name__)
1014

@@ -73,6 +77,41 @@ def get_treatment_value(self):
7377
"""Return the treatment value of the treatment variable in this causal test case."""
7478
return self.treatment_value
7579

80+
def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
81+
"""Execute a causal test case and return the causal test result.
82+
83+
:param estimator: A reference to an Estimator class.
84+
:param data_collector: The data collector to be used which provides a dataframe for the Estimator
85+
:return causal_test_result: A CausalTestResult for the executed causal test case.
86+
"""
87+
if estimator.df is None:
88+
estimator.df = data_collector.collect_data()
89+
90+
logger.info("treatments: %s", self.treatment_variable.name)
91+
logger.info("outcomes: %s", self.outcome_variable)
92+
93+
causal_test_result = self._return_causal_test_results(estimator)
94+
return causal_test_result
95+
96+
def _return_causal_test_results(self, estimator) -> CausalTestResult:
97+
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
98+
99+
:param estimator: An Estimator class object
100+
:return: a CausalTestResult object containing the confidence intervals
101+
"""
102+
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
103+
raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
104+
estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
105+
effect, confidence_intervals = estimate_effect(**self.estimate_params)
106+
causal_test_result = CausalTestResult(
107+
estimator=estimator,
108+
test_value=TestValue(self.estimate_type, effect),
109+
effect_modifier_configuration=self.effect_modifier_configuration,
110+
confidence_intervals=confidence_intervals,
111+
)
112+
113+
return causal_test_result
114+
76115
def __str__(self):
77116
treatment_config = {self.treatment_variable.name: self.treatment_value}
78117
control_config = {self.treatment_variable.name: self.control_value}

0 commit comments

Comments
 (0)