Skip to content

Commit 6bcfa0e

Browse files
pylint
1 parent a696899 commit 6bcfa0e

File tree

5 files changed

+53
-63
lines changed

5 files changed

+53
-63
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class DataCollector(ABC):
1818

1919
def __init__(self, scenario: Scenario):
2020
self.scenario = scenario
21+
2122
@abstractmethod
2223
def collect_data(self, **kwargs) -> pd.DataFrame:
2324
"""

causal_testing/json_front/json_class.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
import logging
77

8-
from collections.abc import Iterable, Mapping
8+
from collections.abc import Mapping
99
from dataclasses import dataclass
1010
from pathlib import Path
1111
from statistics import StatisticsError
@@ -86,11 +86,9 @@ def setup(self, scenario: Scenario):
8686
"No data found, either provide a path to a file containing data or manually populate the .data "
8787
"attribute with a dataframe before calling .setup()"
8888
)
89-
self.data_collector = ObservationalDataCollector(
90-
self.scenario, data)
89+
self.data_collector = ObservationalDataCollector(self.scenario, data)
9190
self._populate_metas()
9291

93-
9492
def _create_abstract_test_case(self, test, mutates, effects):
9593
assert len(test["mutations"]) == 1
9694
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
@@ -156,11 +154,11 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
156154
failed, _ = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
157155

158156
msg = (
159-
f"Executing concrete test: {test['name']} \n"
160-
+ f"treatment variable: {test['treatment_variable']} \n"
161-
+ f"outcome_variable = {outcome_variable} \n"
162-
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
163-
+ f"Result: {'FAILED' if failed else 'Passed'}"
157+
f"Executing concrete test: {test['name']} \n"
158+
+ f"treatment variable: {test['treatment_variable']} \n"
159+
+ f"outcome_variable = {outcome_variable} \n"
160+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
161+
+ f"Result: {'FAILED' if failed else 'Passed'}"
164162
)
165163
print(msg)
166164
self._append_to_file(msg, logging.INFO)
@@ -187,12 +185,12 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
187185
)
188186
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
189187
msg = (
190-
f"Executing test: {test['name']} \n"
191-
+ f" {causal_test_case} \n"
192-
+ " "
193-
+ ("\n ").join(str(result[1]).split("\n"))
194-
+ "==============\n"
195-
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
188+
f"Executing test: {test['name']} \n"
189+
+ f" {causal_test_case} \n"
190+
+ " "
191+
+ ("\n ").join(str(result[1]).split("\n"))
192+
+ "==============\n"
193+
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
196194
)
197195
return msg
198196

@@ -220,13 +218,13 @@ def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
220218
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
221219

222220
msg = (
223-
f"Executing test: {test['name']} \n"
224-
+ " abstract_test \n"
225-
+ f" {abstract_test} \n"
226-
+ f" {abstract_test.treatment_variable.name},"
227-
+ f" {abstract_test.treatment_variable.distribution} \n"
228-
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
229-
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
221+
f"Executing test: {test['name']} \n"
222+
+ " abstract_test \n"
223+
+ f" {abstract_test} \n"
224+
+ f" {abstract_test.treatment_variable.name},"
225+
+ f" {abstract_test.treatment_variable.distribution} \n"
226+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
227+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
230228
)
231229
return msg
232230

@@ -251,8 +249,7 @@ def _populate_metas(self):
251249
meta.populate(self.data_collector.data)
252250

253251
def _execute_test_case(
254-
self, causal_test_case: CausalTestCase, test: Mapping,
255-
f_flag: bool
252+
self, causal_test_case: CausalTestCase, test: Mapping, f_flag: bool
256253
) -> (bool, CausalTestResult):
257254
"""Executes a singular test case, prints the results and returns the test case result
258255
:param causal_test_case: The concrete test case to be executed
@@ -265,8 +262,9 @@ def _execute_test_case(
265262
failed = False
266263

267264
estimation_model = self._setup_test(causal_test_case=causal_test_case, test=test)
268-
causal_test_result = causal_test_case.execute_test(estimator=estimation_model,
269-
data_collector=self.data_collector)
265+
causal_test_result = causal_test_case.execute_test(
266+
estimator=estimation_model, data_collector=self.data_collector
267+
)
270268

271269
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
272270

@@ -288,9 +286,7 @@ def _execute_test_case(
288286
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
289287
return failed, causal_test_result
290288

291-
def _setup_test(
292-
self, causal_test_case: CausalTestCase, test: Mapping
293-
) -> Estimator:
289+
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estimator:
294290
"""Create the necessary inputs for a single test case
295291
:param causal_test_case: The concrete test case to be executed
296292
:param test: Single JSON test definition stored in a mapping (dict)
@@ -368,7 +364,7 @@ def get_args(test_args=None) -> argparse.Namespace:
368364
parser.add_argument(
369365
"-w",
370366
help="Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
371-
"careful",
367+
"careful",
372368
action="store_true",
373369
)
374370
parser.add_argument(

causal_testing/specification/causal_specification.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""This module holds the abstract CausalSpecification data class, which holds a Scenario and CausalDag"""
22

3-
from abc import ABC
43
from dataclasses import dataclass
54
from typing import Union
65

causal_testing/testing/causal_test_case.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,14 @@
22
import logging
33
from typing import Any
44

5-
import pandas as pd
6-
75
from causal_testing.specification.variable import Variable
86
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
97
from causal_testing.testing.base_test_case import BaseTestCase
108
from causal_testing.testing.estimators import Estimator
119
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
1210
from causal_testing.data_collection.data_collector import ObservationalDataCollector
13-
from causal_testing.specification.causal_dag import CausalDAG
14-
from causal_testing.specification.scenario import Scenario
1511

16-
from causal_testing.specification.causal_specification import CausalSpecification
12+
1713
logger = logging.getLogger(__name__)
1814

1915

@@ -31,15 +27,15 @@ class CausalTestCase:
3127
"""
3228

3329
def __init__(
34-
# pylint: disable=too-many-arguments
35-
self,
36-
base_test_case: BaseTestCase,
37-
expected_causal_effect: CausalTestOutcome,
38-
control_value: Any = None,
39-
treatment_value: Any = None,
40-
estimate_type: str = "ate",
41-
estimate_params: dict = None,
42-
effect_modifier_configuration: dict[Variable:Any] = None,
30+
# pylint: disable=too-many-arguments
31+
self,
32+
base_test_case: BaseTestCase,
33+
expected_causal_effect: CausalTestOutcome,
34+
control_value: Any = None,
35+
treatment_value: Any = None,
36+
estimate_type: str = "ate",
37+
estimate_params: dict = None,
38+
effect_modifier_configuration: dict[Variable:Any] = None,
4339
):
4440
"""
4541
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
@@ -85,7 +81,7 @@ def execute_test(self, estimator: type(Estimator), data_collector: Observational
8581
"""Execute a causal test case and return the causal test result.
8682
8783
:param estimator: A reference to an Estimator class.
88-
:param causal_test_case: The CausalTestCase object to be tested
84+
:param data_collector: The data collector to be used which provides a dataframe for the Estimator
8985
:return causal_test_result: A CausalTestResult for the executed causal test case.
9086
"""
9187
if estimator.df is None:
@@ -100,11 +96,10 @@ def execute_test(self, estimator: type(Estimator), data_collector: Observational
10096
causal_test_result = self._return_causal_test_results(estimator)
10197
return causal_test_result
10298

103-
def _return_causal_test_results(self, estimator):
99+
def _return_causal_test_results(self, estimator) -> CausalTestResult:
104100
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
105101
106102
:param estimator: An Estimator class object
107-
:param causal_test_case: The concrete test case to be executed
108103
:return: a CausalTestResult object containing the confidence intervals
109104
"""
110105
if not hasattr(estimator, f"estimate_{self.estimate_type}"):
@@ -128,4 +123,3 @@ def __str__(self):
128123
f"Running {treatment_config} instead of {control_config} should cause the following "
129124
f"changes to {outcome_variable}: {self.expected_causal_effect}."
130125
)
131-

causal_testing/testing/causal_test_suite.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from causal_testing.testing.base_test_case import BaseTestCase
88
from causal_testing.testing.causal_test_case import CausalTestCase
99
from causal_testing.testing.estimators import Estimator
10-
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
10+
from causal_testing.testing.causal_test_result import CausalTestResult
1111
from causal_testing.data_collection.data_collector import ObservationalDataCollector
1212
from causal_testing.specification.causal_specification import CausalSpecification
13+
1314
logger = logging.getLogger(__name__)
1415

1516

@@ -26,11 +27,11 @@ class CausalTestSuite(UserDict):
2627
"""
2728

2829
def add_test_object(
29-
self,
30-
base_test_case: BaseTestCase,
31-
causal_test_case_list: Iterable[CausalTestCase],
32-
estimators_classes: Iterable[Type[Estimator]],
33-
estimate_type: str = "ate",
30+
self,
31+
base_test_case: BaseTestCase,
32+
causal_test_case_list: Iterable[CausalTestCase],
33+
estimators_classes: Iterable[Type[Estimator]],
34+
estimate_type: str = "ate",
3435
):
3536
"""
3637
A setter object to allow for the easy construction of the dictionary test suite structure
@@ -44,16 +45,18 @@ def add_test_object(
4445
test_object = {"tests": causal_test_case_list, "estimators": estimators_classes, "estimate_type": estimate_type}
4546
self.data[base_test_case] = test_object
4647

47-
def execute_test_suite(self, data_collector: ObservationalDataCollector,
48-
causal_specification: CausalSpecification) -> list[CausalTestResult]:
48+
def execute_test_suite(
49+
self, data_collector: ObservationalDataCollector, causal_specification: CausalSpecification
50+
) -> dict[str, CausalTestResult]:
4951
"""Execute a suite of causal tests and return the results in a list
50-
:param test_suite: CasualTestSuite object
52+
:param data_collector: The data collector to be used for the test_suite. Can be observational, experimental or
53+
custom
54+
:param causal_specification:
5155
:return: A dictionary where each key is the name of the estimators specified and the values are lists of
5256
causal_test_result objects
5357
"""
5458
if data_collector.data.empty:
5559
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
56-
data_collector.collect_data()
5760
test_suite_results = {}
5861
for edge in self:
5962
logger.info("treatment: %s", edge.treatment_variable)
@@ -76,12 +79,9 @@ def execute_test_suite(self, data_collector: ObservationalDataCollector,
7679
minimal_adjustment_set,
7780
test.outcome_variable.name,
7881
)
79-
if estimator.df is None:
80-
estimator.df = data_collector.collect_data()
81-
causal_test_result = test._return_causal_test_results(estimator)
82+
causal_test_result = test.execute_test(estimator, data_collector)
8283
causal_test_results.append(causal_test_result)
8384

8485
results[estimator_class.__name__] = causal_test_results
8586
test_suite_results[edge] = results
8687
return test_suite_results
87-

0 commit comments

Comments
 (0)