Skip to content

Commit 08ab5aa

Browse files
authored
Merge pull request #307 from CITCOM-project/jmafoster1/ignore-cycles-dafni
Jmafoster1/ignore cycles dafni
2 parents a20954c + 3a6e2df commit 08ab5aa

File tree

6 files changed

+43
-33
lines changed

6 files changed

+43
-33
lines changed

causal_testing/estimation/abstract_regression_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
outcome=outcome,
4242
df=df,
4343
effect_modifiers=effect_modifiers,
44+
alpha=alpha,
4445
query=query,
4546
)
4647

causal_testing/estimation/logistic_regression_estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44

55
import numpy as np
6+
import pandas as pd
67
import statsmodels.formula.api as smf
78

89
from causal_testing.estimation.abstract_regression_estimator import RegressionEstimator
@@ -31,11 +32,12 @@ def add_modelling_assumptions(self):
3132
self.modelling_assumptions.append("The outcome must be binary.")
3233
self.modelling_assumptions.append("Independently and identically distributed errors.")
3334

34-
def estimate_unit_odds_ratio(self) -> float:
35+
def estimate_unit_odds_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
3536
"""Estimate the odds ratio of increasing the treatment by one. In logistic regression, this corresponds to the
3637
coefficient of the treatment of interest.
3738
3839
:return: The odds ratio. Confidence intervals are not yet supported.
3940
"""
4041
model = self._run_regression(self.df)
41-
return np.exp(model.params[self.treatment])
42+
ci_low, ci_high = np.exp(model.conf_int(self.alpha).loc[self.treatment])
43+
return pd.Series(np.exp(model.params[self.treatment])), [pd.Series(ci_low), pd.Series(ci_high)]

causal_testing/json_front/json_class.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from statistics import StatisticsError
1212

1313
import pandas as pd
14+
import numpy as np
1415
import scipy
1516
from fitter import Fitter, get_common_distributions
1617

@@ -21,7 +22,7 @@
2122
from causal_testing.specification.scenario import Scenario
2223
from causal_testing.specification.variable import Input, Meta, Output
2324
from causal_testing.testing.causal_test_case import CausalTestCase
24-
from causal_testing.testing.causal_test_result import CausalTestResult
25+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
2526
from causal_testing.testing.base_test_case import BaseTestCase
2627
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2728

@@ -136,8 +137,10 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
136137
failed, msg = self._run_concrete_metamorphic_test(test, f_flag, effects)
137138
# If we have a variable to mutate
138139
else:
139-
if test["estimate_type"] == "coefficient":
140-
failed, msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
140+
if test["estimate_type"] in ["coefficient", "unit_odds_ratio"]:
141+
failed, msg = self._run_coefficient_test(
142+
test=test, f_flag=f_flag, effects=effects, estimate_type=test["estimate_type"]
143+
)
141144
else:
142145
failed, msg = self._run_metamorphic_tests(
143146
test=test, f_flag=f_flag, effects=effects, mutates=mutates
@@ -146,7 +149,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
146149
test["result"] = msg
147150
return self.test_plan["tests"]
148151

149-
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
152+
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict, estimate_type: str = "coefficient"):
150153
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
151154
152155
:param test: Single JSON test definition stored in a mapping (dict)
@@ -163,10 +166,11 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
163166
causal_test_case = CausalTestCase(
164167
base_test_case=base_test_case,
165168
expected_causal_effect=next(effects[effect] for variable, effect in test["expected_effect"].items()),
166-
estimate_type="coefficient",
169+
estimate_type=estimate_type,
167170
effect_modifier_configuration={self.scenario.variables[v] for v in test.get("effect_modifiers", [])},
168171
)
169172
failed, result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
173+
170174
msg = (
171175
f"Executing test: {test['name']} \n"
172176
+ f" {causal_test_case} \n"
@@ -273,10 +277,17 @@ def _execute_test_case(
273277
failed = False
274278

275279
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test)
276-
causal_test_result = causal_test_case.execute_test(
277-
estimator=estimation_model, data_collector=self.data_collector
278-
)
279-
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
280+
try:
281+
causal_test_result = causal_test_case.execute_test(
282+
estimator=estimation_model, data_collector=self.data_collector
283+
)
284+
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
285+
except np.linalg.LinAlgError as e:
286+
result = CausalTestResult(
287+
estimator=estimation_model,
288+
test_value=TestValue("Error", str(e)),
289+
)
290+
return None, result
280291

281292
if "coverage" in test and test["coverage"]:
282293
adequacy_metric = DataAdequacy(causal_test_case, estimation_model)

causal_testing/testing/causal_test_outcome.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class SomeEffect(CausalTestOutcome):
2929
def apply(self, res: CausalTestResult) -> bool:
3030
if res.ci_low() is None or res.ci_high() is None:
3131
return None
32-
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
32+
if res.test_value.type in ("risk_ratio", "hazard_ratio", "unit_odds_ratio"):
3333
return any(
3434
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())
3535
)
@@ -54,7 +54,7 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
5454
self.ctol = ctol
5555

