Skip to content

Commit 1604d7a

Browse files
committed
linter
1 parent 886b911 commit 1604d7a

File tree

4 files changed

+49
-18
lines changed

4 files changed

+49
-18
lines changed

causal_testing/json_front/json_class.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from statistics import StatisticsError
1212

1313
import pandas as pd
14-
import numpy as np
1514
import scipy
1615
from fitter import Fitter, get_common_distributions
1716

@@ -68,7 +67,7 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
6867
data_paths = []
6968
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
7069

71-
def setup(self, scenario: Scenario, data=[]):
70+
def setup(self, scenario: Scenario, data=None):
7271
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
7372
self.scenario = scenario
7473
self._get_scenario_variables()
@@ -82,7 +81,7 @@ def setup(self, scenario: Scenario, data=[]):
8281
# Populate the data
8382
if self.input_paths.data_paths:
8483
data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
85-
if len(data) == 0:
84+
if data is None or len(data) == 0:
8685
raise ValueError(
8786
"No data found. Please either provide a path to a file containing data or manually populate the .data "
8887
"attribute with a dataframe before calling .setup()"
@@ -131,7 +130,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
131130
test["estimator"] = estimators[test["estimator"]]
132131
# If we have specified concrete control and treatment value
133132
if "mutations" not in test:
134-
failed, msg = self._run_concrete_metamorphic_test(test, f_flag, effects, mutates)
133+
failed, msg = self._run_concrete_metamorphic_test(test, f_flag, effects)
135134
# If we have a variable to mutate
136135
else:
137136
if test["estimate_type"] == "coefficient":
@@ -176,7 +175,7 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
176175
self._append_to_file(msg, logging.INFO)
177176
return failed, result
178177

179-
def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
178+
def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict):
180179
outcome_variable = next(iter(test["expected_effect"])) # Take first key from dictionary of expected effect
181180
base_test_case = BaseTestCase(
182181
treatment_variable=self.variables["inputs"][test["treatment_variable"]],

causal_testing/testing/causal_test_adequacy.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,66 @@
11
"""
22
This module contains code to measure various aspects of causal test adequacy.
33
"""
4+
from itertools import combinations
5+
from copy import deepcopy
6+
import pandas as pd
7+
48
from causal_testing.testing.causal_test_suite import CausalTestSuite
59
from causal_testing.data_collection.data_collector import DataCollector
610
from causal_testing.specification.causal_specification import CausalSpecification
711
from causal_testing.testing.estimators import Estimator
812
from causal_testing.testing.causal_test_case import CausalTestCase
9-
from itertools import combinations
10-
from copy import deepcopy
11-
from sklearn.model_selection import KFold
12-
from sklearn.metrics import mean_squared_error as mse
13-
import numpy as np
14-
from sklearn.model_selection import cross_val_score
15-
import pandas as pd
1613

1714

1815
class DAGAdequacy:
16+
"""
17+
Measures the adequacy of a given DAG by hos many edges and independences are tested.
18+
"""
19+
1920
def __init__(
2021
self,
2122
causal_specification: CausalSpecification,
2223
test_suite: CausalTestSuite,
2324
):
2425
self.causal_dag = causal_specification.causal_dag
2526
self.test_suite = test_suite
27+
self.tested_pairs = None
28+
self.pairs_to_test = None
29+
self.untested_edges = None
30+
self.dag_adequacy = None
2631

2732
def measure_adequacy(self):
33+
"""
34+
Calculate the adequacy measurement, and populate the `dat_adequacy` field.
35+
"""
2836
self.tested_pairs = {
29-
(t.base_test_case.treatment_variable, t.base_test_case.outcome_variable) for t in self.causal_test_suite
37+
(t.base_test_case.treatment_variable, t.base_test_case.outcome_variable) for t in self.test_suite
3038
}
3139
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2))
32-
self.untested_edges = pairs_to_test.difference(tested_pairs)
33-
self.dag_adequacy = len(tested_pairs) / len(pairs_to_test)
40+
self.untested_edges = self.pairs_to_test.difference(self.tested_pairs)
41+
self.dag_adequacy = len(self.tested_pairs) / len(self.pairs_to_test)
42+
43+
def to_dict(self):
44+
"Returns the adequacy object as a dictionary."
45+
return {
46+
"causal_dag": self.causal_dag,
47+
"test_suite": self.test_suite,
48+
"tested_pairs": self.tested_pairs,
49+
"pairs_to_test": self.pairs_to_test,
50+
"untested_edges": self.untested_edges,
51+
"dag_adequacy": self.dag_adequacy,
52+
}
3453

3554

3655
class DataAdequacy:
56+
"""
57+
Measures the adequacy of a given test according to the Fisher kurtosis of the bootstrapped result.
58+
- Positive kurtoses indicate the model doesn't have enough data so is unstable.
59+
- Negative kurtoses indicate the model doesn't have enough data, but is too stable, indicating that the spread of
60+
inputs is insufficient.
61+
- Zero kurtosis is optimal.
62+
"""
63+
3764
def __init__(
3865
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100
3966
):
@@ -45,6 +72,9 @@ def __init__(
4572
self.bootstrap_size = bootstrap_size
4673

4774
def measure_adequacy(self):
75+
"""
76+
Calculate the adequacy measurement, and populate the data_adequacy field.
77+
"""
4878
results = []
4979
for i in range(self.bootstrap_size):
5080
estimator = deepcopy(self.estimator)
@@ -75,4 +105,5 @@ def convert_to_df(field):
75105
self.outcomes = sum(outcomes)
76106

77107
def to_dict(self):
108+
"Returns the adequacy object as a dictionary."
78109
return {"kurtosis": self.kurtosis.to_dict(), "bootstrap_size": self.bootstrap_size, "passing": self.outcomes}

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ class CausalTestCase:
1818
"""
1919
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
2020
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
21-
specified. The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the
22-
model-under-test to produce the expected change.
21+
specified. The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment
22+
causes the model-under-test to produce the expected change.
2323
"""
2424

2525
def __init__(

causal_testing/testing/causal_test_outcome.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class NoEffect(CausalTestOutcome):
4343

4444
def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
4545
"""
46-
:param atol: Arithmetic tolerance. The test will pass if the absolute value of the causal effect is less than atol.
46+
:param atol: Arithmetic tolerance. The test will pass if the absolute value of the causal effect is less than
47+
atol.
4748
:param ctol: Categorical tolerance. The test will pass if this proportion of categories pass.
4849
"""
4950
self.atol = atol

0 commit comments

Comments
 (0)