Skip to content

Commit 273fd3d

Browse files
add temp file for observational collector
1 parent 1166240 commit 273fd3d

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

causal_testing/json_front/json_class.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def _json_parse(self):
136136
for data_file in self.paths.data_paths:
137137
df = pd.read_csv(data_file, header=0)
138138
self.data.append(df)
139-
139+
self.data = pd.concat(self.data)
140+
breakpoint()
140141
def _populate_metas(self):
141142
"""
142143
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
143144
"""
144-
145145
for meta in self.variables.metas:
146146
meta.populate(self.data)
147147

@@ -195,8 +195,13 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
195195
- causal_test_engine - Test Engine instance for the test being run
196196
- estimation_model - Estimator instance for the test being run
197197
"""
198-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.paths.data_path)
199-
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
198+
199+
with tempfile.TemporaryFile(delete=False) as temp:
200+
self.data.to_csv(temp)
201+
breakpoint()
202+
data_collector = ObservationalDataCollector(self.modelling_scenario, temp.name)
203+
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
204+
200205
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
201206
treatment_var = causal_test_case.treatment_variable
202207
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
@@ -280,12 +285,12 @@ class JsonClassPaths:
280285

281286
json_path: Path
282287
dag_path: Path
283-
data_path: Path
288+
data_paths: list[Path]
284289

285-
def __init__(self, json_path: str, dag_path: str, data_path: str):
290+
def __init__(self, json_path: str, dag_path: str, data_paths: str):
286291
self.json_path = Path(json_path)
287292
self.dag_path = Path(dag_path)
288-
self.data_paths = [Path(path) for path in data_path]
293+
self.data_paths = [Path(path) for path in data_paths]
289294

290295

291296
@dataclass()

0 commit comments

Comments
 (0)