Skip to content

Commit 7620f96

Browse files
committed
Merge branch 'main' of github.com:CITCOM-project/CausalTestingFramework into test-adequacy
2 parents 3d51429 + 79dcb48 commit 7620f96

File tree

19 files changed

+399
-562
lines changed

19 files changed

+399
-562
lines changed

.github/workflows/publish-to-pypi.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
run: |
2222
pip3 install .
2323
pip3 install .[pypi]
24-
pip3 install build
24+
pip3 install build wheel
2525
pip3 install setuptools --upgrade
2626
pip3 install setuptools_scm
2727
- name: Build Package

causal_testing/data_collection/data_collector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
150150
151151
:return: A pandas dataframe containing execution data that is valid for the scenario-under-test.
152152
"""
153-
154153
execution_data_df = self.data
155154
for meta in self.scenario.metas():
156155
if meta.name not in self.data:

causal_testing/json_front/json_class.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77

8-
from collections.abc import Iterable, Mapping
8+
from collections.abc import Mapping
99
from dataclasses import dataclass
1010
from pathlib import Path
1111
from statistics import StatisticsError
@@ -23,7 +23,6 @@
2323
from causal_testing.specification.variable import Input, Meta, Output
2424
from causal_testing.testing.causal_test_case import CausalTestCase
2525
from causal_testing.testing.causal_test_result import CausalTestResult
26-
from causal_testing.testing.causal_test_engine import CausalTestEngine
2726
from causal_testing.testing.estimators import Estimator
2827
from causal_testing.testing.base_test_case import BaseTestCase
2928
from causal_testing.testing.causal_test_adequacy import DataAdequacy
@@ -51,12 +50,12 @@ class JsonUtility:
5150
def __init__(self, output_path: str, output_overwrite: bool = False):
5251
self.input_paths = None
5352
self.variables = {"inputs": {}, "outputs": {}, "metas": {}}
54-
self.data = []
5553
self.test_plan = None
5654
self.scenario = None
5755
self.causal_specification = None
5856
self.output_path = Path(output_path)
5957
self.check_file_exists(self.output_path, output_overwrite)
58+
self.data_collector = None
6059

6160
def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None):
6261
"""
@@ -71,6 +70,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
7170

7271
def setup(self, scenario: Scenario):
7372
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
73+
data = []
7474
self.scenario = scenario
7575
self._get_scenario_variables()
7676
self.scenario.setup_treatment_variables()
@@ -82,20 +82,21 @@ def setup(self, scenario: Scenario):
8282
self.test_plan = json.load(f)
8383
# Populate the data
8484
if self.input_paths.data_paths:
85-
self.data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
86-
if len(self.data) == 0:
85+
data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
86+
if len(data) == 0:
8787
raise ValueError(
8888
"No data found, either provide a path to a file containing data or manually populate the .data "
8989
"attribute with a dataframe before calling .setup()"
9090
)
91+
self.data_collector = ObservationalDataCollector(self.scenario, data)
9192
self._populate_metas()
9293

9394
def _create_abstract_test_case(self, test, mutates, effects):
9495
assert len(test["mutations"]) == 1
9596
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
9697

9798
if not treatment_var.distribution:
98-
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
99+
fitter = Fitter(self.data_collector.data[treatment_var.name], distributions=get_common_distributions())
99100
fitter.fit()
100101
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
101102
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
@@ -134,15 +135,36 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
134135
failed, msg = self._run_concrete_metamorphic_test()
135136
# If we have a variable to mutate
136137
else:
137-
if test["estimate_type"] == "coefficient":
138-
failed, msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
139-
else:
140-
failed, msg = self._run_metamorphic_tests(
141-
test=test, f_flag=f_flag, effects=effects, mutates=mutates
142-
)
138+
outcome_variable = next(
139+
iter(test["expected_effect"])
140+
) # Take first key from dictionary of expected effect
141+
base_test_case = BaseTestCase(
142+
treatment_variable=self.variables["inputs"][test["treatment_variable"]],
143+
outcome_variable=self.variables["outputs"][outcome_variable],
144+
)
145+
146+
causal_test_case = CausalTestCase(
147+
base_test_case=base_test_case,
148+
expected_causal_effect=effects[test["expected_effect"][outcome_variable]],
149+
control_value=test["control_value"],
150+
treatment_value=test["treatment_value"],
151+
estimate_type=test["estimate_type"],
152+
)
153+
154+
failed, msg = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
155+
156+
# msg = (
157+
# f"Executing concrete test: {test['name']} \n"
158+
# + f"treatment variable: {test['treatment_variable']} \n"
159+
# + f"outcome_variable = {outcome_variable} \n"
160+
# + f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
161+
# + f"Result: {'FAILED' if failed else 'Passed'}"
162+
# )
163+
# print(msg)
164+
self._append_to_file(msg, logging.INFO)
143165
test["failed"] = failed
144166
test["result"] = msg
145-
return self.test_plan["tests"]
167+
return self.test_plan["tests"]
146168

