8
8
from typing import Dict , Any , Optional , List , Union , Sequence
9
9
from tqdm import tqdm
10
10
11
+
11
12
import pandas as pd
12
13
import numpy as np
13
14
@@ -344,7 +345,6 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
344
345
num_batches = int (np .ceil (num_tests / batch_size ))
345
346
346
347
logger .info (f"Processing { num_tests } tests in { num_batches } batches of up to { batch_size } tests each" )
347
- all_results = []
348
348
with tqdm (total = num_tests , desc = "Overall progress" , mininterval = 0.1 ) as progress :
349
349
# Process each batch
350
350
for batch_idx in range (num_batches ):
@@ -360,26 +360,23 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
360
360
batch_results = []
361
361
for test_case in current_batch :
362
362
try :
363
- result = test_case .execute_test ()
364
- batch_results . append ( result )
365
- except ( TypeError , AttributeError ) as e :
363
+ batch_results . append ( test_case .execute_test () )
364
+ # pylint: disable=broad-exception-caught
365
+ except Exception as e :
366
366
if not silent :
367
367
logger .error (f"Type or attribute error in test: { str (e )} " )
368
368
raise
369
- result = CausalTestResult (
370
- estimator = test_case .estimator ,
371
- test_value = TestValue ("Error" , str (e )),
369
+ batch_results .append (
370
+ CausalTestResult (
371
+ estimator = test_case .estimator ,
372
+ test_value = TestValue ("Error" , str (e )),
373
+ )
372
374
)
373
- batch_results .append (result )
374
375
375
376
progress .update (1 )
376
377
377
- all_results .extend (batch_results )
378
-
379
- logger .info (f"Completed batch { batch_idx + 1 } of { num_batches } " )
380
-
381
- logger .info (f"Completed processing all { len (all_results )} tests in { num_batches } batches" )
382
- return all_results
378
+ yield batch_results
379
+ logger .info (f"Completed processing in { num_batches } batches" )
383
380
384
381
def run_tests (self , silent = False ) -> List [CausalTestResult ]:
385
382
"""
@@ -399,7 +396,6 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
399
396
try :
400
397
result = test_case .execute_test ()
401
398
results .append (result )
402
- logger .info (f"Test completed: { test_case } " )
403
399
# pylint: disable=broad-exception-caught
404
400
except Exception as e :
405
401
if not silent :
@@ -414,9 +410,11 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
414
410
415
411
return results
416
412
417
- def save_results (self , results : List [CausalTestResult ]) -> None :
413
+ def save_results (self , results : List [CausalTestResult ], output_path : str = None ) -> None :
418
414
"""Save test results to JSON file in the expected format."""
419
- logger .info (f"Saving results to { self .paths .output_path } " )
415
+ if output_path is None :
416
+ output_path = self .paths .output_path
417
+ logger .info (f"Saving results to { output_path } " )
420
418
421
419
# Load original test configs to preserve test metadata
422
420
with open (self .paths .test_config_path , "r" , encoding = "utf-8" ) as f :
@@ -460,7 +458,7 @@ def save_results(self, results: List[CausalTestResult]) -> None:
460
458
json_results .append (output )
461
459
462
460
# Save to file
463
- with open (self . paths . output_path , "w" , encoding = "utf-8" ) as f :
461
+ with open (output_path , "w" , encoding = "utf-8" ) as f :
464
462
json .dump (json_results , f , indent = 2 )
465
463
466
464
logger .info ("Results saved successfully" )
0 commit comments