4
4
from dataclasses import dataclass
5
5
from typing import Callable
6
6
import pandas as pd
7
- from causal_testing .data_collection .data_collector import ObservationalDataCollector
8
7
from causal_testing .specification .causal_specification import CausalSpecification
9
8
from causal_testing .testing .base_test_case import BaseTestCase
10
9
from causal_testing .estimation .cubic_spline_estimator import CubicSplineRegressionEstimator
@@ -73,21 +72,20 @@ def __init__(
73
72
74
73
def execute (
75
74
self ,
76
- data_collector : ObservationalDataCollector ,
75
+ df : pd . DataFrame ,
77
76
max_executions : int = 200 ,
78
77
custom_data_aggregator : Callable [[dict , dict ], dict ] = None ,
79
78
):
80
79
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input
81
80
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against
82
81
the simulator, checked for faults and the result returned with collected data
83
- :param data_collector : An ObservationalDataCollector which gathers data relevant to the specified scenario
82
+ :param df : An dataframe which contains data relevant to the specified scenario
84
83
:param max_executions: Maximum number of simulator executions before exiting the search
85
84
:param custom_data_aggregator:
86
85
:return: tuple containing SimulationResult or str, execution number and collected data"""
87
- data_collector .collect_data ()
88
86
89
87
for i in range (max_executions ):
90
- surrogate_models = self .generate_surrogates (self .specification , data_collector )
88
+ surrogate_models = self .generate_surrogates (self .specification , df )
91
89
candidate_test_case , _ , surrogate = self .search_algorithm .search (surrogate_models , self .specification )
92
90
93
91
self .simulator .startup ()
@@ -96,10 +94,10 @@ def execute(
96
94
self .simulator .shutdown ()
97
95
98
96
if custom_data_aggregator is not None :
99
- if data_collector . data is not None :
100
- data_collector . data = custom_data_aggregator (data_collector . data , test_result .data )
97
+ if df is not None :
98
+ df = custom_data_aggregator (df , test_result .data )
101
99
else :
102
- data_collector . data = pd .concat ([data_collector . data , test_result_df ], ignore_index = True )
100
+ df = pd .concat ([df , test_result_df ], ignore_index = True )
103
101
if test_result .fault :
104
102
print (
105
103
f"Fault found between { surrogate .treatment } causing { surrogate .outcome } . Contradiction with "
@@ -108,17 +106,17 @@ def execute(
108
106
test_result .relationship = (
109
107
f"{ surrogate .treatment } -> { surrogate .outcome } expected { surrogate .expected_relationship } "
110
108
)
111
- return test_result , i + 1 , data_collector . data
109
+ return test_result , i + 1 , df
112
110
113
111
print ("No fault found" )
114
- return "No fault found" , i + 1 , data_collector . data
112
+ return "No fault found" , i + 1 , df
115
113
116
114
def generate_surrogates (
117
- self , specification : CausalSpecification , data_collector : ObservationalDataCollector
115
+ self , specification : CausalSpecification , df : pd . DataFrame
118
116
) -> list [CubicSplineRegressionEstimator ]:
119
117
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
120
118
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
121
- :param data_collector : An ObservationalDataCollector which gathers data relevant to the specified scenario
119
+ :param df : An dataframe which contains data relevant to the specified scenario
122
120
:return: A list of surrogate models
123
121
"""
124
122
surrogate_models = []
@@ -139,7 +137,7 @@ def generate_surrogates(
139
137
minimal_adjustment_set ,
140
138
v ,
141
139
4 ,
142
- df = data_collector . data ,
140
+ df = df ,
143
141
expected_relationship = edge_metadata ["expected" ],
144
142
)
145
143
surrogate_models .append (surrogate )
0 commit comments