Skip to content

Commit c507180

Browse files
Update docstrings and formatting
1 parent b7ec69e commit c507180

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from causal_testing.testing.causal_test_outcome import CausalTestResult
99
from causal_testing.testing.estimators import Estimator
1010
from causal_testing.testing.base_causal_test import BaseCausalTest
11+
from causal_testing.testing.causal_test_suite import CausalTestSuite
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -53,13 +54,18 @@ def __init__(self, causal_specification: CausalSpecification, data_collector: Da
5354
self.scenario_execution_data_df = self.data_collector.collect_data(**kwargs)
5455

5556
def execute_test_suite(
56-
self, test_suite: dict[dict[CausalTestCase], dict[Estimator], str]
57-
) -> list[CausalTestResult]:
58-
"""Execute a suite of causal tests and return the results in a list"""
57+
self, test_suite: CausalTestSuite) -> list[CausalTestResult]:
58+
59+
"""Execute a suite of causal tests and return the results in a list
60+
:param test_suite: CasualTestSuite object
61+
:return: test_suite results which contains a list of CausalTestResult objects
62+
63+
"""
5964
if self.scenario_execution_data_df.empty:
6065
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
6166
test_suite_results = {}
62-
for edge in test_suite:
67+
test_suite_dict = test_suite.test_suite
68+
for edge in test_suite_dict:
6369
print("edge: ")
6470
print(edge)
6571
logger.info("treatment: %s", edge.treatment_variable)
@@ -69,16 +75,16 @@ def execute_test_suite(
6975
minimal_adjustment_set = minimal_adjustment_set - set(edge.outcome_variable.name)
7076

7177
variables_for_positivity = (
72-
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
78+
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
7379
)
7480
if self._check_positivity_violation(variables_for_positivity):
7581
# TODO: We should allow users to continue because positivity can be overcome with parametric models
7682
# TODO: When we implement causal contracts, we should also note the positivity violation there
7783
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
7884

79-
estimators = test_suite[edge]["estimators"]
80-
tests = test_suite[edge]["tests"]
81-
estimate_type = test_suite[edge]["estimate_type"]
85+
estimators = test_suite_dict[edge]["estimators"]
86+
tests = test_suite_dict[edge]["tests"]
87+
estimate_type = test_suite_dict[edge]["estimate_type"]
8288
results = []
8389
for EstimatorClass in estimators:
8490
causal_test_results = []
@@ -104,7 +110,7 @@ def execute_test_suite(
104110
return test_suite_results
105111

106112
def execute_test(
107-
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
113+
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
108114
) -> CausalTestResult:
109115
"""Execute a causal test case and return the causal test result.
110116
@@ -151,7 +157,15 @@ def execute_test(
151157

152158
# TODO (MF) I think that the test oracle procedure should go in here.
153159
# This way, the user can supply it as a function or something, which can be applied to the result of CI
160+
154161
def _return_causal_test_results(self, estimate_type, estimator, causal_test_case):
162+
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
163+
164+
:param estimate_type: A string which denotes the type of estimate to return
165+
:param estimator: An Estimator class object
166+
:param causal_test_case: The concrete test case to be executed
167+
:return: a CausalTestResult object containing the confidence intervals
168+
"""
155169
# TODO: Some estimators also return the CATE. Find the best way to add this into the causal test engine.
156170
if estimate_type == "cate":
157171
logger.debug("calculating cate")

examples/lr91/causal_test_max_conductances_test_suite.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,9 @@ def causal_testing_sensitivity_analysis():
8686
def effects_on_APD90(observational_data_path, test_suite):
8787
"""Perform causal testing for the scenario in which we investigate the causal effect of a given input on APD90.
8888
89-
:param observational_data_path: Path to observational data containing previous executions of the LR91 model.
90-
:param treatment_var: The input variable whose effect on APD90 we are interested in.
91-
:param control_val: The control value for the treatment variable (before intervention).
92-
:param treatment_val: The treatment value for the treatment variable (after intervention).
93-
:param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
94-
:return: ATE for the effect of G_K on APD90
89+
:param: test_suite: A CausalTestSuite object containing a dictionary of base_test_cases and the treatment/outcome
90+
values to be tested
91+
:return: causal_test_results containing a list of causal_test_result objects
9592
"""
9693
# 1. Define Causal DAG
9794
causal_dag = CausalDAG('./dag.dot')
@@ -128,11 +125,8 @@ def effects_on_APD90(observational_data_path, test_suite):
128125
# 8. Create an instance of the causal test engine
129126
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
130127

131-
# 9. Obtain the minimal adjustment set from the causal DAG
132-
133-
# 10. Run the causal test and print results
134-
causal_test_results = causal_test_engine.execute_test_suite(test_suite.test_suite)
135-
print(causal_test_results)
128+
# 9. Run the causal test suite
129+
causal_test_results = causal_test_engine.execute_test_suite(test_suite)
136130
return causal_test_results
137131

138132

0 commit comments

Comments
 (0)