@@ -171,6 +171,41 @@ def test_ctf_batches(self):
171
171
172
172
self .assertEqual ([result ["passed" ] for result in all_results ], [True ])
173
173
174
+ def test_ctf_batches_matches_run_tests (self ):
175
+ # Run the tests normally
176
+ framework = CausalTestingFramework (self .paths )
177
+ framework .setup ()
178
+ framework .load_tests ()
179
+ normale_results = framework .run_tests ()
180
+
181
+ # Run the tests in batches
182
+ output_files = []
183
+ with tempfile .TemporaryDirectory () as tmpdir :
184
+ for i , results in enumerate (framework .run_tests_in_batches ()):
185
+ temp_file_path = os .path .join (tmpdir , f"output_{ i } .json" )
186
+ framework .save_results (results , temp_file_path )
187
+ output_files .append (temp_file_path )
188
+ del results
189
+
190
+ # Now stitch the results together from the temporary files
191
+ all_results = []
192
+ for file_path in output_files :
193
+ with open (file_path , "r" , encoding = "utf-8" ) as f :
194
+ all_results .extend (json .load (f ))
195
+
196
+ with tempfile .TemporaryDirectory () as tmpdir :
197
+ normal_output = os .path .join (tmpdir , f"normal.json" )
198
+ framework .save_results (normale_results , normal_output )
199
+ with open (normal_output ) as f :
200
+ normal_results = json .load (f )
201
+
202
+ batch_output = os .path .join (tmpdir , f"batch.json" )
203
+ with open (batch_output , "w" ) as f :
204
+ json .dump (all_results , f )
205
+ with open (batch_output ) as f :
206
+ batch_results = json .load (f )
207
+ self .assertEqual (normal_results , batch_results )
208
+
174
209
def test_global_query (self ):
175
210
framework = CausalTestingFramework (self .paths )
176
211
framework .setup ()
0 commit comments