Skip to content

Commit 106f346

Browse files
Merge pull request #176 from CITCOM-project/json_moddeling_assumption_method
JSON add estimator formula
2 parents b68aee3 + fb06924 commit 106f346

File tree

4 files changed

+58
-42
lines changed

4 files changed

+58
-42
lines changed

causal_testing/json_front/json_class.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77

8+
from collections.abc import Iterable, Mapping
89
from dataclasses import dataclass
910
from pathlib import Path
1011
from statistics import StatisticsError
@@ -118,8 +119,11 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
118119

119120
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
120121
failures = 0
122+
test["estimator"] = estimators[test["estimator"]]
123+
if "formula" in test:
124+
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
121125
for concrete_test in concrete_tests:
122-
failed = self._execute_test_case(concrete_test, estimators[test["estimator"]], f_flag)
126+
failed = self._execute_test_case(concrete_test, test, f_flag)
123127
if failed:
124128
failures += 1
125129
return failures
@@ -147,17 +151,18 @@ def _populate_metas(self):
147151
var.distribution = getattr(scipy.stats, dist)(**params)
148152
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
149153

150-
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
154+
def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool) -> bool:
151155
"""Executes a singular test case, prints the results and returns the test case result
152156
:param causal_test_case: The concrete test case to be executed
157+
:param test: Single JSON test definition stored in a mapping (dict)
153158
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
154159
:return: A boolean that if True indicates the causal test case passed and if false indicates the test case
155160
failed.
156161
:rtype: bool
157162
"""
158163
failed = False
159164

160-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
165+
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
161166
causal_test_result = causal_test_engine.execute_test(
162167
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
163168
)
@@ -183,9 +188,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
183188
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
184189
return failed
185190

186-
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
191+
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[CausalTestEngine, Estimator]:
187192
"""Create the necessary inputs for a single test case
188193
:param causal_test_case: The concrete test case to be executed
194+
:param test: Single JSON test definition stored in a mapping (dict)
189195
:returns:
190196
- causal_test_engine - Test Engine instance for the test being run
191197
- estimation_model - Estimator instance for the test being run
@@ -197,27 +203,21 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
197203
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
198204
treatment_var = causal_test_case.treatment_variable
199205
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
200-
estimation_model = estimator(
201-
treatment=treatment_var.name,
202-
treatment_value=causal_test_case.treatment_value,
203-
control_value=causal_test_case.control_value,
204-
adjustment_set=minimal_adjustment_set,
205-
outcome=causal_test_case.outcome_variable.name,
206-
df=causal_test_engine.scenario_execution_data_df,
207-
effect_modifiers=causal_test_case.effect_modifier_configuration,
208-
)
209-
210-
self.add_modelling_assumptions(estimation_model)
211-
206+
estimator_kwargs = {
207+
"treatment": treatment_var.name,
208+
"treatment_value": causal_test_case.treatment_value,
209+
"control_value": causal_test_case.control_value,
210+
"adjustment_set": minimal_adjustment_set,
211+
"outcome": causal_test_case.outcome_variable.name,
212+
"df": causal_test_engine.scenario_execution_data_df,
213+
"effect_modifiers": causal_test_case.effect_modifier_configuration,
214+
}
215+
if "formula" in test:
216+
estimator_kwargs["formula"] = test["formula"]
217+
218+
estimation_model = test["estimator"](**estimator_kwargs)
212219
return causal_test_engine, estimation_model
213220

214-
def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: disable=unused-argument
215-
"""Optional abstract method where user functionality can be written to determine what assumptions are required
216-
for specific test cases
217-
:param estimation_model: estimator model instance for the current running test.
218-
"""
219-
return
220-
221221
def _append_to_file(self, line: str, log_level: int = None):
222222
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
223223
logging level.
@@ -226,9 +226,7 @@ def _append_to_file(self, line: str, log_level: int = None):
226226
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
227227
"""
228228
with open(self.output_path, "a", encoding="utf-8") as f:
229-
f.write(
230-
line + "\n",
231-
)
229+
f.write(line)
232230
if log_level:
233231
logger.log(level=log_level, msg=line)
234232

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def __init__(
320320
self.formula = formula
321321
else:
322322
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
323-
self.formula = f"{outcome} ~ {'+'.join(((terms)))}"
323+
self.formula = f"{outcome} ~ {'+'.join(terms)}"
324324

325325
for term in self.effect_modifiers:
326326
self.adjustment_set.add(term)

examples/poisson/example_run_causal_tests.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,6 @@ def populate_num_shapes_unit(data):
149149
}
150150

