Skip to content

Commit 515d06b

Browse files
Skip identification of adjustment set if formula provided
1 parent f9f0976 commit 515d06b

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

causal_testing/json_front/json_class.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,12 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
165165
)
166166
failed, result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
167167
msg = (
168-
f"Executing test: {test['name']} \n"
169-
+ f" {causal_test_case} \n"
170-
+ " "
171-
+ ("\n ").join(str(result).split("\n"))
172-
+ "==============\n"
173-
+ f" Result: {'FAILED' if failed else 'Passed'}"
168+
f"Executing test: {test['name']} \n"
169+
+ f" {causal_test_case} \n"
170+
+ " "
171+
+ ("\n ").join(str(result).split("\n"))
172+
+ "==============\n"
173+
+ f" Result: {'FAILED' if failed else 'Passed'}"
174174
)
175175
self._append_to_file(msg, logging.INFO)
176176
return failed, result
@@ -192,11 +192,11 @@ def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict
192192
failed, msg = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
193193

194194
msg = (
195-
f"Executing concrete test: {test['name']} \n"
196-
+ f"treatment variable: {test['treatment_variable']} \n"
197-
+ f"outcome_variable = {outcome_variable} \n"
198-
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
199-
+ f"Result: {'FAILED' if failed else 'Passed'}"
195+
f"Executing concrete test: {test['name']} \n"
196+
+ f"treatment variable: {test['treatment_variable']} \n"
197+
+ f"outcome_variable = {outcome_variable} \n"
198+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
199+
+ f"Result: {'FAILED' if failed else 'Passed'}"
200200
)
201201
self._append_to_file(msg, logging.INFO)
202202
return failed, msg
@@ -225,13 +225,13 @@ def _run_metamorphic_tests(self, test: dict, f_flag: bool, effects: dict, mutate
225225
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
226226

227227
msg = (
228-
f"Executing test: {test['name']} \n"
229-
+ " abstract_test \n"
230-
+ f" {abstract_test} \n"
231-
+ f" {abstract_test.treatment_variable.name},"
232-
+ f" {abstract_test.treatment_variable.distribution} \n"
233-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
234-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
228+
f"Executing test: {test['name']} \n"
229+
+ " abstract_test \n"
230+
+ f" {abstract_test} \n"
231+
+ f" {abstract_test.treatment_variable.name},"
232+
+ f" {abstract_test.treatment_variable.distribution} \n"
233+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
234+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
235235
)
236236
self._append_to_file(msg, logging.INFO)
237237
return failures, msg
@@ -257,7 +257,7 @@ def _populate_metas(self):
257257
meta.populate(self.data_collector.data)
258258

259259
def _execute_test_case(
260-
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
260+
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
261261
) -> (bool, CausalTestResult):
262262
"""Executes a singular test case, prints the results and returns the test case result
263263
:param causal_test_case: The concrete test case to be executed
@@ -307,20 +307,23 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
307307
:returns:
308308
- estimation_model - Estimator instance for the test being run
309309
"""
310-
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
311-
treatment_var = causal_test_case.treatment_variable
312-
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
313-
estimator_kwargs = {
314-
"treatment": treatment_var.name,
315-
"treatment_value": causal_test_case.treatment_value,
316-
"control_value": causal_test_case.control_value,
317-
"adjustment_set": minimal_adjustment_set,
318-
"outcome": causal_test_case.outcome_variable.name,
319-
"effect_modifiers": causal_test_case.effect_modifier_configuration,
320-
"alpha": test["alpha"] if "alpha" in test else 0.05,
321-
}
310+
estimator_kwargs = {}
322311
if "formula" in test:
323312
estimator_kwargs["formula"] = test["formula"]
313+
estimator_kwargs["adjustment_set"] = {}
314+
else:
315+
minimal_adjustment_set = self.causal_specification.causal_dag.identification(
316+
causal_test_case.base_test_case)
317+
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable}
318+
estimator_kwargs["adjustment_set"] = minimal_adjustment_set
319+
320+
estimator_kwargs["treatment"] = causal_test_case.treatment_variable.name
321+
estimator_kwargs["treatment_value"] = causal_test_case.treatment_value
322+
estimator_kwargs["control_value"] = causal_test_case.control_value
323+
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
324+
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
325+
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05
326+
324327
estimation_model = test["estimator"](**estimator_kwargs)
325328
return estimation_model
326329

@@ -374,7 +377,7 @@ def get_args(test_args=None) -> argparse.Namespace:
374377
parser.add_argument(
375378
"-w",
376379
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
377-
"careful",
380+
"careful",
378381
action="store_true",
379382
)
380383
parser.add_argument(

0 commit comments

Comments
 (0)