1
1
"""
2
2
This module contains code to measure various aspects of causal test adequacy.
3
3
"""
4
+ from itertools import combinations
5
+ from copy import deepcopy
6
+ import pandas as pd
7
+
4
8
from causal_testing .testing .causal_test_suite import CausalTestSuite
5
9
from causal_testing .data_collection .data_collector import DataCollector
6
10
from causal_testing .specification .causal_specification import CausalSpecification
7
11
from causal_testing .testing .estimators import Estimator
8
12
from causal_testing .testing .causal_test_case import CausalTestCase
9
- from itertools import combinations
10
- from copy import deepcopy
11
- from sklearn .model_selection import KFold
12
- from sklearn .metrics import mean_squared_error as mse
13
- import numpy as np
14
- from sklearn .model_selection import cross_val_score
15
- import pandas as pd
16
13
17
14
18
15
class DAGAdequacy :
16
+ """
17
+ Measures the adequacy of a given DAG by hos many edges and independences are tested.
18
+ """
19
+
19
20
def __init__ (
20
21
self ,
21
22
causal_specification : CausalSpecification ,
22
23
test_suite : CausalTestSuite ,
23
24
):
24
25
self .causal_dag = causal_specification .causal_dag
25
26
self .test_suite = test_suite
27
+ self .tested_pairs = None
28
+ self .pairs_to_test = None
29
+ self .untested_edges = None
30
+ self .dag_adequacy = None
26
31
27
32
def measure_adequacy (self ):
33
+ """
34
+ Calculate the adequacy measurement, and populate the `dat_adequacy` field.
35
+ """
28
36
self .tested_pairs = {
29
- (t .base_test_case .treatment_variable , t .base_test_case .outcome_variable ) for t in self .causal_test_suite
37
+ (t .base_test_case .treatment_variable , t .base_test_case .outcome_variable ) for t in self .test_suite
30
38
}
31
39
self .pairs_to_test = set (combinations (self .causal_dag .graph .nodes , 2 ))
32
- self .untested_edges = pairs_to_test .difference (tested_pairs )
33
- self .dag_adequacy = len (tested_pairs ) / len (pairs_to_test )
40
+ self .untested_edges = self .pairs_to_test .difference (self .tested_pairs )
41
+ self .dag_adequacy = len (self .tested_pairs ) / len (self .pairs_to_test )
42
+
43
+ def to_dict (self ):
44
+ "Returns the adequacy object as a dictionary."
45
+ return {
46
+ "causal_dag" : self .causal_dag ,
47
+ "test_suite" : self .test_suite ,
48
+ "tested_pairs" : self .tested_pairs ,
49
+ "pairs_to_test" : self .pairs_to_test ,
50
+ "untested_edges" : self .untested_edges ,
51
+ "dag_adequacy" : self .dag_adequacy ,
52
+ }
34
53
35
54
36
55
class DataAdequacy :
56
+ """
57
+ Measures the adequacy of a given test according to the Fisher kurtosis of the bootstrapped result.
58
+ - Positive kurtoses indicate the model doesn't have enough data so is unstable.
59
+ - Negative kurtoses indicate the model doesn't have enough data, but is too stable, indicating that the spread of
60
+ inputs is insufficient.
61
+ - Zero kurtosis is optimal.
62
+ """
63
+
37
64
def __init__ (
38
65
self , test_case : CausalTestCase , estimator : Estimator , data_collector : DataCollector , bootstrap_size : int = 100
39
66
):
@@ -45,6 +72,9 @@ def __init__(
45
72
self .bootstrap_size = bootstrap_size
46
73
47
74
def measure_adequacy (self ):
75
+ """
76
+ Calculate the adequacy measurement, and populate the data_adequacy field.
77
+ """
48
78
results = []
49
79
for i in range (self .bootstrap_size ):
50
80
estimator = deepcopy (self .estimator )
@@ -75,4 +105,5 @@ def convert_to_df(field):
75
105
self .outcomes = sum (outcomes )
76
106
77
107
def to_dict (self ):
108
+ "Returns the adequacy object as a dictionary."
78
109
return {"kurtosis" : self .kurtosis .to_dict (), "bootstrap_size" : self .bootstrap_size , "passing" : self .outcomes }
0 commit comments