Skip to content

Commit 4b531e1

Browse files
committed
feat: ability to run tests in batches
1 parent b4d3030 commit 4b531e1

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

causal_testing/__main__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ def main() -> None:
3131

3232
# Load and run tests
3333
framework.load_tests()
34-
results = framework.run_tests(silent=args.silent)
34+
35+
if args.batch_size > 0:
36+
logging.info(f"Running tests in batches of size {args.batch_size}")
37+
results = framework.run_tests_in_batches(batch_size=args.batch_size, silent=args.silent)
38+
else:
39+
logging.info("Running tests in regular mode")
40+
results = framework.run_tests(silent=args.silent)
3541

3642
# Save results
3743
framework.save_results(results)

causal_testing/main.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tqdm import tqdm
1010

1111
import pandas as pd
12+
import numpy as np
1213

1314
from .specification.causal_dag import CausalDAG
1415
from .specification.scenario import Scenario
@@ -22,7 +23,6 @@
2223
from .estimation.logistic_regression_estimator import LogisticRegressionEstimator
2324

2425
logger = logging.getLogger(__name__)
25-
logger.setLevel(logging.ERROR)
2626

2727

2828
@dataclass
@@ -307,6 +307,62 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
307307
estimator=estimator,
308308
)
309309

310+
def run_tests_in_batches(self, batch_size=100, silent=False) -> List[CausalTestResult]:
311+
"""
312+
Run tests in batches to reduce memory usage.
313+
314+
:param batch_size: Number of tests to run in each batch
315+
:param silent: Whether to suppress errors
316+
:return: List of all test results
317+
:raises: ValueError if no tests are loaded
318+
"""
319+
logger.info("Running causal tests in batches...")
320+
321+
if not self.test_cases:
322+
raise ValueError("No tests loaded. Call load_tests() first.")
323+
324+
num_tests = len(self.test_cases)
325+
num_batches = int(np.ceil(num_tests / batch_size))
326+
327+
logger.info(f"Processing {num_tests} tests in {num_batches} batches of up to {batch_size} tests each")
328+
all_results = []
329+
with tqdm(total=num_tests, desc="Overall progress", mininterval=0.1) as overall_pbar:
330+
# Process each batch
331+
for batch_idx in range(num_batches):
332+
start_idx = batch_idx * batch_size
333+
end_idx = min(start_idx + batch_size, num_tests)
334+
current_batch_size = end_idx - start_idx
335+
336+
logger.info(f"Processing batch {batch_idx + 1} of {num_batches} (tests {start_idx} to {end_idx - 1})")
337+
338+
# Get current batch of tests
339+
current_batch = self.test_cases[start_idx:end_idx]
340+
341+
# Process the current batch
342+
batch_results = []
343+
for test_case in current_batch:
344+
try:
345+
result = test_case.execute_test()
346+
batch_results.append(result)
347+
except Exception as e:
348+
if not silent:
349+
logger.error(f"Error running test: {str(e)}")
350+
raise
351+
result = CausalTestResult(
352+
estimator=test_case.estimator,
353+
test_value=TestValue("Error", str(e)),
354+
)
355+
batch_results.append(result)
356+
357+
overall_pbar.update(1)
358+
359+
all_results.extend(batch_results)
360+
361+
logger.info(f"Completed batch {batch_idx + 1} of {num_batches} ({current_batch_size} tests)")
362+
363+
logger.info(f"Completed processing all {len(all_results)} tests in {num_batches} batches")
364+
return all_results
365+
310366
def run_tests(self, silent=False) -> List[CausalTestResult]:
311367
"""
312368
Run all test cases and return their results.
@@ -418,5 +474,11 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
418474
help="Do not crash on error. If set to true, errors are recorded as test results.",
419475
default=False,
420476
)
477+
parser.add_argument(
478+
"--batch-size",
479+
type=int,
480+
default=0,
481+
help="Run tests in batches of the specified size (default: 0, which means no batching)",
482+
)
421483

422484
return parser.parse_args(args)

0 commit comments

Comments
 (0)