|
9 | 9 | from tqdm import tqdm
|
10 | 10 |
|
11 | 11 | import pandas as pd
|
| 12 | +import numpy as np |
12 | 13 |
|
13 | 14 | from .specification.causal_dag import CausalDAG
|
14 | 15 | from .specification.scenario import Scenario
|
|
22 | 23 | from .estimation.logistic_regression_estimator import LogisticRegressionEstimator
|
23 | 24 |
|
24 | 25 | logger = logging.getLogger(__name__)
|
25 |
| -logger.setLevel(logging.ERROR) |
26 | 26 |
|
27 | 27 |
|
28 | 28 | @dataclass
|
@@ -307,6 +307,62 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
|
307 | 307 | estimator=estimator,
|
308 | 308 | )
|
309 | 309 |
|
| 310 | + def run_tests_in_batches(self, batch_size=100, silent=False) -> List[CausalTestResult]: |
| 311 | + """ |
| 312 | + Run tests in batches to reduce memory usage. |
| 313 | +
|
| 314 | + :param batch_size: Number of tests to run in each batch |
| 315 | + :param silent: Whether to suppress errors |
| 316 | + :return: List of all test results |
| 317 | + :raises: ValueError if no tests are loaded |
| 318 | + """ |
| 319 | + logger.info("Running causal tests in batches...") |
| 320 | + |
| 321 | + if not self.test_cases: |
| 322 | + raise ValueError("No tests loaded. Call load_tests() first.") |
| 323 | + |
| 324 | + num_tests = len(self.test_cases) |
| 325 | + num_batches = int(np.ceil(num_tests / batch_size)) |
| 326 | + |
| 327 | + logger.info(f"Processing {num_tests} tests in {num_batches} batches of up to {batch_size} tests each") |
| 328 | + all_results = [] |
| 329 | + with tqdm(total=num_tests, desc="Overall progress", mininterval=0.1) as overall_pbar: |
| 330 | + # Process each batch |
| 331 | + for batch_idx in range(num_batches): |
| 332 | + start_idx = batch_idx * batch_size |
| 333 | + end_idx = min(start_idx + batch_size, num_tests) |
| 334 | + current_batch_size = end_idx - start_idx |
| 335 | + |
| 336 | + logger.info(f"Processing batch {batch_idx + 1} of {num_batches} (tests {start_idx} to {end_idx - 1})") |
| 337 | + |
| 338 | + # Get current batch of tests |
| 339 | + current_batch = self.test_cases[start_idx:end_idx] |
| 340 | + |
| 341 | + # Process the current batch |
| 342 | + batch_results = [] |
| 343 | + for test_case in current_batch: |
| 344 | + try: |
| 345 | + result = test_case.execute_test() |
| 346 | + batch_results.append(result) |
| 347 | + except Exception as e: |
| 348 | + if not silent: |
| 349 | + logger.error(f"Error running test: {str(e)}") |
| 350 | + raise |
| 351 | + result = CausalTestResult( |
| 352 | + estimator=test_case.estimator, |
| 353 | + test_value=TestValue("Error", str(e)), |
| 354 | + ) |
| 355 | + batch_results.append(result) |
| 356 | + |
| 357 | + overall_pbar.update(1) |
| 358 | + |
| 359 | + all_results.extend(batch_results) |
| 360 | + |
| 361 | + logger.info(f"Completed batch {batch_idx + 1} of {num_batches} ({current_batch_size} tests)") |
| 362 | + |
| 363 | + logger.info(f"Completed processing all {len(all_results)} tests in {num_batches} batches") |
| 364 | + return all_results |
| 365 | + |
310 | 366 | def run_tests(self, silent=False) -> List[CausalTestResult]:
|
311 | 367 | """
|
312 | 368 | Run all test cases and return their results.
|
@@ -418,5 +474,11 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
418 | 474 | help="Do not crash on error. If set to true, errors are recorded as test results.",
|
419 | 475 | default=False,
|
420 | 476 | )
|
| 477 | + parser.add_argument( |
| 478 | + "--batch-size", |
| 479 | + type=int, |
| 480 | + default=0, |
| 481 | + help="Run tests in batches of the specified size (default: 0, which means no batching)", |
| 482 | + ) |
421 | 483 |
|
422 | 484 | return parser.parse_args(args)
|
0 commit comments