Skip to content

Commit 5f93882

Browse files
committed
Improved categorical treatment handling
1 parent 9401840 commit 5f93882

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

causal_testing/testing/causal_test_outcome.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,21 @@ def apply(self, res: CausalTestResult) -> bool:
4141
class NoEffect(CausalTestOutcome):
4242
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
4343

44-
def apply(self, res: CausalTestResult, atol: float = 1e-10) -> bool:
44+
def __init__(self, atol: float = 1e-10):
45+
self.atol = atol
46+
47+
def apply(self, res: CausalTestResult) -> bool:
4548
if res.test_value.type == "ate":
46-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < atol)
49+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < self.atol)
4750
if res.test_value.type == "coefficient":
4851
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
4952
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
5053
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
5154
return all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)) or all(
52-
abs(v) < 1e-10 for v in value
55+
abs(v) < self.atol for v in value
5356
)
5457
if res.test_value.type == "risk_ratio":
55-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=atol)
58+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)
5659
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
5760

5861

causal_testing/testing/estimators.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,8 @@ def estimate_unit_ate(self) -> float:
339339
print(model.conf_int())
340340
treatment = [self.treatment]
341341
if str(self.df.dtypes[self.treatment]) == "object":
342-
reference = min(self.df[self.treatment])
343-
treatment = [
344-
x.replace("[", "[T.")
345-
for x in dmatrix(
346-
f"{self.treatment}-1", self.df.query(f"{self.treatment} != '{reference}'"), return_type="dataframe"
347-
).columns
348-
]
342+
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
343+
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
349344
assert set(treatment).issubset(
350345
model.params.index.tolist()
351346
), f"{treatment} not in\n{' '+str(model.params.index).replace(newline, newline+' ')}"

examples/poisson/example_run_causal_tests.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def test_run_causal_tests():
163163
) # Set the path to the data.csv, dag.dot and causal_tests.json file
164164

165165
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
166-
json_utility.setup(scenario=modelling_scenario) # Sets up all the necessary parts of the json_class needed to execute tests
166+
json_utility.setup(
167+
scenario=modelling_scenario
168+
) # Sets up all the necessary parts of the json_class needed to execute tests
167169

168170
json_utility.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
169171

@@ -176,6 +178,8 @@ def test_run_causal_tests():
176178
) # Set the path to the data.csv, dag.dot and causal_tests.json file
177179

178180
# Load the Causal Variables into the JsonUtility class ready to be used in the tests
179-
json_utility.setup(scenario=modelling_scenario) # Sets up all the necessary parts of the json_class needed to execute tests
181+
json_utility.setup(
182+
scenario=modelling_scenario
183+
) # Sets up all the necessary parts of the json_class needed to execute tests
180184

181185
json_utility.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=args.f)

tests/json_front_tests/test_json_class.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ def test_setting_no_path(self):
4848
json_class.set_paths(self.json_path, self.dag_path, None)
4949
self.assertEqual(json_class.input_paths.data_paths, []) # Needs to be list of Paths
5050

51-
52-
5351
def test_setting_paths(self):
5452
self.assertEqual(self.json_class.input_paths.json_path, Path(self.json_path))
5553
self.assertEqual(self.json_class.input_paths.dag_path, Path(self.dag_path))

0 commit comments

Comments
 (0)