Skip to content

Commit b2a7d31

Browse files
Update docstrings to use more generic DataCollector
1 parent 0a15398 commit b2a7d31

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from causal_testing.testing.base_test_case import BaseTestCase
88
from causal_testing.testing.estimators import Estimator
99
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
10-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
10+
from causal_testing.data_collection.data_collector import DataCollector
1111

1212

1313
logger = logging.getLogger(__name__)
@@ -77,7 +77,7 @@ def get_treatment_value(self):
7777
"""Return the treatment value of the treatment variable in this causal test case."""
7878
return self.treatment_value
7979

80-
def execute_test(self, estimator: type(Estimator), data_collector: ObservationalDataCollector) -> CausalTestResult:
80+
def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
8181
"""Execute a causal test case and return the causal test result.
8282
8383
:param estimator: A reference to an Estimator class.

causal_testing/testing/causal_test_suite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from causal_testing.testing.causal_test_case import CausalTestCase
99
from causal_testing.testing.estimators import Estimator
1010
from causal_testing.testing.causal_test_result import CausalTestResult
11-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
11+
from causal_testing.data_collection.data_collector import DataCollector
1212
from causal_testing.specification.causal_specification import CausalSpecification
1313

1414
logger = logging.getLogger(__name__)
@@ -46,7 +46,7 @@ def add_test_object(
4646
self.data[base_test_case] = test_object
4747

4848
def execute_test_suite(
49-
self, data_collector: ObservationalDataCollector, causal_specification: CausalSpecification
49+
self, data_collector: DataCollector, causal_specification: CausalSpecification
5050
) -> dict[str, CausalTestResult]:
5151
"""Execute a suite of causal tests and return the results in a list
5252
:param data_collector: The data collector to be used for the test_suite. Can be observational, experimental or

0 commit comments

Comments
 (0)