Skip to content

Commit 61e7def

Browse files
add type hints for causal_test_suite.py + black
1 parent 71b6d21 commit 61e7def

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def _generate_concrete_tests(
146146
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
147147
)
148148

149-
150149
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
151150
concrete_tests.append(concrete_test)
152151
# Control run
@@ -162,7 +161,6 @@ def _generate_concrete_tests(
162161
treatment_run["bin"] = index
163162
runs.append(treatment_run)
164163

165-
166164
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
167165

168166
def generate_concrete_tests(

causal_testing/testing/causal_test_suite.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from collections import UserDict
2+
from typing import Type, Iterable
3+
from causal_testing.testing.base_test_case import BaseTestCase
4+
from causal_testing.testing.causal_test_case import CausalTestCase
5+
from causal_testing.testing.estimators import Estimator
26

37

48
class CausalTestSuite(UserDict):
@@ -13,7 +17,13 @@ class CausalTestSuite(UserDict):
1317
base_test_case's and execute each causal_test_case with each iterator.
1418
"""
1519

16-
def add_test_object(self, base_test_case, causal_test_case_list, estimators_classes, estimate_type: str = "ate"):
20+
def add_test_object(
21+
self,
22+
base_test_case: BaseTestCase,
23+
causal_test_case_list: Iterable[CausalTestCase],
24+
estimators_classes: Iterable[Type[Estimator]],
25+
estimate_type: str = "ate",
26+
):
1727
"""
1828
A setter object to allow for the easy construction of the dictionary test suite structure
1929

causal_testing/testing/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ class LinearRegressionEstimator(Estimator):
282282

283283
def __init__(
284284
self,
285-
treatment: tuple,
285+
treatment: tuple,
286286
treatment_value: float,
287287
control_value: float,
288288
adjustment_set: set,

0 commit comments

Comments
 (0)