Skip to content

Commit a5fa736

Browse files
committed
Merge branch 'f-allian/front-end' of github.com:CITCOM-project/CausalTestingFramework into f-allian/front-end
2 parents 9c7e4cd + 5bdb82e commit a5fa736

File tree

2 files changed

+80
-13
lines changed

2 files changed

+80
-13
lines changed

causal_testing/__main__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ def main() -> None:
3131

3232
# Load and run tests
3333
framework.load_tests()
34-
results = framework.run_tests(silent=args.silent)
34+
35+
if args.batch_size > 0:
36+
logging.info(f"Running tests in batches of size {args.batch_size}")
37+
results = framework.run_tests_in_batches(batch_size=args.batch_size, silent=args.silent)
38+
else:
39+
logging.info("Running tests in regular mode")
40+
results = framework.run_tests(silent=args.silent)
3541

3642
# Save results
3743
framework.save_results(results)

causal_testing/main.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
from tqdm import tqdm
1010

1111
import pandas as pd
12-
13-
from .specification.causal_dag import CausalDAG
14-
from .specification.scenario import Scenario
15-
from .specification.variable import Input, Output
16-
from .specification.causal_specification import CausalSpecification
17-
from .testing.causal_test_case import CausalTestCase
18-
from .testing.base_test_case import BaseTestCase
19-
from .testing.causal_test_outcome import NoEffect, SomeEffect, Positive, Negative
20-
from .testing.causal_test_result import CausalTestResult, TestValue
21-
from .estimation.linear_regression_estimator import LinearRegressionEstimator
22-
from .estimation.logistic_regression_estimator import LogisticRegressionEstimator
12+
import numpy as np
13+
14+
from causal_testing.specification.causal_dag import CausalDAG
15+
from causal_testing.specification.scenario import Scenario
16+
from causal_testing.specification.variable import Input, Output
17+
from causal_testing.specification.causal_specification import CausalSpecification
18+
from causal_testing.testing.causal_test_case import CausalTestCase
19+
from causal_testing.testing.base_test_case import BaseTestCase
20+
from causal_testing.testing.causal_test_outcome import NoEffect, SomeEffect, Positive, Negative
21+
from causal_testing.testing.causal_test_result import CausalTestResult, TestValue
22+
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
23+
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
2324

2425
logger = logging.getLogger(__name__)
25-
logger.setLevel(logging.ERROR)
2626

2727

2828
@dataclass
@@ -307,6 +307,61 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
307307
estimator=estimator,
308308
)
309309

310+
def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> List[CausalTestResult]:
311+
"""
312+
Run tests in batches to reduce memory usage.
313+
314+
:param batch_size: Number of tests to run in each batch
315+
:param silent: Whether to suppress errors
316+
:return: List of all test results
317+
:raises: ValueError if no tests are loaded
318+
"""
319+
logger.info("Running causal tests in batches...")
320+
321+
if not self.test_cases:
322+
raise ValueError("No tests loaded. Call load_tests() first.")
323+
324+
num_tests = len(self.test_cases)
325+
num_batches = int(np.ceil(num_tests / batch_size))
326+
327+
logger.info(f"Processing {num_tests} tests in {num_batches} batches of up to {batch_size} tests each")
328+
all_results = []
329+
with tqdm(total=num_tests, desc="Overall progress", mininterval=0.1) as progress:
330+
# Process each batch
331+
for batch_idx in range(num_batches):
332+
start_idx = batch_idx * batch_size
333+
end_idx = min(start_idx + batch_size, num_tests)
334+
335+
logger.info(f"Processing batch {batch_idx + 1} of {num_batches} (tests {start_idx} to {end_idx - 1})")
336+
337+
# Get current batch of tests
338+
current_batch = self.test_cases[start_idx:end_idx]
339+
340+
# Process the current batch
341+
batch_results = []
342+
for test_case in current_batch:
343+
try:
344+
result = test_case.execute_test()
345+
batch_results.append(result)
346+
except (TypeError, AttributeError) as e:
347+
if not silent:
348+
logger.error(f"Type or attribute error in test: {str(e)}")
349+
raise
350+
result = CausalTestResult(
351+
estimator=test_case.estimator,
352+
test_value=TestValue("Error", str(e)),
353+
)
354+
batch_results.append(result)
355+
356+
progress.update(1)
357+
358+
all_results.extend(batch_results)
359+
360+
logger.info(f"Completed batch {batch_idx + 1} of {num_batches}")
361+
362+
logger.info(f"Completed processing all {len(all_results)} tests in {num_batches} batches")
363+
return all_results
364+
310365
def run_tests(self, silent=False) -> List[CausalTestResult]:
311366
"""
312367
Run all test cases and return their results.
@@ -418,5 +473,11 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
418473
help="Do not crash on error. If set to true, errors are recorded as test results.",
419474
default=False,
420475
)
476+
parser.add_argument(
477+
"--batch-size",
478+
type=int,
479+
default=0,
480+
help="Run tests in batches of the specified size (default: 0, which means no batching)",
481+
)
421482

422483
return parser.parse_args(args)

0 commit comments

Comments
 (0)