Skip to content

Commit 651a358

Browse files
committed
Merge branch 'main' of github.com:CITCOM-project/CausalTestingFramework into json-cate
2 parents 4e939a1 + 106f346 commit 651a358

File tree

4 files changed

+66
-48
lines changed

4 files changed

+66
-48
lines changed

causal_testing/json_front/json_class.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import logging
77

8+
from collections.abc import Iterable, Mapping
89
from dataclasses import dataclass
910
from pathlib import Path
1011
from statistics import StatisticsError
@@ -154,10 +155,11 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
154155

155156
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
156157
failures = 0
158+
test["estimator"] = estimators[test["estimator"]]
159+
if "formula" in test:
160+
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
157161
for concrete_test in concrete_tests:
158-
failed = self._execute_test_case(
159-
concrete_test, estimators[test["estimator"]], f_flag, test.get("conditions", [])
160-
)
162+
failed = self._execute_test_case(concrete_test, test, f_flag)
161163
if failed:
162164
failures += 1
163165
return failures
@@ -178,19 +180,26 @@ def _populate_metas(self):
178180
for meta in self.scenario.variables_of_type(Meta):
179181
meta.populate(self.data)
180182

181-
def _execute_test_case(
182-
self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool, conditions: list[str]
183-
) -> bool:
183+
def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool) -> bool:
184184
"""Executes a singular test case, prints the results and returns the test case result
185185
:param causal_test_case: The concrete test case to be executed
186+
:param test: Single JSON test definition stored in a mapping (dict)
186187
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
187188
:return: A boolean that if True indicates the causal test case passed and if false indicates the test case
188189
failed.
189190
:rtype: bool
190191
"""
191192
failed = False
192193

