Skip to content

Commit 18bafcc

Browse files
Refactor to use ** instead of catching
1 parent 9e35c06 commit 18bafcc

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
control_value: Any = None,
3131
treatment_value: Any = None,
3232
estimate_type: str = "ate",
33-
estimate_params: dict = "None",
33+
estimate_params: dict = None,
3434
effect_modifier_configuration: dict[Variable:Any] = None,
3535
):
3636
"""
@@ -48,7 +48,8 @@ def __init__(
4848
self.treatment_variable = base_test_case.treatment_variable
4949
self.treatment_value = treatment_value
5050
self.estimate_type = estimate_type
51-
self.estimate_params = estimate_params
51+
if estimate_params is None:
52+
self.estimate_params = {}
5253
self.effect = base_test_case.effect
5354

5455
if effect_modifier_configuration:

causal_testing/testing/causal_test_engine.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,8 @@ def _return_causal_test_results(self, estimator, causal_test_case):
162162
)
163163
elif causal_test_case.estimate_type == "risk_ratio":
164164
logger.debug("calculating risk_ratio")
165-
try:
166-
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio(causal_test_case.estimate_params)
167-
except TypeError:
168-
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
165+
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio(**causal_test_case.estimate_params)
166+
169167
causal_test_result = CausalTestResult(
170168
estimator=estimator,
171169
test_value=TestValue("risk_ratio", risk_ratio),

0 commit comments

Comments
 (0)