Skip to content

Commit 886b911

Browse files
committed
All tests pass
1 parent d53fef2 commit 886b911

File tree

3 files changed

+23
-41
lines changed

3 files changed

+23
-41
lines changed

causal_testing/json_front/json_class.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
6868
data_paths = []
6969
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
7070

71-
def setup(self, scenario: Scenario):
71+
def setup(self, scenario: Scenario, data=[]):
7272
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
73-
data = []
7473
self.scenario = scenario
7574
self._get_scenario_variables()
7675
self.scenario.setup_treatment_variables()
@@ -85,7 +84,7 @@ def setup(self, scenario: Scenario):
8584
data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
8685
if len(data) == 0:
8786
raise ValueError(
88-
"No data found, either provide a path to a file containing data or manually populate the .data "
87+
"No data found. Please either provide a path to a file containing data or manually populate the .data "
8988
"attribute with a dataframe before calling .setup()"
9089
)
9190
self.data_collector = ObservationalDataCollector(self.scenario, data)
@@ -132,39 +131,18 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
132131
test["estimator"] = estimators[test["estimator"]]
133132
# If we have specified concrete control and treatment value
134133
if "mutations" not in test:
135-
failed, msg = self._run_concrete_metamorphic_test()
134+
failed, msg = self._run_concrete_metamorphic_test(test, f_flag, effects, mutates)
136135
# If we have a variable to mutate
137136
else:
138-
outcome_variable = next(
139-
iter(test["expected_effect"])
140-
) # Take first key from dictionary of expected effect
141-
base_test_case = BaseTestCase(
142-
treatment_variable=self.variables["inputs"][test["treatment_variable"]],
143-
outcome_variable=self.variables["outputs"][outcome_variable],
144-
)
145-
146-
causal_test_case = CausalTestCase(
147-
base_test_case=base_test_case,
148-
expected_causal_effect=effects[test["expected_effect"][outcome_variable]],
149-
control_value=test["control_value"],
150-
treatment_value=test["treatment_value"],
151-
estimate_type=test["estimate_type"],
152-
)
153-
154-
failed, msg = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
155-
156-
# msg = (
157-
# f"Executing concrete test: {test['name']} \n"
158-
# + f"treatment variable: {test['treatment_variable']} \n"
159-
# + f"outcome_variable = {outcome_variable} \n"
160-
# + f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
161-
# + f"Result: {'FAILED' if failed else 'Passed'}"
162-
# )
163-
# print(msg)
164-
self._append_to_file(msg, logging.INFO)
137+
if test["estimate_type"] == "coefficient":
138+
failed, msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
139+
else:
140+
failed, msg = self._run_metamorphic_tests(
141+
test=test, f_flag=f_flag, effects=effects, mutates=mutates
142+
)
165143
test["failed"] = failed
166144
test["result"] = msg
167-
return self.test_plan["tests"]
145+
return self.test_plan["tests"]
168146

169147
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
170148
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
@@ -299,7 +277,7 @@ def _execute_test_case(
299277
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
300278

301279
if "coverage" in test and test["coverage"]:
302-
adequacy_metric = DataAdequacy(causal_test_case, causal_test_engine, estimation_model)
280+
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, self.data_collector)
303281
adequacy_metric.measure_adequacy()
304282
# self._append_to_file(f"KURTOSIS: {effect_estimate.mean()}", logging.INFO)
305283
# self._append_to_file(f"PASSING: {sum(outcomes)}/{len(outcomes)}", logging.INFO)
@@ -331,7 +309,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
331309
data. Conditions should be in the query format detailed at
332310
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
333311
:returns:
334-
- causal_test_engine - Test Engine instance for the test being run
335312
- estimation_model - Estimator instance for the test being run
336313
"""
337314
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)

causal_testing/testing/causal_test_adequacy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def measure_adequacy(self):
4949
for i in range(self.bootstrap_size):
5050
estimator = deepcopy(self.estimator)
5151
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
52-
try:
53-
results.append(self.test.execute_test(estimator, self.data_collector))
54-
except np.LinAlgError:
55-
continue
52+
# try:
53+
results.append(self.test_case.execute_test(estimator, self.data_collector))
54+
# except np.LinAlgError:
55+
# continue
5656
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
5757
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]
5858

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def test_None_ci(self):
3636
"treatment_value": 1,
3737
"outcome": "A",
3838
"adjustment_set": set(),
39-
"test_value": test_value,
39+
"effect_estimate": 0,
40+
"effect_measure": "ate",
41+
"ci_high": None,
42+
"ci_low": None,
4043
},
4144
)
4245

@@ -264,7 +267,8 @@ def test_someEffect_str(self):
264267
"treatment_value": 1,
265268
"outcome": "A",
266269
"adjustment_set": set(),
267-
"test_value": test_value,
270+
"effect_estimate": 0,
271+
"effect_measure": "ate",
268272
"ci_low": -0.1,
269273
"ci_high": 0.2,
270274
},
@@ -287,7 +291,8 @@ def test_someEffect_dict(self):
287291
"treatment_value": 1,
288292
"outcome": "A",
289293
"adjustment_set": set(),
290-
"test_value": test_value,
294+
"effect_estimate": 0,
295+
"effect_measure": "ate",
291296
"ci_low": -0.1,
292297
"ci_high": 0.2,
293298
},

0 commit comments

Comments
 (0)