193-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator, conditions)
194+
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
195+
if not var.distribution:
196+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
197+
fitter.fit()
198+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
199+
var.distribution = getattr(scipy.stats, dist)(**params)
200+
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
201+
202+
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
194203
causal_test_result = causal_test_engine.execute_test(
195204
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
196205
)
@@ -216,11 +225,10 @@ def _execute_test_case(
216225
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
217226
return failed
218227

219-
def _setup_test(
220-
self, causal_test_case: CausalTestCase, estimator: Estimator, conditions: list[str]
221-
) -> tuple[CausalTestEngine, Estimator]:
228+
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str]=None) -> tuple[CausalTestEngine, Estimator]:
222229
"""Create the necessary inputs for a single test case
223230
:param causal_test_case: The concrete test case to be executed
231+
:param test: Single JSON test definition stored in a mapping (dict)
224232
:returns:
225233
- causal_test_engine - Test Engine instance for the test being run
226234
- estimation_model - Estimator instance for the test being run
@@ -234,27 +242,21 @@ def _setup_test(
234242
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
235243
treatment_var = causal_test_case.treatment_variable
236244
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
237-
estimation_model = estimator(
238-
treatment=treatment_var.name,
239-
treatment_value=causal_test_case.treatment_value,
240-
control_value=causal_test_case.control_value,
241-
adjustment_set=minimal_adjustment_set,
242-
outcome=causal_test_case.outcome_variable.name,
243-
df=causal_test_engine.scenario_execution_data_df,
244-
effect_modifiers=causal_test_case.effect_modifier_configuration,
245-
)
246-
247-
self.add_modelling_assumptions(estimation_model)
248-
245+
estimator_kwargs = {
246+
"treatment": treatment_var.name,
247+
"treatment_value": causal_test_case.treatment_value,
248+
"control_value": causal_test_case.control_value,
249+
"adjustment_set": minimal_adjustment_set,
250+
"outcome": causal_test_case.outcome_variable.name,
251+
"df": causal_test_engine.scenario_execution_data_df,
252+
"effect_modifiers": causal_test_case.effect_modifier_configuration,
253+
}
254+
if "formula" in test:
255+
estimator_kwargs["formula"] = test["formula"]
256+
257+
estimation_model = test["estimator"](**estimator_kwargs)
249258
return causal_test_engine, estimation_model
250259

251-
def add_modelling_assumptions(self, estimation_model: Estimator): # pylint: disable=unused-argument
252-
"""Optional abstract method where user functionality can be written to determine what assumptions are required
253-
for specific test cases
254-
:param estimation_model: estimator model instance for the current running test.
255-
"""
256-
return
257-
258260
def _append_to_file(self, line: str, log_level: int = None):
259261
"""Appends given line(s) to the current output file. If log_level is specified it also logs that message to the
260262
logging level.
@@ -263,9 +265,7 @@ def _append_to_file(self, line: str, log_level: int = None):
263265
is possible to use the inbuilt logging level variables such as logging.INFO and logging.WARNING
264266
"""
265267
with open(self.output_path, "a", encoding="utf-8") as f:
266-
f.write(
267-
line + "\n",
268-
)
268+
f.write(line)
269269
if log_level:
270270
logger.log(level=log_level, msg=line)
271271

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def __init__(
312312
self.formula = formula
313313
else:
314314
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
315-
self.formula = f"{outcome} ~ {'+'.join(((terms)))}"
315+
self.formula = f"{outcome} ~ {'+'.join(terms)}"
316316

317317
for term in self.effect_modifiers:
318318
self.adjustment_set.add(term)

examples/poisson/example_run_causal_tests.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,6 @@ def populate_num_shapes_unit(data):
149149
}
150150

151151

152-
class MyJsonUtility(JsonUtility):
153-
"""Extension of JsonUtility class to add modelling assumptions to the estimator instance"""
154-
155-
def add_modelling_assumptions(self, estimation_model: Estimator):
156-
# Add squared intensity term as a modelling assumption if intensity is the treatment of the test
157-
if "intensity" in estimation_model.treatment[0]:
158-
estimation_model.intercept = 0
159-
160-
161152
def test_run_causal_tests():
162153
ROOT = os.path.realpath(os.path.dirname(__file__))
163154

@@ -166,7 +157,7 @@ def test_run_causal_tests():
166157
dag_path = f"{ROOT}/dag.dot"
167158
data_path = f"{ROOT}/data.csv"
168159

169-
json_utility = MyJsonUtility(log_path) # Create an instance of the extended JsonUtility class
160+
json_utility = JsonUtility(log_path) # Create an instance of the extended JsonUtility class
170161
json_utility.set_paths(
171162
json_path, dag_path, [data_path]
172163
) # Set the path to the data.csv, dag.dot and causal_tests.json file
@@ -178,8 +169,8 @@ def test_run_causal_tests():
178169

179170

180171
if __name__ == "__main__":
181-
args = MyJsonUtility.get_args()
182-
json_utility = MyJsonUtility(args.log_path) # Create an instance of the extended JsonUtility class
172+
args = JsonUtility.get_args()
173+
json_utility = JsonUtility(args.log_path) # Create an instance of the extended JsonUtility class
183174
json_utility.set_paths(
184175
args.json_path, args.dag_path, args.data_path
185176
) # Set the path to the data.csv, dag.dot and causal_tests.json file

tests/json_front_tests/test_json_class.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77

88
from causal_testing.testing.estimators import LinearRegressionEstimator
9-
from causal_testing.testing.causal_test_outcome import NoEffect
9+
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
1010
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
1111
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1212
from causal_testing.specification.variable import Input, Output, Meta
@@ -186,9 +186,36 @@ def test_generate_tests_from_json_no_dist(self):
186186
temp_out = reader.readlines()
187187
self.assertIn("failed", temp_out[-1])
188188

189+
def test_formula_in_json_test(self):
190+
example_test = {
191+
"tests": [
192+
{
193+
"name": "test1",
194+
"mutations": {"test_input": "Increase"},
195+
"estimator": "LinearRegressionEstimator",
196+
"estimate_type": "ate",
197+
"effect_modifiers": [],
198+
"expectedEffect": {"test_output": "Positive"},
199+
"skip": False,
200+
"formula": "test_output ~ test_input"
201+
}
202+
]
203+
}
204+
self.json_class.test_plan = example_test
205+
effects = {"Positive": Positive()}
206+
mutates = {
207+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
208+
> self.json_class.scenario.variables[x].z3
209+
}
210+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
211+
212+
self.json_class.generate_tests(effects, mutates, estimators, False)
213+
with open("temp_out.txt", 'r') as reader:
214+
temp_out = reader.readlines()
215+
self.assertIn("test_output ~ test_input", ''.join(temp_out))
216+
189217
def tearDown(self) -> None:
190-
pass
191-
# remove_temp_dir_if_existent()
218+
remove_temp_dir_if_existent()
192219

193220

194221
def populate_example(*args, **kwargs):

0 commit comments

Comments
 (0)