Skip to content

Commit b8ace8e

Browse files
committed
Removed data collector from causal test suite
1 parent 5495caa commit b8ace8e

File tree

12 files changed

+47
-103
lines changed

12 files changed

+47
-103
lines changed

causal_testing/json_front/json_class.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import scipy
1515
from fitter import Fitter, get_common_distributions
1616

17-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
1817
from causal_testing.generation.abstract_causal_test_case import AbstractCausalTestCase
1918
from causal_testing.specification.causal_dag import CausalDAG
2019
from causal_testing.specification.causal_specification import CausalSpecification
@@ -56,8 +55,8 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5655
self.scenario = None
5756
self.causal_specification = None
5857
self.output_path = Path(output_path)
58+
self.df = None
5959
self.check_file_exists(self.output_path, output_overwrite)
60-
self.data_collector = None
6160

6261
def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None):
6362
"""
@@ -70,7 +69,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
7069
data_paths = []
7170
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
7271

73-
def setup(self, scenario: Scenario, data=None, ignore_cycles=False):
72+
def setup(self, scenario: Scenario, ignore_cycles=False):
7473
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
7574
self.scenario = scenario
7675
self._get_scenario_variables()
@@ -83,21 +82,20 @@ def setup(self, scenario: Scenario, data=None, ignore_cycles=False):
8382
self.test_plan = json.load(f)
8483
# Populate the data
8584
if self.input_paths.data_paths:
86-
data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
87-
if data is None or len(data) == 0:
85+
self.df = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
86+
if self.df is None or len(self.df) == 0:
8887
raise ValueError(
8988
"No data found. Please either provide a path to a file containing data or manually populate the .data "
9089
"attribute with a dataframe before calling .setup()"
9190
)
92-
self.data_collector = ObservationalDataCollector(self.scenario, data)
9391
self._populate_metas()
9492

9593
def _create_abstract_test_case(self, test, mutates, effects):
9694
assert len(test["mutations"]) == 1
9795
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
9896

9997
if not treatment_var.distribution:
100-
fitter = Fitter(self.data_collector.data[treatment_var.name], distributions=get_common_distributions())
98+
fitter = Fitter(self.df[treatment_var.name], distributions=get_common_distributions())
10199
fitter.fit()
102100
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
103101
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
@@ -257,7 +255,7 @@ def _populate_metas(self):
257255
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
258256
"""
259257
for meta in self.scenario.variables_of_type(Meta):
260-
meta.populate(self.data_collector.data)
258+
meta.populate(self.df)
261259

