@@ -307,7 +307,7 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
307
307
estimator = estimator ,
308
308
)
309
309
310
- def run_tests_in_batches (self , batch_size = 100 , silent = False ) -> List [CausalTestResult ]:
310
+ def run_tests_in_batches (self , batch_size : int = 100 , silent : bool = False ) -> List [CausalTestResult ]:
311
311
"""
312
312
Run tests in batches to reduce memory usage.
313
313
@@ -326,12 +326,11 @@ def run_tests_in_batches(self, batch_size=100, silent=False) -> List[CausalTestR
326
326
327
327
logger .info (f"Processing { num_tests } tests in { num_batches } batches of up to { batch_size } tests each" )
328
328
all_results = []
329
- with tqdm (total = num_tests , desc = "Overall progress" , mininterval = 0.1 ) as overall_pbar :
329
+ with tqdm (total = num_tests , desc = "Overall progress" , mininterval = 0.1 ) as progress :
330
330
# Process each batch
331
331
for batch_idx in range (num_batches ):
332
332
start_idx = batch_idx * batch_size
333
333
end_idx = min (start_idx + batch_size , num_tests )
334
- current_batch_size = end_idx - start_idx
335
334
336
335
logger .info (f"Processing batch { batch_idx + 1 } of { num_batches } (tests { start_idx } to { end_idx - 1 } )" )
337
336
@@ -344,21 +343,21 @@ def run_tests_in_batches(self, batch_size=100, silent=False) -> List[CausalTestR
344
343
try :
345
344
result = test_case .execute_test ()
346
345
batch_results .append (result )
347
- except Exception as e :
346
+ except ( TypeError , AttributeError ) as e :
348
347
if not silent :
349
- logger .error (f"Error running test: { str (e )} " )
348
+ logger .error (f"Type or attribute error in test: { str (e )} " )
350
349
raise
351
350
result = CausalTestResult (
352
351
estimator = test_case .estimator ,
353
352
test_value = TestValue ("Error" , str (e )),
354
353
)
355
354
batch_results .append (result )
356
355
357
- overall_pbar .update (1 )
356
+ progress .update (1 )
358
357
359
358
all_results .extend (batch_results )
360
359
361
- logger .info (f"Completed batch { batch_idx + 1 } of { num_batches } ( { current_batch_size } tests) " )
360
+ logger .info (f"Completed batch { batch_idx + 1 } of { num_batches } " )
362
361
363
362
logger .info (f"Completed processing all { len (all_results )} tests in { num_batches } batches" )
364
363
return all_results
0 commit comments