Skip to content

Commit 979e005

Browse files
Merge pull request #233 from CITCOM-project/adjustment_set_formula_check
Remove Redundant Adjustment Set Calc when Formula
2 parents 18b6ccd + 9d51557 commit 979e005

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

causal_testing/json_front/json_class.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from causal_testing.specification.variable import Input, Meta, Output
2323
from causal_testing.testing.causal_test_case import CausalTestCase
2424
from causal_testing.testing.causal_test_result import CausalTestResult
25-
from causal_testing.testing.estimators import Estimator
25+
from causal_testing.testing.estimators import Estimator, LinearRegressionEstimator, LogisticRegressionEstimator
2626
from causal_testing.testing.base_test_case import BaseTestCase
2727
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2828

@@ -307,20 +307,29 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
307307
:returns:
308308
- estimation_model - Estimator instance for the test being run
309309
"""
310-
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
311-
treatment_var = causal_test_case.treatment_variable
312-
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
313-
estimator_kwargs = {
314-
"treatment": treatment_var.name,
315-
"treatment_value": causal_test_case.treatment_value,
316-
"control_value": causal_test_case.control_value,
317-
"adjustment_set": minimal_adjustment_set,
318-
"outcome": causal_test_case.outcome_variable.name,
319-
"effect_modifiers": causal_test_case.effect_modifier_configuration,
320-
"alpha": test["alpha"] if "alpha" in test else 0.05,
321-
}
310+
estimator_kwargs = {}
322311
if "formula" in test:
312+
if test["estimator"] != (LinearRegressionEstimator or LogisticRegressionEstimator):
313+
raise TypeError(
314+
"Currently only LinearRegressionEstimator and LogisticRegressionEstimator supports the use of "
315+
"formulas"
316+
)
323317
estimator_kwargs["formula"] = test["formula"]
318+
estimator_kwargs["adjustment_set"] = None
319+
else:
320+
minimal_adjustment_set = self.causal_specification.causal_dag.identification(
321+
causal_test_case.base_test_case
322+
)
323+
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable}
324+
estimator_kwargs["adjustment_set"] = minimal_adjustment_set
325+
326+
estimator_kwargs["treatment"] = causal_test_case.treatment_variable.name
327+
estimator_kwargs["treatment_value"] = causal_test_case.treatment_value
328+
estimator_kwargs["control_value"] = causal_test_case.control_value
329+
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
330+
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
331+
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05
332+
324333
estimation_model = test["estimator"](**estimator_kwargs)
325334
return estimation_model
326335

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def estimate_coefficient(self) -> float:
334334
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
335335
assert set(treatment).issubset(
336336
model.params.index.tolist()
337-
), f"{treatment} not in\n{' '+str(model.params.index).replace(newline, newline+' ')}"
337+
), f"{treatment} not in\n{' ' + str(model.params.index).replace(newline, newline + ' ')}"
338338
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
339339
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
340340
if str(self.df.dtypes[self.treatment]) != "object":

tests/json_front_tests/test_json_class.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import scipy
55
import os
66

7-
from causal_testing.testing.estimators import LinearRegressionEstimator
7+
from causal_testing.testing.estimators import LinearRegressionEstimator, CausalForestEstimator
88
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
99
from tests.test_helpers import remove_temp_dir_if_existent
1010
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
@@ -291,6 +291,31 @@ def test_no_data_provided(self):
291291
with self.assertRaises(ValueError):
292292
json_class.setup(self.scenario)
293293

294+
def test_estimator_formula_type_check(self):
295+
example_test = {
296+
"tests": [
297+
{
298+
"name": "test1",
299+
"mutations": {"test_input": "Increase"},
300+
"estimator": "CausalForestEstimator",
301+
"estimate_type": "ate",
302+
"effect_modifiers": [],
303+
"expected_effect": {"test_output": "Positive"},
304+
"skip": False,
305+
"formula": "test_output ~ test_input",
306+
}
307+
]
308+
}
309+
self.json_class.test_plan = example_test
310+
effects = {"Positive": Positive()}
311+
mutates = {
312+
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
313+
> self.json_class.scenario.variables[x].z3
314+
}
315+
estimators = {"CausalForestEstimator": CausalForestEstimator}
316+
with self.assertRaises(TypeError):
317+
self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
318+
294319
def tearDown(self) -> None:
295320
remove_temp_dir_if_existent()
296321
if os.path.exists("temp_out.txt"):

0 commit comments

Comments
 (0)