Skip to content

Commit 4abe735

Browse files
committed
Minor corrections
1 parent a619069 commit 4abe735

File tree

4 files changed

+10
-7
lines changed

4 files changed

+10
-7
lines changed

causal_testing/specification/capabilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
treatment sequences that operate over time.
44
"""
55
from causal_testing.specification.variable import Variable
6+
from typing import Any
67

78

89
class Capability:
910
"""
1011
Data class to encapsulate temporal interventions.
1112
"""
1213

13-
def __init__(self, variable: Variable, value: any, start_time: int, end_time: float):
14+
def __init__(self, variable: Variable, value: Any, start_time: int, end_time: int):
1415
self.variable = variable
1516
self.value = value
1617
self.start_time = start_time

causal_testing/testing/causal_test_adequacy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from causal_testing.specification.causal_dag import CausalDAG
1313
from causal_testing.testing.estimators import Estimator
1414
from causal_testing.testing.causal_test_case import CausalTestCase
15+
import logging
16+
17+
logger = logging.getLogger(__name__)
1518

1619

1720
class DAGAdequacy:
@@ -104,10 +107,13 @@ def measure_adequacy(self):
104107
try:
105108
results.append(self.test_case.execute_test(estimator, None))
106109
except LinAlgError:
110+
logger.warning("Adequacy LinAlgError")
107111
continue
108112
except ConvergenceError:
113+
logger.warning("Adequacy ConvergenceError")
109114
continue
110-
except ValueError:
115+
except ValueError as e:
116+
logger.warning(f"Adequacy ValueError: {e}")
111117
continue
112118
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
113119
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def setup_fault_t_do(self, individual: pd.DataFrame):
697697

698698
return pd.DataFrame({"fault_t_do": fault_t_do})
699699

700-
def setup_fault_time(self, individual, perturbation=-0.001):
700+
def setup_fault_time(self, individual: pd.DataFrame, perturbation: float = -0.001):
701701
"""
702702
Return the time at which the event of interest (i.e. a fault) occurred.
703703
"""

tests/specification_tests/test_capabilities.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
import unittest
2-
from enum import Enum
3-
import z3
4-
from scipy.stats import norm, kstest
5-
62
from causal_testing.specification.capabilities import Capability, TreatmentSequence
73

84

0 commit comments

Comments
 (0)