|
1 | 1 | """This module contains the CausalTestCase class, a class that holds the information required for a causal test"""
|
2 | 2 | import logging
|
3 | 3 | from typing import Any
|
| 4 | +import numpy as np |
4 | 5 |
|
5 | 6 | from causal_testing.specification.variable import Variable
|
6 | 7 | from causal_testing.testing.causal_test_outcome import CausalTestOutcome
|
@@ -79,15 +80,21 @@ def _return_causal_test_results(self, estimator) -> CausalTestResult:
|
79 | 80 | if not hasattr(estimator, f"estimate_{self.estimate_type}"):
|
80 | 81 | raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
|
81 | 82 | estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
|
82 |
| - effect, confidence_intervals = estimate_effect(**self.estimate_params) |
83 |
| - causal_test_result = CausalTestResult( |
84 |
| - estimator=estimator, |
85 |
| - test_value=TestValue(self.estimate_type, effect), |
86 |
| - effect_modifier_configuration=self.effect_modifier_configuration, |
87 |
| - confidence_intervals=confidence_intervals, |
88 |
| - ) |
89 |
| - |
90 |
| - return causal_test_result |
| 83 | + try: |
| 84 | + effect, confidence_intervals = estimate_effect(**self.estimate_params) |
| 85 | + return CausalTestResult( |
| 86 | + estimator=estimator, |
| 87 | + test_value=TestValue(self.estimate_type, effect), |
| 88 | + effect_modifier_configuration=self.effect_modifier_configuration, |
| 89 | + confidence_intervals=confidence_intervals, |
| 90 | + ) |
| 91 | + except np.linalg.LinAlgError: |
| 92 | + return CausalTestResult( |
| 93 | + estimator=estimator, |
| 94 | + test_value=TestValue(self.estimate_type, "LinAlgError"), |
| 95 | + effect_modifier_configuration=self.effect_modifier_configuration, |
| 96 | + confidence_intervals=None, |
| 97 | + ) |
91 | 98 |
|
92 | 99 | def __str__(self):
|
93 | 100 | treatment_config = {self.treatment_variable.name: self.treatment_value}
|
|
0 commit comments