@@ -138,6 +138,13 @@ def load_dag(self) -> CausalDAG:
138
138
logger .error (f"Failed to load DAG: { str (e )} " )
139
139
raise
140
140
141
+ def _read_dataframe (self , data_path ):
142
+ if str (data_path ).endswith (".csv" ):
143
+ return pd .read_csv (data_path )
144
+ if str (data_path ).endswith (".pqt" ):
145
+ return pd .read_parquet (data_path )
146
+ raise ValueError (f"Invalid file type { data_path } . Can only read CSV (.csv) or parquet (.pqt) files." )
147
+
141
148
def load_data (self , query : Optional [str ] = None ) -> pd .DataFrame :
142
149
"""Load and combine all data sources with optional filtering.
143
150
@@ -148,7 +155,7 @@ def load_data(self, query: Optional[str] = None) -> pd.DataFrame:
148
155
logger .info (f"Loading data from { len (self .paths .data_paths )} source(s)" )
149
156
150
157
try :
151
- dfs = [pd . read_csv (data_path ) for data_path in self .paths .data_paths ]
158
+ dfs = [self . _read_dataframe (data_path ) for data_path in self .paths .data_paths ]
152
159
data = pd .concat (dfs , axis = 0 , ignore_index = True )
153
160
logger .info (f"Initial data shape: { data .shape } " )
154
161
@@ -171,20 +178,19 @@ def create_variables(self) -> None:
171
178
172
179
173
180
"""
174
- for node in self .dag .graph .nodes ():
175
- dtype = self .data [node ].dtype .type if node in self .data .columns else str
181
+ for node_name , node_data in self .dag .graph .nodes (data = True ):
182
+ if node_name not in self .data .columns and not node_data .get ("hidden" , False ):
183
+ raise ValueError (f"Node { node_name } missing from data. Should it be marked as hidden?" )
176
184
177
- # If node has no incoming edges, it's an input
178
- if self .dag .graph .in_degree (node ) == 0 :
179
- self .variables ["inputs" ][node ] = Input (name = node , datatype = dtype )
185
+ dtype = self .data .dtypes .get (node_name )
180
186
181
- # If node has outgoing edges, it can be an input
182
- if self .dag .graph .out_degree ( node ) > 0 :
183
- self .variables ["inputs" ][node ] = Input (name = node , datatype = dtype )
187
+ # If node has no incoming edges, it's an input
188
+ if self .dag .graph .in_degree ( node_name ) == 0 :
189
+ self .variables ["inputs" ][node_name ] = Input (name = node_name , datatype = dtype )
184
190
185
- # If node has incoming edges, it can be an output
186
- if self .dag .graph .in_degree (node ) > 0 :
187
- self .variables ["outputs" ][node ] = Output (name = node , datatype = dtype )
191
+ # Otherwise it's an output
192
+ if self .dag .graph .in_degree (node_name ) > 0 :
193
+ self .variables ["outputs" ][node_name ] = Output (name = node_name , datatype = dtype )
188
194
189
195
def create_scenario_and_specification (self ) -> None :
190
196
"""Create scenario and causal specification objects from loaded data.
@@ -259,7 +265,7 @@ def create_base_test(self, test: dict) -> BaseTestCase:
259
265
:return: BaseTestCase object
260
266
:raises: KeyError if required variables are not found in inputs or outputs
261
267
"""
262
- treatment_name = test ["mutations" ][ 0 ]
268
+ treatment_name = test ["treatment_variable" ]
263
269
outcome_name = next (iter (test ["expected_effect" ].keys ()))
264
270
265
271
# Look for treatment variable in both inputs and outputs
@@ -333,12 +339,11 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
333
339
raise ValueError (f"Unknown estimator: { test ['estimator' ]} " )
334
340
335
341
# Create the estimator with correct parameters
336
- adjustment_set = self .causal_specification .causal_dag .identification (base_test )
337
342
estimator = estimator_class (
338
343
base_test_case = base_test ,
339
- treatment_value = 1.0 , # hardcode these for now
340
- control_value = 0.0 ,
341
- adjustment_set = adjustment_set ,
344
+ treatment_value = test . get ( "treatment_value" ),
345
+ control_value = test . get ( "control_value" ) ,
346
+ adjustment_set = test . get ( " adjustment_set" , self . causal_specification . causal_dag . identification ( base_test )) ,
342
347
df = self .data ,
343
348
effect_modifiers = None ,
344
349
formula = test .get ("formula" ),
@@ -416,7 +421,7 @@ def save_results(self, results: List[CausalTestResult]) -> None:
416
421
"name" : test_config ["name" ],
417
422
"estimate_type" : test_config ["estimate_type" ],
418
423
"effect" : test_config .get ("effect" , "direct" ),
419
- "mutations " : test_config ["mutations " ],
424
+ "treatment_variable " : test_config ["treatment_variable " ],
420
425
"expected_effect" : test_config ["expected_effect" ],
421
426
"formula" : test_config .get ("formula" ),
422
427
"alpha" : test_config .get ("alpha" , 0.05 ),
0 commit comments