Skip to content

Commit fbd4842

Browse files
Add optional formula to JSON tests
1 parent e09eadb commit fbd4842

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

causal_testing/json_front/json_class.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77

8+
from collections.abc import Iterable, Mapping
89
from dataclasses import dataclass
910
from pathlib import Path
1011

@@ -112,8 +113,9 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
112113

113114
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
114115
failures = 0
116+
test["estimator"] = estimators[test["estimator"]]
115117
for concrete_test in concrete_tests:
116-
failed = self._execute_test_case(concrete_test, estimators[test["estimator"]], f_flag)
118+
failed = self._execute_test_case(concrete_test, test, f_flag)
117119
if failed:
118120
failures += 1
119121
return failures
@@ -141,7 +143,7 @@ def _populate_metas(self):
141143
var.distribution = getattr(scipy.stats, dist)(**params)
142144
logger.info(var.name + f" {dist}({params})")
143145

144-
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
146+
def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool) -> bool:
145147
"""Executes a singular test case, prints the results and returns the test case result
146148
:param causal_test_case: The concrete test case to be executed
147149
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
@@ -151,7 +153,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
151153
"""
152154
failed = False
153155

154-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
156+
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
155157
causal_test_result = causal_test_engine.execute_test(
156158
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
157159
)
@@ -176,7 +178,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
176178
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
177179
return failed
178180

179-
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
181+
def _setup_test(self, causal_test_case: CausalTestCase, test) -> tuple[CausalTestEngine, Estimator]:
180182
"""Create the necessary inputs for a single test case
181183
:param causal_test_case: The concrete test case to be executed
182184
:returns:
@@ -190,17 +192,27 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
190192
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
191193
treatment_var = causal_test_case.treatment_variable
192194
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
193-
estimation_model = estimator(
194-
treatment=treatment_var.name,
195-
treatment_value=causal_test_case.treatment_value,
196-
control_value=causal_test_case.control_value,
197-
adjustment_set=minimal_adjustment_set,
198-
outcome=causal_test_case.outcome_variable.name,
199-
df=causal_test_engine.scenario_execution_data_df,
200-
effect_modifiers=causal_test_case.effect_modifier_configuration,
201-
)
202-
203-
self.add_modelling_assumptions(estimation_model)
195+
if "formula" in test:
196+
estimation_model = test["estimator"](
197+
treatment=treatment_var.name,
198+
treatment_value=causal_test_case.treatment_value,
199+
control_value=causal_test_case.control_value,
200+
adjustment_set=minimal_adjustment_set,
201+
outcome=causal_test_case.outcome_variable.name,
202+
df=causal_test_engine.scenario_execution_data_df,
203+
effect_modifiers=causal_test_case.effect_modifier_configuration,
204+
formula=test["formula"]
205+
)
206+
else:
207+
estimation_model = test["estimator"](
208+
treatment=treatment_var.name,
209+
treatment_value=causal_test_case.treatment_value,
210+
control_value=causal_test_case.control_value,
211+
adjustment_set=minimal_adjustment_set,
212+
outcome=causal_test_case.outcome_variable.name,
213+
df=causal_test_engine.scenario_execution_data_df,
214+
effect_modifiers=causal_test_case.effect_modifier_configuration,
215+
)
204216

205217
return causal_test_engine, estimation_model
206218

0 commit comments

Comments
 (0)