Skip to content

Commit cab1086

Browse files
Update json_class.py to work with new test execution
1 parent 3ab87c5 commit cab1086

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

causal_testing/json_front/json_class.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5555
self.causal_specification = None
5656
self.output_path = Path(output_path)
5757
self.check_file_exists(self.output_path, output_overwrite)
58+
self.data_collector = None
5859

5960
def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None):
6061
"""
@@ -87,6 +88,8 @@ def setup(self, scenario: Scenario):
8788
"attribute with a dataframe before calling .setup()"
8889
)
8990
self._populate_metas()
91+
self.data_collector = ObservationalDataCollector(
92+
self.scenario, self.data)
9093

9194
def _create_abstract_test_case(self, test, mutates, effects):
9295
assert len(test["mutations"]) == 1
@@ -149,14 +152,15 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
149152
treatment_value=test["treatment_value"],
150153
estimate_type=test["estimate_type"],
151154
)
155+
152156
failed, _ = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
153157

154158
msg = (
155-
f"Executing concrete test: {test['name']} \n"
156-
+ f"treatment variable: {test['treatment_variable']} \n"
157-
+ f"outcome_variable = {outcome_variable} \n"
158-
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
159-
+ f"Result: {'FAILED' if failed else 'Passed'}"
159+
f"Executing concrete test: {test['name']} \n"
160+
+ f"treatment variable: {test['treatment_variable']} \n"
161+
+ f"outcome_variable = {outcome_variable} \n"
162+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
163+
+ f"Result: {'FAILED' if failed else 'Passed'}"
160164
)
161165
print(msg)
162166
self._append_to_file(msg, logging.INFO)
@@ -183,12 +187,12 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
183187
)
184188
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
185189
msg = (
186-
f"Executing test: {test['name']} \n"
187-
+ f" {causal_test_case} \n"
188-
+ " "
189-
+ ("\n ").join(str(result[1]).split("\n"))
190-
+ "==============\n"
191-
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
190+
f"Executing test: {test['name']} \n"
191+
+ f" {causal_test_case} \n"
192+
+ " "
193+
+ ("\n ").join(str(result[1]).split("\n"))
194+
+ "==============\n"
195+
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
192196
)
193197
return msg
194198

@@ -216,13 +220,13 @@ def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
216220
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
217221

218222
msg = (
219-
f"Executing test: {test['name']} \n"
220-
+ " abstract_test \n"
221-
+ f" {abstract_test} \n"
222-
+ f" {abstract_test.treatment_variable.name},"
223-
+ f" {abstract_test.treatment_variable.distribution} \n"
224-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
225-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
223+
f"Executing test: {test['name']} \n"
224+
+ " abstract_test \n"
225+
+ f" {abstract_test} \n"
226+
+ f" {abstract_test.treatment_variable.name},"
227+
+ f" {abstract_test.treatment_variable.distribution} \n"
228+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
229+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
226230
)
227231
return msg
228232

@@ -231,6 +235,7 @@ def _execute_tests(self, concrete_tests, test, f_flag):
231235
details = []
232236
if "formula" in test:
233237
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
238+
234239
for concrete_test in concrete_tests:
235240
failed, result = self._execute_test_case(concrete_test, test, f_flag)
236241
details.append(result)
@@ -246,7 +251,8 @@ def _populate_metas(self):
246251
meta.populate(self.data)
247252

248253
def _execute_test_case(
249-
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
254+
self, causal_test_case: CausalTestCase, test: Iterable[Mapping],
255+
f_flag: bool
250256
) -> (bool, CausalTestResult):
251257
"""Executes a singular test case, prints the results and returns the test case result
252258
:param causal_test_case: The concrete test case to be executed
@@ -258,10 +264,10 @@ def _execute_test_case(
258264
"""
259265
failed = False
260266

261-
causal_test_engine, estimation_model = self._setup_test(
262-
causal_test_case, test, test["conditions"] if "conditions" in test else None
263-
)
264-
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
267+
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test, data=self.data_collector.data)
268+
causal_test_result = causal_test_case.execute_test(estimator=estimation_model,
269+
data_collector=self.data_collector,
270+
causal_specification=self.causal_specification)
265271

266272
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
267273

@@ -284,7 +290,7 @@ def _execute_test_case(
284290
return failed, causal_test_result
285291

286292
def _setup_test(
287-
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
293+
self, causal_test_case: CausalTestCase, test: Mapping, data: pd.DataFrame
288294
) -> tuple[CausalTestEngine, Estimator]:
289295
"""Create the necessary inputs for a single test case
290296
:param causal_test_case: The concrete test case to be executed
@@ -296,12 +302,6 @@ def _setup_test(
296302
- causal_test_engine - Test Engine instance for the test being run
297303
- estimation_model - Estimator instance for the test being run
298304
"""
299-
300-
data_collector = ObservationalDataCollector(
301-
self.scenario, self.data.query(" & ".join(conditions)) if conditions else self.data
302-
)
303-
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
304-
305305
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
306306
treatment_var = causal_test_case.treatment_variable
307307
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
@@ -311,14 +311,14 @@ def _setup_test(
311311
"control_value": causal_test_case.control_value,
312312
"adjustment_set": minimal_adjustment_set,
313313
"outcome": causal_test_case.outcome_variable.name,
314-
"df": causal_test_engine.scenario_execution_data_df,
314+
"df": data,
315315
"effect_modifiers": causal_test_case.effect_modifier_configuration,
316316
"alpha": test["alpha"] if "alpha" in test else 0.05,
317317
}
318318
if "formula" in test:
319319
estimator_kwargs["formula"] = test["formula"]
320320
estimation_model = test["estimator"](**estimator_kwargs)
321-
return causal_test_engine, estimation_model
321+
return estimation_model
322322

323323
def _append_to_file(self, line: str, log_level: int = None):
324324
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
@@ -370,7 +370,7 @@ def get_args(test_args=None) -> argparse.Namespace:
370370
parser.add_argument(
371371
"-w",
372372
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
373-
"careful",
373+
"careful",
374374
action="store_true",
375375
)
376376
parser.add_argument(

0 commit comments

Comments
 (0)