@@ -103,7 +103,6 @@ def setup(self) -> None:
103
103
Set up the framework by loading DAG, runtime csv data, creating the scenario and causal specification.
104
104
105
105
:raises: FileNotFoundError if required files are missing
106
- :raises: Exception if setup process fails
107
106
"""
108
107
109
108
logger .info ("Setting up Causal Testing Framework..." )
@@ -128,17 +127,11 @@ def setup(self) -> None:
128
127
def load_dag (self ) -> CausalDAG :
129
128
"""
130
129
Load the causal DAG from the specified file path.
131
-
132
- :raises: Exception if DAG loading fails
133
130
"""
134
131
logger .info (f"Loading DAG from { self .paths .dag_path } " )
135
- try :
136
- dag = CausalDAG (str (self .paths .dag_path ), ignore_cycles = self .ignore_cycles )
137
- logger .info (f"DAG loaded with { len (dag .graph .nodes )} nodes and { len (dag .graph .edges )} edges" )
138
- return dag
139
- except Exception as e :
140
- logger .error (f"Failed to load DAG: { str (e )} " )
141
- raise
132
+ dag = CausalDAG (str (self .paths .dag_path ), ignore_cycles = self .ignore_cycles )
133
+ logger .info (f"DAG loaded with { len (dag .graph .nodes )} nodes and { len (dag .graph .edges )} edges" )
134
+ return dag
142
135
143
136
def _read_dataframe (self , data_path ):
144
137
if str (data_path ).endswith (".csv" ):
@@ -152,33 +145,22 @@ def load_data(self, query: Optional[str] = None) -> pd.DataFrame:
152
145
153
146
:param query: Optional pandas query string to filter the loaded data
154
147
:return: Combined pandas DataFrame containing all loaded and filtered data
155
- :raises: Exception if data loading or query application fails
156
148
"""
157
149
logger .info (f"Loading data from { len (self .paths .data_paths )} source(s)" )
158
150
159
- try :
160
- dfs = [self ._read_dataframe (data_path ) for data_path in self .paths .data_paths ]
161
- data = pd .concat (dfs , axis = 0 , ignore_index = True )
162
- logger .info (f"Initial data shape: { data .shape } " )
163
-
164
- if query :
165
- try :
166
- logger .info (f"Attempting to apply query: '{ query } '" )
167
- data = data .query (query )
168
- except Exception as e :
169
- logger .error (f"Failed to apply query '{ query } ': { str (e )} " )
170
- raise
151
+ dfs = [self ._read_dataframe (data_path ) for data_path in self .paths .data_paths ]
152
+ data = pd .concat (dfs , axis = 0 , ignore_index = True )
153
+ logger .info (f"Initial data shape: { data .shape } " )
154
+
155
+ if query :
156
+ logger .info (f"Attempting to apply query: '{ query } '" )
157
+ data = data .query (query )
171
158
172
- return data
173
- except Exception as e :
174
- logger .error (f"Failed to load data: { str (e )} " )
175
- raise
159
+ return data
176
160
177
161
def create_variables (self ) -> None :
178
162
"""
179
163
Create variable objects from DAG nodes based on their connectivity.
180
-
181
-
182
164
"""
183
165
for node_name , node_data in self .dag .graph .nodes (data = True ):
184
166
if node_name not in self .data .columns and not node_data .get ("hidden" , False ):
@@ -195,63 +177,25 @@ def create_variables(self) -> None:
195
177
self .variables ["outputs" ][node_name ] = Output (name = node_name , datatype = dtype )
196
178
197
179
def create_scenario_and_specification (self ) -> None :
198
- """Create scenario and causal specification objects from loaded data.
199
-
200
-
201
- :raises: ValueError if scenario constraints filter out all data points
202
- """
180
+ """Create scenario and causal specification objects from loaded data."""
203
181
# Create scenario
204
182
all_variables = list (self .variables ["inputs" ].values ()) + list (self .variables ["outputs" ].values ())
205
183
self .scenario = Scenario (variables = all_variables )
206
184
207
185
# Set up treatment variables
208
186
self .scenario .setup_treatment_variables ()
209
187
210
- # Apply scenario constraints to data
211
- self .apply_scenario_constraints ()
212
-
213
188
# Create causal specification
214
189
self .causal_specification = CausalSpecification (scenario = self .scenario , causal_dag = self .dag )
215
190
216
- def apply_scenario_constraints (self ) -> None :
217
- """
218
- Apply scenario constraints to the loaded data.
219
-
220
- :raises: ValueError if all data points are filtered out by constraints
221
- """
222
- if not self .scenario .constraints :
223
- logger .info ("No scenario constraints to apply" )
224
- return
225
-
226
- original_rows = len (self .data )
227
-
228
- # Apply each constraint directly as a query string
229
- for constraint in self .scenario .constraints :
230
- self .data = self .data .query (str (constraint ))
231
- logger .debug (f"Applied constraint: { constraint } " )
232
-
233
- filtered_rows = len (self .data )
234
- if filtered_rows < original_rows :
235
- logger .info (f"Scenario constraints filtered data from { original_rows } to { filtered_rows } rows" )
236
-
237
- if filtered_rows == 0 :
238
- raise ValueError ("Scenario constraints filtered out all data points. Check your constraints and data." )
239
-
240
191
def load_tests (self ) -> None :
241
192
"""
242
193
Load and prepare test configurations from file.
243
-
244
-
245
- :raises: Exception if test configuration loading fails
246
194
"""
247
195
logger .info (f"Loading test configurations from { self .paths .test_config_path } " )
248
196
249
- try :
250
- with open (self .paths .test_config_path , "r" , encoding = "utf-8" ) as f :
251
- test_configs = json .load (f )
252
- except Exception as e :
253
- logger .error (f"Failed to load test configurations: { str (e )} " )
254
- raise
197
+ with open (self .paths .test_config_path , "r" , encoding = "utf-8" ) as f :
198
+ test_configs = json .load (f )
255
199
256
200
self .test_cases = self .create_test_cases (test_configs )
257
201
@@ -400,63 +344,53 @@ def save_results(self, results: List[CausalTestResult]) -> None:
400
344
"""Save test results to JSON file in the expected format."""
401
345
logger .info (f"Saving results to { self .paths .output_path } " )
402
346
403
- try :
404
- # Load original test configs to preserve test metadata
405
- with open (self .paths .test_config_path , "r" , encoding = "utf-8" ) as f :
406
- test_configs = json .load (f )
407
-
408
- # Combine test configs with their results
409
- json_results = []
410
- for test_config , test_case , result in zip (test_configs ["tests" ], self .test_cases , results ):
411
- # Handle effect estimate - could be a Series or other format
412
- effect_estimate = result .test_value .value
413
- if isinstance (effect_estimate , pd .Series ):
414
- effect_estimate = effect_estimate .to_dict ()
415
-
416
- # Handle confidence intervals - convert to list if needed
417
- ci_low = result .ci_low ()
418
- ci_high = result .ci_high ()
419
- if isinstance (ci_low , pd .Series ):
420
- ci_low = ci_low .tolist ()
421
- if isinstance (ci_high , pd .Series ):
422
- ci_high = ci_high .tolist ()
423
-
424
- # Determine if test failed based on expected vs actual effect
425
- test_passed = (
426
- test_case .expected_causal_effect .apply (result ) if result .test_value .type != "Error" else False
427
- )
428
-
429
- output = {
430
- "name" : test_config ["name" ],
431
- "estimate_type" : test_config ["estimate_type" ],
432
- "effect" : test_config .get ("effect" , "direct" ),
433
- "treatment_variable" : test_config ["treatment_variable" ],
434
- "expected_effect" : test_config ["expected_effect" ],
435
- "formula" : test_config .get ("formula" ),
436
- "alpha" : test_config .get ("alpha" , 0.05 ),
437
- "skip" : test_config .get ("skip" , False ),
438
- "passed" : test_passed ,
439
- "result" : {
440
- "treatment" : result .estimator .base_test_case .treatment_variable .name ,
441
- "outcome" : result .estimator .base_test_case .outcome_variable .name ,
442
- "adjustment_set" : list (result .adjustment_set ) if result .adjustment_set else [],
443
- "effect_measure" : result .test_value .type ,
444
- "effect_estimate" : effect_estimate ,
445
- "ci_low" : ci_low ,
446
- "ci_high" : ci_high ,
447
- },
448
- }
449
- json_results .append (output )
450
-
451
- # Save to file
452
- with open (self .paths .output_path , "w" , encoding = "utf-8" ) as f :
453
- json .dump (json_results , f , indent = 2 )
454
-
455
- logger .info ("Results saved successfully" )
456
-
457
- except Exception as e :
458
- logger .error (f"Failed to save results: { str (e )} " )
459
- raise
347
+ # Load original test configs to preserve test metadata
348
+ with open (self .paths .test_config_path , "r" , encoding = "utf-8" ) as f :
349
+ test_configs = json .load (f )
350
+
351
+ # Combine test configs with their results
352
+ json_results = []
353
+ for test_config , test_case , result in zip (test_configs ["tests" ], self .test_cases , results ):
354
+ # Handle effect estimate - could be a Series or other format
355
+ effect_estimate = result .test_value .value
356
+ if isinstance (effect_estimate , pd .Series ):
357
+ effect_estimate = effect_estimate .to_dict ()
358
+
359
+ # Handle confidence intervals - convert to list if needed
360
+ ci_low = result .ci_low ()
361
+ ci_high = result .ci_high ()
362
+
363
+ # Determine if test failed based on expected vs actual effect
364
+ test_passed = test_case .expected_causal_effect .apply (result ) if result .test_value .type != "Error" else False
365
+
366
+ output = {
367
+ "name" : test_config ["name" ],
368
+ "estimate_type" : test_config ["estimate_type" ],
369
+ "effect" : test_config .get ("effect" , "direct" ),
370
+ "treatment_variable" : test_config ["treatment_variable" ],
371
+ "expected_effect" : test_config ["expected_effect" ],
372
+ "formula" : test_config .get ("formula" ),
373
+ "alpha" : test_config .get ("alpha" , 0.05 ),
374
+ "skip" : test_config .get ("skip" , False ),
375
+ "passed" : test_passed ,
376
+ "result" : {
377
+ "treatment" : result .estimator .base_test_case .treatment_variable .name ,
378
+ "outcome" : result .estimator .base_test_case .outcome_variable .name ,
379
+ "adjustment_set" : list (result .adjustment_set ) if result .adjustment_set else [],
380
+ "effect_measure" : result .test_value .type ,
381
+ "effect_estimate" : effect_estimate ,
382
+ "ci_low" : ci_low ,
383
+ "ci_high" : ci_high ,
384
+ },
385
+ }
386
+ json_results .append (output )
387
+
388
+ # Save to file
389
+ with open (self .paths .output_path , "w" , encoding = "utf-8" ) as f :
390
+ json .dump (json_results , f , indent = 2 )
391
+
392
+ logger .info ("Results saved successfully" )
393
+ return json_results
460
394
461
395
462
396
def setup_logging (verbose : bool = False ) -> None :
0 commit comments