|
1 | 1 | import unittest
|
| 2 | +from pathlib import Path |
| 3 | +import tempfile |
| 4 | +import os |
| 5 | + |
2 | 6 | import shutil
|
3 | 7 | import json
|
4 | 8 | import pandas as pd
|
5 |
| -from pathlib import Path |
6 |
| -from causal_testing.main import CausalTestingPaths, CausalTestingFramework, parse_args |
7 |
| -from causal_testing.__main__ import main |
8 | 9 | from unittest.mock import patch
|
9 | 10 |
|
| 11 | +from causal_testing.main import CausalTestingPaths, CausalTestingFramework |
| 12 | +from causal_testing.__main__ import main |
| 13 | + |
10 | 14 |
|
11 | 15 | class TestCausalTestingPaths(unittest.TestCase):
|
12 | 16 |
|
@@ -144,6 +148,29 @@ def test_ctf(self):
|
144 | 148 |
|
145 | 149 | self.assertEqual(tests_passed, [True])
|
146 | 150 |
|
| 151 | + def test_ctf_batches(self): |
| 152 | + framework = CausalTestingFramework(self.paths) |
| 153 | + framework.setup() |
| 154 | + |
| 155 | + # Load and run tests |
| 156 | + framework.load_tests() |
| 157 | + |
| 158 | + output_files = [] |
| 159 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 160 | + for i, results in enumerate(framework.run_tests_in_batches()): |
| 161 | + temp_file_path = os.path.join(tmpdir, f"output_{i}.json") |
| 162 | + framework.save_results(results, temp_file_path) |
| 163 | + output_files.append(temp_file_path) |
| 164 | + del results |
| 165 | + |
| 166 | + # Now stitch the results together from the temporary files |
| 167 | + all_results = [] |
| 168 | + for file_path in output_files: |
| 169 | + with open(file_path, "r", encoding="utf-8") as f: |
| 170 | + all_results.extend(json.load(f)) |
| 171 | + |
| 172 | + self.assertEqual([result["passed"] for result in all_results], [True]) |
| 173 | + |
147 | 174 | def test_global_query(self):
|
148 | 175 | framework = CausalTestingFramework(self.paths)
|
149 | 176 | framework.setup()
|
|
0 commit comments