Skip to content

Commit 3b00417

Browse files
authored
Merge pull request #240 from CITCOM-project/fix-linalg-error
Fixed linalg error
2 parents 53b2a91 + d1089b1 commit 3b00417

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""This module contains the CausalTestCase class, a class that holds the information required for a causal test"""
22
import logging
33
from typing import Any
4+
import numpy as np
45

56
from causal_testing.specification.variable import Variable
67
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
@@ -79,15 +80,21 @@ def _return_causal_test_results(self, estimator) -> CausalTestResult:
7980
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
8081
raise AttributeError(f"{estimator.__class__} has no {self.estimate_type} method.")
8182
estimate_effect = getattr(estimator, f"estimate_{self.estimate_type}")
82-
effect, confidence_intervals = estimate_effect(**self.estimate_params)
83-
causal_test_result = CausalTestResult(
84-
estimator=estimator,
85-
test_value=TestValue(self.estimate_type, effect),
86-
effect_modifier_configuration=self.effect_modifier_configuration,
87-
confidence_intervals=confidence_intervals,
88-
)
89-
90-
return causal_test_result
83+
try:
84+
effect, confidence_intervals = estimate_effect(**self.estimate_params)
85+
return CausalTestResult(
86+
estimator=estimator,
87+
test_value=TestValue(self.estimate_type, effect),
88+
effect_modifier_configuration=self.effect_modifier_configuration,
89+
confidence_intervals=confidence_intervals,
90+
)
91+
except np.linalg.LinAlgError:
92+
return CausalTestResult(
93+
estimator=estimator,
94+
test_value=TestValue(self.estimate_type, "LinAlgError"),
95+
effect_modifier_configuration=self.effect_modifier_configuration,
96+
confidence_intervals=None,
97+
)
9198

9299
def __str__(self):
93100
treatment_config = {self.treatment_variable.name: self.treatment_value}

0 commit comments

Comments
 (0)