|
9 | 9 | from tqdm import tqdm
|
10 | 10 |
|
11 | 11 | import pandas as pd
|
12 |
| - |
13 |
| -from .specification.causal_dag import CausalDAG |
14 |
| -from .specification.scenario import Scenario |
15 |
| -from .specification.variable import Input, Output |
16 |
| -from .specification.causal_specification import CausalSpecification |
17 |
| -from .testing.causal_test_case import CausalTestCase |
18 |
| -from .testing.base_test_case import BaseTestCase |
19 |
| -from .testing.causal_test_outcome import NoEffect, SomeEffect, Positive, Negative |
20 |
| -from .testing.causal_test_result import CausalTestResult, TestValue |
21 |
| -from .estimation.linear_regression_estimator import LinearRegressionEstimator |
22 |
| -from .estimation.logistic_regression_estimator import LogisticRegressionEstimator |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from causal_testing.specification.causal_dag import CausalDAG |
| 15 | +from causal_testing.specification.scenario import Scenario |
| 16 | +from causal_testing.specification.variable import Input, Output |
| 17 | +from causal_testing.specification.causal_specification import CausalSpecification |
| 18 | +from causal_testing.testing.causal_test_case import CausalTestCase |
| 19 | +from causal_testing.testing.base_test_case import BaseTestCase |
| 20 | +from causal_testing.testing.causal_test_outcome import NoEffect, SomeEffect, Positive, Negative |
| 21 | +from causal_testing.testing.causal_test_result import CausalTestResult, TestValue |
| 22 | +from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator |
| 23 | +from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator |
23 | 24 |
|
24 | 25 | logger = logging.getLogger(__name__)
|
25 |
| -logger.setLevel(logging.ERROR) |
26 | 26 |
|
27 | 27 |
|
28 | 28 | @dataclass
|
@@ -307,6 +307,61 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
|
307 | 307 | estimator=estimator,
|
308 | 308 | )
|
309 | 309 |
|
| 310 | + def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> List[CausalTestResult]: |
| 311 | + """ |
| 312 | + Run tests in batches to reduce memory usage. |
| 313 | +
|
| 314 | + :param batch_size: Number of tests to run in each batch |
| 315 | + :param silent: Whether to suppress errors |
| 316 | + :return: List of all test results |
| 317 | + :raises: ValueError if no tests are loaded |
| 318 | + """ |
| 319 | + logger.info("Running causal tests in batches...") |
| 320 | + |
| 321 | + if not self.test_cases: |
| 322 | + raise ValueError("No tests loaded. Call load_tests() first.") |
| 323 | + |
| 324 | + num_tests = len(self.test_cases) |
| 325 | + num_batches = int(np.ceil(num_tests / batch_size)) |
| 326 | + |
| 327 | + logger.info(f"Processing {num_tests} tests in {num_batches} batches of up to {batch_size} tests each") |
| 328 | + all_results = [] |
| 329 | + with tqdm(total=num_tests, desc="Overall progress", mininterval=0.1) as progress: |
| 330 | + # Process each batch |
| 331 | + for batch_idx in range(num_batches): |
| 332 | + start_idx = batch_idx * batch_size |
| 333 | + end_idx = min(start_idx + batch_size, num_tests) |
| 334 | + |
| 335 | + logger.info(f"Processing batch {batch_idx + 1} of {num_batches} (tests {start_idx} to {end_idx - 1})") |
| 336 | + |
| 337 | + # Get current batch of tests |
| 338 | + current_batch = self.test_cases[start_idx:end_idx] |
| 339 | + |
| 340 | + # Process the current batch |
| 341 | + batch_results = [] |
| 342 | + for test_case in current_batch: |
| 343 | + try: |
| 344 | + result = test_case.execute_test() |
| 345 | + batch_results.append(result) |
| 346 | + except (TypeError, AttributeError) as e: |
| 347 | + if not silent: |
| 348 | + logger.error(f"Type or attribute error in test: {str(e)}") |
| 349 | + raise |
| 350 | + result = CausalTestResult( |
| 351 | + estimator=test_case.estimator, |
| 352 | + test_value=TestValue("Error", str(e)), |
| 353 | + ) |
| 354 | + batch_results.append(result) |
| 355 | + |
| 356 | + progress.update(1) |
| 357 | + |
| 358 | + all_results.extend(batch_results) |
| 359 | + |
| 360 | + logger.info(f"Completed batch {batch_idx + 1} of {num_batches}") |
| 361 | + |
| 362 | + logger.info(f"Completed processing all {len(all_results)} tests in {num_batches} batches") |
| 363 | + return all_results |
| 364 | + |
310 | 365 | def run_tests(self, silent=False) -> List[CausalTestResult]:
|
311 | 366 | """
|
312 | 367 | Run all test cases and return their results.
|
@@ -418,5 +473,11 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
418 | 473 | help="Do not crash on error. If set to true, errors are recorded as test results.",
|
419 | 474 | default=False,
|
420 | 475 | )
|
| 476 | + parser.add_argument( |
| 477 | + "--batch-size", |
| 478 | + type=int, |
| 479 | + default=0, |
| 480 | + help="Run tests in batches of the specified size (default: 0, which means no batching)", |
| 481 | + ) |
421 | 482 |
|
422 | 483 | return parser.parse_args(args)
|
0 commit comments