Skip to content

Commit d6ba0b4

Browse files
committed
Changed run_tests_in_batches. Still takes up all the memory
1 parent eb96ea7 commit d6ba0b4

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

causal_testing/__main__.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""This module contains the main entrypoint functionality to the Causal Testing Framework."""
22

33
import logging
4+
import tempfile
5+
import json
6+
import os
7+
48
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework
59

610

@@ -34,13 +38,27 @@ def main() -> None:
3438

3539
if args.batch_size > 0:
3640
logging.info(f"Running tests in batches of size {args.batch_size}")
37-
results = framework.run_tests_in_batches(batch_size=args.batch_size, silent=args.silent)
41+
with tempfile.TemporaryDirectory() as tmpdir:
42+
output_files = []
43+
for i, results in enumerate(framework.run_tests_in_batches(batch_size=args.batch_size, silent=args.silent)):
44+
temp_file_path = os.path.join(tmpdir, f"output_{i}.json")
45+
framework.save_results(results, temp_file_path)
46+
output_files.append(temp_file_path)
47+
del results
48+
49+
# Now stitch the results together from the temporary files
50+
all_results = []
51+
for file_path in output_files:
52+
with open(file_path, "r") as f:
53+
all_results.extend(json.load(f))
54+
55+
# Save the final stitched results to your desired location
56+
with open(args.output, "w") as f:
57+
json.dump(all_results, f, indent=4)
3858
else:
3959
logging.info("Running tests in regular mode")
4060
results = framework.run_tests(silent=args.silent)
41-
42-
# Save results
43-
framework.save_results(results)
61+
framework.save_results(results)
4462

4563
logging.info("Causal testing completed successfully.")
4664

causal_testing/main.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict, Any, Optional, List, Union, Sequence
99
from tqdm import tqdm
1010

11+
1112
import pandas as pd
1213
import numpy as np
1314

@@ -344,7 +345,6 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
344345
num_batches = int(np.ceil(num_tests / batch_size))
345346

346347
logger.info(f"Processing {num_tests} tests in {num_batches} batches of up to {batch_size} tests each")
347-
all_results = []
348348
with tqdm(total=num_tests, desc="Overall progress", mininterval=0.1) as progress:
349349
# Process each batch
350350
for batch_idx in range(num_batches):
@@ -360,26 +360,23 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
360360
batch_results = []
361361
for test_case in current_batch:
362362
try:
363-
result = test_case.execute_test()
364-
batch_results.append(result)
365-
except (TypeError, AttributeError) as e:
363+
batch_results.append(test_case.execute_test())
364+
# pylint: disable=broad-exception-caught
365+
except Exception as e:
366366
if not silent:
367367
logger.error(f"Type or attribute error in test: {str(e)}")
368368
raise
369-
result = CausalTestResult(
370-
estimator=test_case.estimator,
371-
test_value=TestValue("Error", str(e)),
369+
batch_results.append(
370+
CausalTestResult(
371+
estimator=test_case.estimator,
372+
test_value=TestValue("Error", str(e)),
373+
)
372374
)
373-
batch_results.append(result)
374375

375376
progress.update(1)
376377

377-
all_results.extend(batch_results)
378-
379-
logger.info(f"Completed batch {batch_idx + 1} of {num_batches}")
380-
381-
logger.info(f"Completed processing all {len(all_results)} tests in {num_batches} batches")
382-
return all_results
378+
yield batch_results
379+
logger.info(f"Completed processing in {num_batches} batches")
383380

384381
def run_tests(self, silent=False) -> List[CausalTestResult]:
385382
"""
@@ -399,7 +396,6 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
399396
try:
400397
result = test_case.execute_test()
401398
results.append(result)
402-
logger.info(f"Test completed: {test_case}")
403399
# pylint: disable=broad-exception-caught
404400
except Exception as e:
405401
if not silent:
@@ -414,9 +410,11 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
414410

415411
return results
416412

417-
def save_results(self, results: List[CausalTestResult]) -> None:
413+
def save_results(self, results: List[CausalTestResult], output_path: str = None) -> None:
418414
"""Save test results to JSON file in the expected format."""
419-
logger.info(f"Saving results to {self.paths.output_path}")
415+
if output_path is None:
416+
output_path = self.paths.output_path
417+
logger.info(f"Saving results to {output_path}")
420418

421419
# Load original test configs to preserve test metadata
422420
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
@@ -460,7 +458,7 @@ def save_results(self, results: List[CausalTestResult]) -> None:
460458
json_results.append(output)
461459

462460
# Save to file
463-
with open(self.paths.output_path, "w", encoding="utf-8") as f:
461+
with open(output_path, "w", encoding="utf-8") as f:
464462
json.dump(json_results, f, indent=2)
465463

466464
logger.info("Results saved successfully")

0 commit comments

Comments
 (0)