Skip to content

Commit eda70f8

Browse files
Directly pass dataframe
1 parent 4c1f30a commit eda70f8

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ def run_system_with_input_configuration(self, input_configuration: dict) -> pd.D
124124
class ObservationalDataCollector(DataCollector):
125125
"""A data collector that extracts data that is relevant to the specified scenario from a csv of execution data."""
126126

127-
def __init__(self, scenario: Scenario, csv_path: str):
127+
def __init__(self, scenario: Scenario, data: pd.DataFrame):
128128
super().__init__(scenario)
129-
self.csv_path = csv_path
129+
self.data = data
130130

131131
def collect_data(self, **kwargs) -> pd.DataFrame:
132132
"""Read a csv containing execution data for the system-under-test into a pandas dataframe and filter to remove
@@ -137,7 +137,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
137137
:return: A pandas dataframe containing execution data that is valid for the scenario-under-test.
138138
"""
139139

140-
execution_data_df = pd.read_csv(self.csv_path, **kwargs)
140+
execution_data_df = self.data
141141
for meta in self.scenario.metas():
142142
meta.populate(execution_data_df)
143143
scenario_execution_data_df = self.filter_valid_data(execution_data_df)

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,8 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
195195
- estimation_model - Estimator instance for the test being run
196196
"""
197197

198-
with tempfile.TemporaryFile(delete=False) as temp:
199-
self.data.to_csv(temp)
200-
data_collector = ObservationalDataCollector(self.modelling_scenario, temp.name)
201-
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
198+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data)
199+
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
202200

203201
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
204202
treatment_var = causal_test_case.treatment_variable

0 commit comments

Comments
 (0)