Skip to content

Commit 36c2e4f

Browse files
Merge branch 'main' into JSON_treatment_var
# Conflicts: # causal_testing/json_front/json_class.py # tests/json_front_tests/test_json_class.py
2 parents 87a5425 + 106f346 commit 36c2e4f

File tree

4 files changed

+58
-38
lines changed

4 files changed

+58
-38
lines changed

causal_testing/json_front/json_class.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77

8+
from collections.abc import Iterable, Mapping
89
from dataclasses import dataclass
910
from pathlib import Path
1011
from statistics import StatisticsError
@@ -153,8 +154,11 @@ def _create_abstract_test_case(self, test, mutates, effects):
153154

154155
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
155156
failures = 0
157+
test["estimator"] = estimators[test["estimator"]]
158+
if "formula" in test:
159+
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
156160
for concrete_test in concrete_tests:
157-
failed = self._execute_test_case(concrete_test, estimators[test["estimator"]], f_flag)
161+
failed = self._execute_test_case(concrete_test, test, f_flag)
158162
if failed:
159163
failures += 1
160164
return failures
@@ -182,16 +186,17 @@ def _populate_metas(self):
182186
var.distribution = getattr(scipy.stats, dist)(**params)
183187
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
184188

185-
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
189+
def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool) -> bool:
186190
"""Executes a singular test case, prints the results and returns the test case result
187191
:param causal_test_case: The concrete test case to be executed
192+
:param test: Single JSON test definition stored in a mapping (dict)
188193
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
189194
:return: A boolean that if True indicates the causal test case passed and if false indicates the test case
190195
failed.
191196
:rtype: bool
192197
"""
193198
failed = False
194-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
199+
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
195200
causal_test_result = causal_test_engine.execute_test(
196201
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
197202
)
@@ -217,9 +222,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
217222
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
218223
return failed
219224

220-
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
225+
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[CausalTestEngine, Estimator]:
221226
"""Create the necessary inputs for a single test case
222227
:param causal_test_case: The concrete test case to be executed
228+
:param test: Single JSON test definition stored in a mapping (dict)
223229
:returns:
224230
- causal_test_engine - Test Engine instance for the test being run
225231
- estimation_model - Estimator instance for the test being run
@@ -231,27 +237,21 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
231237
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
232238
treatment_var = causal_test_case.treatment_variable
233239
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
234-
estimation_model = estimator(
235-
treatment=treatment_var.name,
236-
treatment_value=causal_test_case.treatment_value,
237-
control_value=causal_test_case.control_value,
238-
adjustment_set=minimal_adjustment_set,
239-
outcome=causal_test_case.outcome_variable.name,
240-
df=causal_test_engine.scenario_execution_data_df,
241-
effect_modifiers=causal_test_case.effect_modifier_configuration,
242-
)
243-
244-
self.add_modelling_assumptions(estimation_model)
245-
240+
estimator_kwargs = {
241+
"treatment": treatment_var.name,
242+
"treatment_value": causal_test_case.treatment_value,
243+
"control_value": causal_test_case.control_value,
244+
"adjustment_set": minimal_adjustment_set,
245+
"outcome": causal_test_case.outcome_variable.name,
246+
"df": causal_test_engine.scenario_execution_data_df,
247+
"effect_modifiers": causal_test_case.effect_modifier_configuration,
248+
}
249+
if "formula" in test:
250+
estimator_kwargs["formula"] = test["formula"]
251+
252+
estimation_model = test["estimator"](**estimator_kwargs)
246253
return causal_test_engine, estimation_model
247254

248-
def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: disable=unused-argument
249-
"""Optional abstract method where user functionality can be written to determine what assumptions are required
250-
for specific test cases
251-
:param estimation_model: estimator model instance for the current running test.
252-
"""
253-
return
254-
255255
def _append_to_file(self, line: str, log_level: int = None):
256256
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
257257
logging level.
@@ -261,7 +261,7 @@ def _append_to_file(self, line: str, log_level: int = None):
261261
"""
262262
with open(self.output_path, "a", encoding="utf-8") as f:
263263
f.write(
264-
line,
264+
line + "\n",
265265
)
266266
if log_level:
267267
logger.log(level=log_level, msg=line)

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def __init__(
320320
self.formula = formula
321321
else:
322322
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
323-
self.formula = f"{outcome} ~ {'+'.join(((terms)))}"
323+
self.formula = f"{outcome} ~ {'+'.join(terms)}"
324324

325325
for term in self.effect_modifiers:
326326
self.adjustment_set.add(term)

examples/poisson/example_run_causal_tests.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,6 @@ def populate_num_shapes_unit(data):
149149
}
150150