147169
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
148170
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
@@ -242,6 +264,7 @@ def _execute_tests(self, concrete_tests, test, f_flag):
242264
details = []
243265
if "formula" in test:
244266
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
267+
245268
for concrete_test in concrete_tests:
246269
failed, result = self._execute_test_case(concrete_test, test, f_flag)
247270
details.append(result)
@@ -254,10 +277,10 @@ def _populate_metas(self):
254277
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
255278
"""
256279
for meta in self.scenario.variables_of_type(Meta):
257-
meta.populate(self.data)
280+
meta.populate(self.data_collector.data)
258281

259282
def _execute_test_case(
260-
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
283+
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
261284
) -> (bool, CausalTestResult):
262285
"""Executes a singular test case, prints the results and returns the test case result
263286
:param causal_test_case: The concrete test case to be executed
@@ -269,10 +292,10 @@ def _execute_test_case(
269292
"""
270293
failed = False
271294

272-
causal_test_engine, estimation_model = self._setup_test(
273-
causal_test_case, test, test["conditions"] if "conditions" in test else None
295+
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test)
296+
causal_test_result = causal_test_case.execute_test(
297+
estimator=estimation_model, data_collector=self.data_collector
274298
)
275-
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
276299
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
277300

278301
if "coverage" in test and test["coverage"]:
@@ -300,9 +323,7 @@ def _execute_test_case(
300323
# logger.warning(" FAILED - expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
301324
return failed, causal_test_result
302325

303-
def _setup_test(
304-
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
305-
) -> tuple[CausalTestEngine, Estimator]:
326+
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estimator:
306327
"""Create the necessary inputs for a single test case
307328
:param causal_test_case: The concrete test case to be executed
308329
:param test: Single JSON test definition stored in a mapping (dict)
@@ -313,12 +334,6 @@ def _setup_test(
313334
- causal_test_engine - Test Engine instance for the test being run
314335
- estimation_model - Estimator instance for the test being run
315336
"""
316-
317-
data_collector = ObservationalDataCollector(
318-
self.scenario, self.data.query(" & ".join(conditions)) if conditions else self.data
319-
)
320-
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
321-
322337
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
323338
treatment_var = causal_test_case.treatment_variable
324339
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
@@ -328,14 +343,13 @@ def _setup_test(
328343
"control_value": causal_test_case.control_value,
329344
"adjustment_set": minimal_adjustment_set,
330345
"outcome": causal_test_case.outcome_variable.name,
331-
"df": causal_test_engine.scenario_execution_data_df,
332346
"effect_modifiers": causal_test_case.effect_modifier_configuration,
333347
"alpha": test["alpha"] if "alpha" in test else 0.05,
334348
}
335349
if "formula" in test:
336350
estimator_kwargs["formula"] = test["formula"]
337351
estimation_model = test["estimator"](**estimator_kwargs)
338-
return causal_test_engine, estimation_model
352+
return estimation_model
339353

340354
def _append_to_file(self, line: str, log_level: int = None):
341355
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the

causal_testing/specification/causal_specification.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""This module holds the abstract CausalSpecification data class, which holds a Scenario and CausalDag"""
22

3-
from abc import ABC
43
from dataclasses import dataclass
54
from typing import Union
65

@@ -11,9 +10,9 @@
1110

1211

1312
@dataclass
14-
class CausalSpecification(ABC):
13+
class CausalSpecification:
1514
"""
16-
Abstract Class for the Causal Specification (combination of Scenario and Causal Dag)
15+
Data class storing the Causal Specification (combination of Scenario and Causal Dag)
1716
"""
1817

1918
scenario: Scenario

causal_testing/testing/causal_test_case.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from causal_testing.specification.variable import Variable
66
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
77
from causal_testing.testing.base_test_case import BaseTestCase
8+
from causal_testing.testing.estimators import Estimator
9+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
10+
from causal_testing.data_collection.data_collector import DataCollector
11+
812

913
logger = logging.getLogger(__name__)
1014

@@ -73,6 +77,41 @@ def get_treatment_value(self):
7377
"""Return the treatment value of the treatment variable in this causal test case."""
7478
return self.treatment_value
7579

80+
def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
81+
"""Execute a causal test case and return the causal test result.
82+
83+
:param estimator: A reference to an Estimator class.
84+
:param data_collector: The data collector to be used which provides a dataframe for the Estimator
85+
:return causal_test_result: A CausalTestResult for the executed causal test case.
86+
"""
87+
if estimator.df is None:
88+
estimator.df = data_collector.collect_data()
89+
90+
logger.info("treatments: %s", self.treatment_variable.name)
91+
logger.info("outcomes: %s", self.outcome_variable)
92+
93+
causal_test_result = self._return_causal_test_results(estimator)
94+
return causal_test_result
95+
96+
def _return_causal_test_results(self, estimator) -> CausalTestResult:
97+
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
98+
99+
:param estimator: An Estimator class object
100+
:return: a CausalTestResult object containing the confidence intervals
101+
"""
102+
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
103+
raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
104+
estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
105+
effect, confidence_intervals = estimate_effect(**self.estimate_params)
106+
causal_test_result = CausalTestResult(
107+
estimator=estimator,
108+
test_value=TestValue(self.estimate_type, effect),
109+
effect_modifier_configuration=self.effect_modifier_configuration,
110+
confidence_intervals=confidence_intervals,
111+
)
112+
113+
return causal_test_result
114+
76115
def __str__(self):
77116
treatment_config = {self.treatment_variable.name: self.treatment_value}
78117
control_config = {self.treatment_variable.name: self.control_value}

0 commit comments

Comments
 (0)