5656
def apply(self, res: CausalTestResult) -> bool:
57-
if res.test_value.type in ("risk_ratio", "hazard_ratio"):
57+
if res.test_value.type in ("risk_ratio", "hazard_ratio", "unit_odds_ratio"):
5858
return any(
5959
ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol)
6060
for ci_low, ci_high, value in zip(res.ci_low(), res.ci_high(), res.test_value.value)

dafni/main_dafni.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from causal_testing.specification.variable import Input, Output
1313
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect, SomeEffect
1414
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
15+
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
1516
from causal_testing.json_front.json_class import JsonUtility
1617

1718

@@ -29,44 +30,36 @@ def get_args(test_args=None) -> argparse.Namespace:
2930
- argparse.Namespace - A Namsespace consisting of the arguments to this script
3031
"""
3132
parser = argparse.ArgumentParser(description="A script for running the CTF on DAFNI.")
32-
33-
parser.add_argument("--data_path", required=True, help="Path to the input runtime data (.csv)", nargs="+")
34-
35-
parser.add_argument(
36-
"--tests_path", required=True, help="Input configuration file path " "containing the causal tests (.json)"
37-
)
38-
33+
parser.add_argument("-d", "--data_path", required=True, help="Path to the input runtime data (.csv)", nargs="+")
3934
parser.add_argument(
40-
"-i", "--ignore_cycles", action="store_true", help="Whether to ignore cycles in the DAG.", default=False
35+
"-t", "--tests_path", required=True, help="Input configuration file path " "containing the causal tests (.json)"
4136
)
42-
4337
parser.add_argument(
38+
"-v",
4439
"--variables_path",
4540
required=True,
4641
help="Input configuration file path " "containing the predefined variables (.json)",
4742
)
48-
4943
parser.add_argument(
44+
"-D",
5045
"--dag_path",
5146
required=True,
5247
help="Input configuration file path containing a valid DAG (.dot). "
5348
"Note: this must be supplied if the --tests argument isn't provided.",
5449
)
55-
56-
parser.add_argument("--output_path", required=False, help="Path to the output directory.")
57-
50+
parser.add_argument(
51+
"-i", "--ignore_cycles", action="store_true", help="Whether to ignore cycles in the DAG.", default=False
52+
)
5853
parser.add_argument(
5954
"-f", default=False, help="(Optional) Failure flag to step the framework from running if a test has failed."
6055
)
61-
56+
parser.add_argument("-o", "--output_path", required=False, help="Path to the output directory.")
6257
parser.add_argument(
6358
"-w",
6459
default=False,
6560
help="(Optional) Specify to overwrite any existing output files. "
66-
"This can lead to the loss of existing outputs if not "
67-
"careful",
61+
"This can lead to the loss of existing outputs if not careful",
6862
)
69-
7063
args = parser.parse_args(test_args)
7164

7265
# Convert these to Path objects for main()
@@ -165,7 +158,10 @@ def main():
165158

166159
modelling_scenario.setup_treatment_variables()
167160

168-
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
161+
estimators = {
162+
"LinearRegressionEstimator": LinearRegressionEstimator,
163+
"LogisticRegressionEstimator": LogisticRegressionEstimator,
164+
}
169165

170166
# Step 3: Define the expected variables
171167

tests/estimation_tests/test_logistic_regression_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ def setUpClass(cls) -> None:
1919
def test_odds_ratio(self):
2020
df = self.scarf_df.copy()
2121
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, set(), "completed", df)
22-
odds = logistic_regression_estimator.estimate_unit_odds_ratio()
23-
self.assertEqual(round(odds, 4), 0.8948)
22+
odds, _ = logistic_regression_estimator.estimate_unit_odds_ratio()
23+
self.assertEqual(round(odds[0], 4), 0.8948)

0 commit comments

Comments
 (0)