262260
def _execute_test_case(
263261
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
@@ -273,9 +271,7 @@ def _execute_test_case(
273271
failed = False
274272

275273
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test)
276-
causal_test_result = causal_test_case.execute_test(
277-
estimator=estimation_model, data_collector=self.data_collector
278-
)
274+
causal_test_result = causal_test_case.execute_test(estimator=estimation_model)
279275
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
280276

281277
if "coverage" in test and test["coverage"]:
@@ -329,7 +325,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
329325
estimator_kwargs["control_value"] = causal_test_case.control_value
330326
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
331327
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
332-
estimator_kwargs["df"] = self.data_collector.collect_data()
328+
estimator_kwargs["df"] = self.df
333329
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05
334330

335331
estimation_model = test["estimator"](**estimator_kwargs)

causal_testing/testing/causal_test_adequacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def measure_adequacy(self):
105105
else:
106106
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
107107
try:
108-
results.append(self.test_case.execute_test(estimator, None))
108+
results.append(self.test_case.execute_test(estimator))
109109
except LinAlgError:
110110
logger.warning("Adequacy LinAlgError")
111111
continue

causal_testing/testing/causal_test_case.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import logging
44
from typing import Any
5+
import pandas as pd
56

67
from causal_testing.specification.variable import Variable
78
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
89
from causal_testing.testing.base_test_case import BaseTestCase
910
from causal_testing.estimation.abstract_estimator import Estimator
1011
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
11-
from causal_testing.data_collection.data_collector import DataCollector
1212

1313

1414
logger = logging.getLogger(__name__)
@@ -58,25 +58,13 @@ def __init__(
5858
else:
5959
self.effect_modifier_configuration = {}
6060

61-
def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
61+
def execute_test(self, estimator: type(Estimator)) -> CausalTestResult:
6262
"""Execute a causal test case and return the causal test result.
6363
64-
:param estimator: A reference to an Estimator class.
65-
:param data_collector: The data collector to be used which provides a dataframe for the Estimator
64+
:param estimator: An Estimator class object
6665
:return causal_test_result: A CausalTestResult for the executed causal test case.
6766
"""
68-
if estimator.df is None:
69-
estimator.df = data_collector.collect_data()
70-
71-
causal_test_result = self._return_causal_test_results(estimator)
72-
return causal_test_result
73-
74-
def _return_causal_test_results(self, estimator) -> CausalTestResult:
75-
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
7667

77-
:param estimator: An Estimator class object
78-
:return: a CausalTestResult object containing the confidence intervals
79-
"""
8068
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
8169
raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
8270
estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")

causal_testing/testing/causal_test_suite.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
https://causal-testing-framework.readthedocs.io/en/latest/test_suite.html"""
33

44
import logging
5-
6-
from collections import UserDict
75
from typing import Type, Iterable
6+
from collections import UserDict
7+
import pandas as pd
8+
89
from causal_testing.testing.base_test_case import BaseTestCase
910
from causal_testing.testing.causal_test_case import CausalTestCase
1011
from causal_testing.estimation.abstract_estimator import Estimator
1112
from causal_testing.testing.causal_test_result import CausalTestResult
12-
from causal_testing.data_collection.data_collector import DataCollector
1313
from causal_testing.specification.causal_specification import CausalSpecification
1414

1515
logger = logging.getLogger(__name__)
@@ -47,17 +47,14 @@ def add_test_object(
4747
self.data[base_test_case] = test_object
4848

4949
def execute_test_suite(
50-
self, data_collector: DataCollector, causal_specification: CausalSpecification
50+
self, causal_specification: CausalSpecification, df: pd.DataFrame
5151
) -> dict[str, CausalTestResult]:
5252
"""Execute a suite of causal tests and return the results in a list
53-
:param data_collector: The data collector to be used for the test_suite. Can be observational, experimental or
54-
custom
55-
:param causal_specification:
53+
:param causal_specification: A causal specification object which wraps up the scenario and causal DAG.
54+
:param df: A dataframe containing the test data.
5655
:return: A dictionary where each key is the name of the estimators specified and the values are lists of
5756
causal_test_result objects
5857
"""
59-
if data_collector.data.empty:
60-
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
6158
test_suite_results = {}
6259
for edge in self:
6360
logger.info("treatment: %s", edge.treatment_variable)
@@ -79,8 +76,9 @@ def execute_test_suite(
7976
test.control_value,
8077
minimal_adjustment_set,
8178
test.outcome_variable.name,
79+
df=df,
8280
)
83-
causal_test_result = test.execute_test(estimator, data_collector)
81+
causal_test_result = test.execute_test(estimator)
8482
causal_test_results.append(causal_test_result)
8583

8684
results[estimator_class.__name__] = causal_test_results

examples/covasim_/doubling_beta/example_beta.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def doubling_beta_CATE_on_csv(
6565
)
6666

6767
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
68-
causal_test_result = causal_test_case.execute_test(
69-
estimator=linear_regression_estimator, data_collector=data_collector
70-
)
68+
causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator)
7169

7270
# Repeat for association estimate (no adjustment)
7371
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(
@@ -79,9 +77,7 @@ def doubling_beta_CATE_on_csv(
7977
df=past_execution_df,
8078
formula="cum_infections ~ beta + I(beta ** 2)",
8179
)
82-
association_test_result = causal_test_case.execute_test(
83-
estimator=no_adjustment_linear_regression_estimator, data_collector=data_collector
84-
)
80+
association_test_result = causal_test_case.execute_test(estimator=no_adjustment_linear_regression_estimator)
8581

8682
# Store results for plotting
8783
results_dict["association"] = {
@@ -111,9 +107,7 @@ def doubling_beta_CATE_on_csv(
111107
df=counterfactual_past_execution_df,
112108
formula="cum_infections ~ beta + I(beta ** 2) + avg_age + contacts",
113109
)
114-
counterfactual_causal_test_result = causal_test_case.execute_test(
115-
estimator=linear_regression_estimator, data_collector=data_collector
116-
)
110+
counterfactual_causal_test_result = causal_test_case.execute_test(estimator=linear_regression_estimator)
117111

118112
results_dict["counterfactual"] = {
119113
"ate": counterfactual_causal_test_result.test_value.value,

examples/lr91/example_max_conductances.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from causal_testing.specification.scenario import Scenario
66
from causal_testing.specification.variable import Input, Output
77
from causal_testing.specification.causal_specification import CausalSpecification
8-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
98
from causal_testing.testing.causal_test_case import CausalTestCase
109
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
1110
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
@@ -134,17 +133,19 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
134133
treatment_value=treatment_val,
135134
)
136135

137-
# 7. Create a data collector
138-
data_collector = ObservationalDataCollector(scenario, pd.read_csv(observational_data_path))
139-
140136
# 8. Obtain the minimal adjustment set from the causal DAG
141137
minimal_adjustment_set = causal_dag.identification(base_test_case)
142138
linear_regression_estimator = LinearRegressionEstimator(
143-
treatment_var.name, treatment_val, control_val, minimal_adjustment_set, "APD90"
139+
treatment_var.name,
140+
treatment_val,
141+
control_val,
142+
minimal_adjustment_set,
143+
"APD90",
144+
df=pd.read_csv(observational_data_path),
144145
)
145146

146147
# 9. Run the causal test and print results
147-
causal_test_result = causal_test_case.execute_test(linear_regression_estimator, data_collector)
148+
causal_test_result = causal_test_case.execute_test(linear_regression_estimator)
148149
logger.info("%s", causal_test_result)
149150
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
150151

examples/lr91/example_max_conductances_test_suite.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from causal_testing.specification.scenario import Scenario
66
from causal_testing.specification.variable import Input, Output
77
from causal_testing.specification.causal_specification import CausalSpecification
8-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
98
from causal_testing.testing.causal_test_case import CausalTestCase
109
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
1110
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
@@ -143,11 +142,8 @@ def effects_on_APD90(observational_data_path, test_suite):
143142
# 5. Create a causal specification from the scenario and causal DAG
144143
causal_specification = CausalSpecification(scenario, causal_dag)
145144

146-
# 7. Create a data collector
147-
data_collector = ObservationalDataCollector(scenario, pd.read_csv(observational_data_path))
148-
149145
# 8. Run the causal test suite
150-
causal_test_results = test_suite.execute_test_suite(data_collector, causal_specification)
146+
causal_test_results = test_suite.execute_test_suite(causal_specification, pd.read_csv(observational_data_path))
151147
return causal_test_results
152148

153149

examples/poisson-line-process/example_poisson_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def causal_test_intensity_num_shapes(
116116
)
117117

118118
# 9. Execute the test
119-
causal_test_result = causal_test_case.execute_test(estimator, None)
119+
causal_test_result = causal_test_case.execute_test(estimator)
120120

121121
return causal_test_result
122122

examples/poisson/example_run_causal_tests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from causal_testing.testing.causal_test_outcome import ExactValue, Positive, Negative, NoEffect, CausalTestOutcome
88
from causal_testing.testing.causal_test_result import CausalTestResult
99
from causal_testing.json_front.json_class import JsonUtility
10-
from causal_testing.estimation.abstract_estimator import Estimator
1110
from causal_testing.specification.scenario import Scenario
1211
from causal_testing.specification.variable import Input, Output, Meta
1312

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_data_adequacy_group_by(self):
138138
treatment_value=treatment_strategy,
139139
estimate_type="hazard_ratio",
140140
)
141-
causal_test_result = causal_test_case.execute_test(estimation_model, None)
141+
causal_test_result = causal_test_case.execute_test(estimation_model)
142142
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, group_by="id")
143143
adequacy_metric.measure_adequacy()
144144
adequacy_dict = adequacy_metric.to_dict()

0 commit comments

Comments
 (0)