151151

152-
class MyJsonUtility(JsonUtility):
153-
"""Extension of JsonUtility class to add modelling assumptions to the estimator instance"""
154-
155-
def add_modelling_assumptions(self, estimation_model: Estimator):
156-
# Add squared intensity term as a modelling assumption if intensity is the treatment of the test
157-
if "intensity" in estimation_model.treatment[0]:
158-
estimation_model.intercept = 0
159-
160-
161152
def test_run_causal_tests():
162153
ROOT = os.path.realpath(os.path.dirname(__file__))
163154

@@ -166,7 +157,7 @@ def test_run_causal_tests():
166157
dag_path = f"{ROOT}/dag.dot"
167158
data_path = f"{ROOT}/data.csv"
168159

169-
json_utility = MyJsonUtility(log_path) # Create an instance of the extended JsonUtility class
160+
json_utility = JsonUtility(log_path) # Create an instance of the extended JsonUtility class
170161
json_utility.set_paths(
171162
json_path, dag_path, [data_path]
172163
) # Set the path to the data.csv, dag.dot and causal_tests.json file
@@ -178,8 +169,8 @@ def test_run_causal_tests():
178169

179170

180171
if __name__ == "__main__":
181-
args = MyJsonUtility.get_args()
182-
json_utility = MyJsonUtility(args.log_path) # Create an instance of the extended JsonUtility class
172+
args = JsonUtility.get_args()
173+
json_utility = JsonUtility(args.log_path) # Create an instance of the extended JsonUtility class
183174
json_utility.set_paths(
184175
args.json_path, args.dag_path, args.data_path
185176
) # Set the path to the data.csv, dag.dot and causal_tests.json file

tests/json_front_tests/test_json_class.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77

88
from causal_testing.testing.estimators import LinearRegressionEstimator
9-
from causal_testing.testing.causal_test_outcome import NoEffect
9+
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
1010
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
1111
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1212
from causal_testing.specification.variable import Input, Output, Meta
@@ -127,9 +127,36 @@ def test_generate_tests_from_json(self):
127127
temp_out = reader.readlines()
128128
self.assertIn("failed", temp_out[-1])
129129

130+
def test_formula_in_json_test(self):
131+
example_test = {
132+
"tests": [
133+
{
134+
"name": "test1",
135+
"mutations": {"test_input": "Increase"},
136+
"estimator": "LinearRegressionEstimator",
137+
"estimate_type": "ate",
138+
"effect_modifiers": [],
139+
"expectedEffect": {"test_output": "Positive"},
140+
"skip": False,
141+
"formula": "test_output ~ test_input"
142+
}
143+
]
144+
}
145+
self.json_class.test_plan = example_test
146+
effects = {"Positive": Positive()}
147+
mutates = {
148+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
149+
> self.json_class.scenario.variables[x].z3
150+
}
151+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
152+
153+
self.json_class.generate_tests(effects, mutates, estimators, False)
154+
with open("temp_out.txt", 'r') as reader:
155+
temp_out = reader.readlines()
156+
self.assertIn("test_output ~ test_input", ''.join(temp_out))
157+
130158
def tearDown(self) -> None:
131-
pass
132-
# remove_temp_dir_if_existent()
159+
remove_temp_dir_if_existent()
133160

134161

135162
def populate_example(*args, **kwargs):

0 commit comments

Comments
 (0)