Skip to content

Commit e509eb3

Browse files
Refactor execute_test_suite into the CausalTestSuite object
1 parent 32df1c9 commit e509eb3

File tree

2 files changed

+113
-21
lines changed

2 files changed

+113
-21
lines changed

causal_testing/testing/causal_test_suite.py

Lines changed: 107 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
"""This module contains the CausalTestSuite class, for details on using it:
22
https://causal-testing-framework.readthedocs.io/en/latest/test_suite.html"""
3+
import logging
4+
35
from collections import UserDict
46
from typing import Type, Iterable
57
from causal_testing.testing.base_test_case import BaseTestCase
68
from causal_testing.testing.causal_test_case import CausalTestCase
79
from causal_testing.testing.estimators import Estimator
10+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
11+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
12+
from causal_testing.specification.causal_specification import CausalSpecification
13+
14+
logger = logging.getLogger(__name__)
815

916

1017
class CausalTestSuite(UserDict):
@@ -20,11 +27,11 @@ class CausalTestSuite(UserDict):
2027
"""
2128

2229
def add_test_object(
23-
self,
24-
base_test_case: BaseTestCase,
25-
causal_test_case_list: Iterable[CausalTestCase],
26-
estimators_classes: Iterable[Type[Estimator]],
27-
estimate_type: str = "ate",
30+
self,
31+
base_test_case: BaseTestCase,
32+
causal_test_case_list: Iterable[CausalTestCase],
33+
estimators_classes: Iterable[Type[Estimator]],
34+
estimate_type: str = "ate",
2835
):
2936
"""
3037
A setter object to allow for the easy construction of the dictionary test suite structure
@@ -37,3 +44,98 @@ def add_test_object(
3744
"""
3845
test_object = {"tests": causal_test_case_list, "estimators": estimators_classes, "estimate_type": estimate_type}
3946
self.data[base_test_case] = test_object
47+
48+
def execute_test_suite(self, data_collector: ObservationalDataCollector,
49+
causal_specification: CausalSpecification) -> list[CausalTestResult]:
50+
"""Execute a suite of causal tests and return the results in a list
51+
:param test_suite: CasualTestSuite object
52+
:return: A dictionary where each key is the name of the estimators specified and the values are lists of
53+
causal_test_result objects
54+
"""
55+
if data_collector.data.empty:
56+
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
57+
data_collector.collect_data()
58+
test_suite_results = {}
59+
for edge in self:
60+
logger.info("treatment: %s", edge.treatment_variable)
61+
logger.info("outcome: %s", edge.outcome_variable)
62+
minimal_adjustment_set = causal_specification.causal_dag.identification(edge)
63+
minimal_adjustment_set = minimal_adjustment_set - set(edge.treatment_variable.name)
64+
minimal_adjustment_set = minimal_adjustment_set - set(edge.outcome_variable.name)
65+
66+
variables_for_positivity = list(minimal_adjustment_set) + [
67+
edge.treatment_variable.name,
68+
edge.outcome_variable.name,
69+
]
70+
71+
if self._check_positivity_violation(variables_for_positivity, causal_specification.scenario, data_collector.data):
72+
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
73+
74+
estimators = self[edge]["estimators"]
75+
tests = self[edge]["tests"]
76+
results = {}
77+
for estimator_class in estimators:
78+
causal_test_results = []
79+
80+
for test in tests:
81+
estimator = estimator_class(
82+
test.treatment_variable.name,
83+
test.treatment_value,
84+
test.control_value,
85+
minimal_adjustment_set,
86+
test.outcome_variable.name,
87+
)
88+
if estimator.df is None:
89+
estimator.df = data_collector.data
90+
causal_test_result = self._return_causal_test_results(estimator, test)
91+
causal_test_results.append(causal_test_result)
92+
93+
results[estimator_class.__name__] = causal_test_results
94+
test_suite_results[edge] = results
95+
return test_suite_results
96+
97+
def _return_causal_test_results(self, estimator, causal_test_case):
98+
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
99+
100+
:param estimator: An Estimator class object
101+
:param causal_test_case: The concrete test case to be executed
102+
:return: a CausalTestResult object containing the confidence intervals
103+
"""
104+
if not hasattr(estimator, f"estimate_{causal_test_case.estimate_type}"):
105+
raise AttributeError(f"{estimator.__class__} has no {causal_test_case.estimate_type} method.")
106+
estimate_effect = getattr(estimator, f"estimate_{causal_test_case.estimate_type}")
107+
effect, confidence_intervals = estimate_effect(**causal_test_case.estimate_params)
108+
causal_test_result = CausalTestResult(
109+
estimator=estimator,
110+
test_value=TestValue(causal_test_case.estimate_type, effect),
111+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
112+
confidence_intervals=confidence_intervals,
113+
)
114+
115+
return causal_test_result
116+
117+
def _check_positivity_violation(self, variables_list, scenario, data):
118+
"""Check whether the dataframe has a positivity violation relative to the specified variables list.
119+
120+
A positivity violation occurs when there is a stratum of the dataframe which does not have any data. Put simply,
121+
if we split the dataframe into covariate sub-groups, each sub-group must contain both a treated and untreated
122+
individual. If a positivity violation occurs, causal inference is still possible using a properly specified
123+
parametric estimator. Therefore, we should not throw an exception upon violation but raise a warning instead.
124+
125+
:param variables_list: The list of variables for which positivity must be satisfied.
126+
:return: True if positivity is violated, False otherwise.
127+
"""
128+
if not (set(variables_list) - {x.name for x in scenario.hidden_variables()}).issubset(
129+
data.columns
130+
):
131+
missing_variables = set(variables_list) - set(data.columns)
132+
logger.warning(
133+
"Positivity violation: missing data for variables %s.\n"
134+
"Causal inference is only valid if a well-specified parametric model is used.\n"
135+
"Alternatively, consider restricting analysis to executions without the variables:"
136+
".",
137+
missing_variables,
138+
)
139+
return True
140+
141+
return False

tests/testing_tests/test_causal_test_suite.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import os
33
import numpy as np
44
import pandas as pd
5-
from causal_testing.testing.causal_test_engine import CausalTestEngine
6-
from causal_testing.testing.causal_test_engine import CausalTestSuite
5+
from causal_testing.testing.causal_test_suite import CausalTestSuite
76
from causal_testing.testing.causal_test_case import CausalTestCase
87
from causal_testing.testing.base_test_case import BaseTestCase
98
from causal_testing.specification.variable import Input, Output
@@ -61,6 +60,9 @@ def setUp(self) -> None:
6160
self.test_suite.add_test_object(
6261
base_test_case=self.base_test_case, causal_test_case_list=test_list, estimators_classes=self.estimators
6362
)
63+
self.causal_specification = CausalSpecification(self.scenario, self.causal_dag)
64+
65+
self.data_collector = ObservationalDataCollector(self.scenario, self.df)
6466

6567
def test_adding_test_object(self):
6668
"test an object can be added to the test_suite using the add_test_object function"
@@ -93,9 +95,8 @@ def test_return_single_test_object(self):
9395

9496
def test_execute_test_suite_single_base_test_case(self):
9597
"""Check that the test suite can return the correct results from dummy data for a single base_test-case"""
96-
causal_test_engine = self.create_causal_test_engine()
9798

98-
causal_test_results = causal_test_engine.execute_test_suite(test_suite=self.test_suite)
99+
causal_test_results = self.test_suite.execute_test_suite(self.data_collector, self.causal_specification)
99100
causal_test_case_result = causal_test_results[self.base_test_case]
100101
self.assertAlmostEqual(causal_test_case_result["LinearRegressionEstimator"][0].test_value.value, 4, delta=1e-10)
101102

@@ -109,21 +110,10 @@ def test_execute_test_suite_multiple_estimators(self):
109110
test_suite_2_estimators.add_test_object(
110111
base_test_case=self.base_test_case, causal_test_case_list=test_list, estimators_classes=estimators
111112
)
112-
causal_test_engine = self.create_causal_test_engine()
113-
causal_test_results = causal_test_engine.execute_test_suite(test_suite=test_suite_2_estimators)
113+
causal_test_results = test_suite_2_estimators.execute_test_suite(self.data_collector, self.causal_specification)
114114
causal_test_case_result = causal_test_results[self.base_test_case]
115115
linear_regression_result = causal_test_case_result["LinearRegressionEstimator"][0]
116116
causal_forrest_result = causal_test_case_result["CausalForestEstimator"][0]
117117
self.assertAlmostEqual(linear_regression_result.test_value.value, 4, delta=1e-1)
118118
self.assertAlmostEqual(causal_forrest_result.test_value.value, 4, delta=1e-1)
119119

120-
def create_causal_test_engine(self):
121-
"""
122-
Creating test engine is relatively computationally complex, this function allows for it to
123-
easily be made in only the tests that require it.
124-
"""
125-
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
126-
127-
data_collector = ObservationalDataCollector(self.scenario, self.df)
128-
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
129-
return causal_test_engine

0 commit comments

Comments
 (0)