Skip to content

Commit 5495caa

Browse files
committed
Removed datacollector from surrogate assisted
1 parent a20954c commit 5495caa

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from dataclasses import dataclass
55
from typing import Callable
66
import pandas as pd
7-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
87
from causal_testing.specification.causal_specification import CausalSpecification
98
from causal_testing.testing.base_test_case import BaseTestCase
109
from causal_testing.estimation.cubic_spline_estimator import CubicSplineRegressionEstimator
@@ -73,21 +72,20 @@ def __init__(
7372

7473
def execute(
7574
self,
76-
data_collector: ObservationalDataCollector,
75+
df: pd.DataFrame,
7776
max_executions: int = 200,
7877
custom_data_aggregator: Callable[[dict, dict], dict] = None,
7978
):
8079
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input
8180
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against
8281
the simulator, checked for faults and the result returned with collected data
83-
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
82+
:param df: An dataframe which contains data relevant to the specified scenario
8483
:param max_executions: Maximum number of simulator executions before exiting the search
8584
:param custom_data_aggregator:
8685
:return: tuple containing SimulationResult or str, execution number and collected data"""
87-
data_collector.collect_data()
8886

8987
for i in range(max_executions):
90-
surrogate_models = self.generate_surrogates(self.specification, data_collector)
88+
surrogate_models = self.generate_surrogates(self.specification, df)
9189
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification)
9290

9391
self.simulator.startup()
@@ -96,10 +94,10 @@ def execute(
9694
self.simulator.shutdown()
9795

9896
if custom_data_aggregator is not None:
99-
if data_collector.data is not None:
100-
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
97+
if df is not None:
98+
df = custom_data_aggregator(df, test_result.data)
10199
else:
102-
data_collector.data = pd.concat([data_collector.data, test_result_df], ignore_index=True)
100+
df = pd.concat([df, test_result_df], ignore_index=True)
103101
if test_result.fault:
104102
print(
105103
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
@@ -108,17 +106,17 @@ def execute(
108106
test_result.relationship = (
109107
f"{surrogate.treatment} -> {surrogate.outcome} expected {surrogate.expected_relationship}"
110108
)
111-
return test_result, i + 1, data_collector.data
109+
return test_result, i + 1, df
112110

113111
print("No fault found")
114-
return "No fault found", i + 1, data_collector.data
112+
return "No fault found", i + 1, df
115113

116114
def generate_surrogates(
117-
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
115+
self, specification: CausalSpecification, df: pd.DataFrame
118116
) -> list[CubicSplineRegressionEstimator]:
119117
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
120118
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
121-
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
119+
:param df: An dataframe which contains data relevant to the specified scenario
122120
:return: A list of surrogate models
123121
"""
124122
surrogate_models = []
@@ -139,7 +137,7 @@ def generate_surrogates(
139137
minimal_adjustment_set,
140138
v,
141139
4,
142-
df=data_collector.data,
140+
df=df,
143141
expected_relationship=edge_metadata["expected"],
144142
)
145143
surrogate_models.append(surrogate)

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import unittest
2-
from causal_testing.data_collection.data_collector import ObservationalDataCollector
32
from causal_testing.specification.causal_dag import CausalDAG
43
from causal_testing.specification.causal_specification import CausalSpecification
54
from causal_testing.specification.scenario import Scenario
@@ -69,7 +68,7 @@ def test_surrogate_model_generation(self):
6968
scenario = Scenario(variables={z, x, m, y})
7069
specification = CausalSpecification(scenario, causal_dag)
7170

72-
surrogate_models = c_s_a_test_case.generate_surrogates(specification, ObservationalDataCollector(scenario, df))
71+
surrogate_models = c_s_a_test_case.generate_surrogates(specification, df)
7372
self.assertEqual(len(surrogate_models), 2)
7473

7574
for surrogate in surrogate_models:
@@ -101,7 +100,7 @@ def test_causal_surrogate_assisted_execution(self):
101100

102101
c_s_a_test_case = CausalSurrogateAssistedTestCase(specification, search_algorithm, simulator)
103102

104-
result, iterations, result_data = c_s_a_test_case.execute(ObservationalDataCollector(scenario, df))
103+
result, iterations, result_data = c_s_a_test_case.execute(df)
105104

106105
self.assertIsInstance(result, SimulationResult)
107106
self.assertEqual(iterations, 1)
@@ -131,7 +130,7 @@ def test_causal_surrogate_assisted_execution_failure(self):
131130

132131
c_s_a_test_case = CausalSurrogateAssistedTestCase(specification, search_algorithm, simulator)
133132

134-
result, iterations, result_data = c_s_a_test_case.execute(ObservationalDataCollector(scenario, df), 1)
133+
result, iterations, result_data = c_s_a_test_case.execute(df, 1)
135134

136135
self.assertIsInstance(result, str)
137136
self.assertEqual(iterations, 1)
@@ -161,9 +160,7 @@ def test_causal_surrogate_assisted_execution_custom_aggregator(self):
161160

162161
c_s_a_test_case = CausalSurrogateAssistedTestCase(specification, search_algorithm, simulator)
163162

164-
result, iterations, result_data = c_s_a_test_case.execute(
165-
ObservationalDataCollector(scenario, df), custom_data_aggregator=data_double_aggregator
166-
)
163+
result, iterations, result_data = c_s_a_test_case.execute(df, custom_data_aggregator=data_double_aggregator)
167164

168165
self.assertIsInstance(result, SimulationResult)
169166
self.assertEqual(iterations, 1)
@@ -197,7 +194,7 @@ def test_causal_surrogate_assisted_execution_incorrect_search_config(self):
197194
self.assertRaises(
198195
ValueError,
199196
c_s_a_test_case.execute,
200-
data_collector=ObservationalDataCollector(scenario, df),
197+
df=df,
201198
custom_data_aggregator=data_double_aggregator,
202199
)
203200

0 commit comments

Comments
 (0)