151151

152-
class MyJsonUtility(JsonUtility):
153-
"""Extension of JsonUtility class to add modelling assumptions to the estimator instance"""
154-
155-
def add_modelling_assumptions(self, estimation_model: Estimator):
156-
# Add squared intensity term as a modelling assumption if intensity is the treatment of the test
157-
if "intensity" in estimation_model.treatment[0]:
158-
estimation_model.intercept = 0
159-
160-
161152
def test_run_causal_tests():
162153
ROOT = os.path.realpath(os.path.dirname(__file__))
163154

@@ -166,7 +157,7 @@ def test_run_causal_tests():
166157
dag_path = f"{ROOT}/dag.dot"
167158
data_path = f"{ROOT}/data.csv"
168159

169-
json_utility = MyJsonUtility(log_path) # Create an instance of the extended JsonUtility class
160+
json_utility = JsonUtility(log_path) # Create an instance of the extended JsonUtility class
170161
json_utility.set_paths(
171162
json_path, dag_path, [data_path]
172163
) # Set the path to the data.csv, dag.dot and causal_tests.json file
@@ -178,8 +169,8 @@ def test_run_causal_tests():
178169

179170

180171
if __name__ == "__main__":
181-
args = MyJsonUtility.get_args()
182-
json_utility = MyJsonUtility(args.log_path) # Create an instance of the extended JsonUtility class
172+
args = JsonUtility.get_args()
173+
json_utility = JsonUtility(args.log_path) # Create an instance of the extended JsonUtility class
183174
json_utility.set_paths(
184175
args.json_path, args.dag_path, args.data_path
185176
) # Set the path to the data.csv, dag.dot and causal_tests.json file

tests/json_front_tests/test_json_class.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77

88
from causal_testing.testing.estimators import LinearRegressionEstimator
9-
from causal_testing.testing.causal_test_outcome import NoEffect
9+
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
1010
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
1111
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1212
from causal_testing.specification.variable import Input, Output, Meta
@@ -127,6 +127,34 @@ def test_run_json_tests_from_json(self):
127127
temp_out = reader.readlines()
128128
self.assertIn("failed", temp_out[-1])
129129

130+
def test_formula_in_json_test(self):
131+
example_test = {
132+
"tests": [
133+
{
134+
"name": "test1",
135+
"mutations": {"test_input": "Increase"},
136+
"estimator": "LinearRegressionEstimator",
137+
"estimate_type": "ate",
138+
"effect_modifiers": [],
139+
"expectedEffect": {"test_output": "Positive"},
140+
"skip": False,
141+
"formula": "test_output ~ test_input"
142+
}
143+
]
144+
}
145+
self.json_class.test_plan = example_test
146+
effects = {"Positive": Positive()}
147+
mutates = {
148+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
149+
> self.json_class.scenario.variables[x].z3
150+
}
151+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
152+
153+
self.json_class.generate_tests(effects, mutates, estimators, False)
154+
with open("temp_out.txt", 'r') as reader:
155+
temp_out = reader.readlines()
156+
self.assertIn("test_output ~ test_input", ''.join(temp_out))
157+
130158
def test_run_concrete_json_testcase(self):
131159
example_test = {
132160
"tests": [
@@ -150,6 +178,7 @@ def test_run_concrete_json_testcase(self):
150178
with open("temp_out.txt", 'r') as reader:
151179
temp_out = reader.readlines()
152180
self.assertIn("failed", temp_out[-1])
181+
153182
def tearDown(self) -> None:
154183
remove_temp_dir_if_existent()
155184

0 commit comments

Comments
 (0)