|
1 | 1 | """This module contains the CausalTestCase class, a class that holds the information required for a causal test"""
|
2 | 2 | import logging
|
3 | 3 | from typing import Any
|
| 4 | +import numpy as np |
4 | 5 |
|
5 | 6 | from causal_testing.specification.variable import Variable
|
6 | 7 | from causal_testing.testing.causal_test_outcome import CausalTestOutcome
|
@@ -81,21 +82,19 @@ def _return_causal_test_results(self, estimator) -> CausalTestResult:
|
81 | 82 | estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
|
82 | 83 | try:
|
83 | 84 | effect, confidence_intervals = estimate_effect(**self.estimate_params)
|
84 |
| - causal_test_result = CausalTestResult( |
| 85 | + return CausalTestResult( |
85 | 86 | estimator=estimator,
|
86 | 87 | test_value=TestValue(self.estimate_type, effect),
|
87 | 88 | effect_modifier_configuration=self.effect_modifier_configuration,
|
88 | 89 | confidence_intervals=confidence_intervals,
|
89 | 90 | )
|
90 | 91 | except np.linalg.LinAlgError:
|
91 |
| - causal_test_result = CausalTestResult( |
| 92 | + return CausalTestResult( |
92 | 93 | estimator=estimator,
|
93 | 94 | test_value=TestValue(self.estimate_type, "LinAlgError"),
|
94 | 95 | effect_modifier_configuration=self.effect_modifier_configuration,
|
95 | 96 | confidence_intervals=None,
|
96 | 97 | )
|
97 |
| - finally: |
98 |
| - return causal_test_result |
99 | 98 |
|
100 | 99 | def __str__(self):
|
101 | 100 | treatment_config = {self.treatment_variable.name: self.treatment_value}
|
|
0 commit comments