Skip to content

Commit 0509fa2

Browse files
committed
Test running in batches produces equivalent output to running normally
1 parent 11d39aa commit 0509fa2

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/main_tests/test_main.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,41 @@ def test_ctf_batches(self):
171171

172172
self.assertEqual([result["passed"] for result in all_results], [True])
173173

174+
def test_ctf_batches_matches_run_tests(self):
175+
# Run the tests normally
176+
framework = CausalTestingFramework(self.paths)
177+
framework.setup()
178+
framework.load_tests()
179+
normale_results = framework.run_tests()
180+
181+
# Run the tests in batches
182+
output_files = []
183+
with tempfile.TemporaryDirectory() as tmpdir:
184+
for i, results in enumerate(framework.run_tests_in_batches()):
185+
temp_file_path = os.path.join(tmpdir, f"output_{i}.json")
186+
framework.save_results(results, temp_file_path)
187+
output_files.append(temp_file_path)
188+
del results
189+
190+
# Now stitch the results together from the temporary files
191+
all_results = []
192+
for file_path in output_files:
193+
with open(file_path, "r", encoding="utf-8") as f:
194+
all_results.extend(json.load(f))
195+
196+
with tempfile.TemporaryDirectory() as tmpdir:
197+
normal_output = os.path.join(tmpdir, f"normal.json")
198+
framework.save_results(normale_results, normal_output)
199+
with open(normal_output) as f:
200+
normal_results = json.load(f)
201+
202+
batch_output = os.path.join(tmpdir, f"batch.json")
203+
with open(batch_output, "w") as f:
204+
json.dump(all_results, f)
205+
with open(batch_output) as f:
206+
batch_results = json.load(f)
207+
self.assertEqual(normal_results, batch_results)
208+
174209
def test_global_query(self):
175210
framework = CausalTestingFramework(self.paths)
176211
framework.setup()

0 commit comments

Comments
 (0)