@@ -290,7 +290,7 @@ def check_draft(self, ground_truth, input_ids, draft_tokens):
290
290
291
291
return input_ids
292
292
293
- def check_data_consistency_across_ranks (self , data , group = None , fail_when_mismatch = True ):
293
+ def check_data_consistency_across_ranks (self , data , group = None , fail_when_mismatch = False ):
294
294
"""This function checks the data consistency across all ranks in the group.
295
295
296
296
Use rank 0 data as the golden set to broadcast to all ranks.
@@ -331,9 +331,7 @@ def validate(
331
331
332
332
if ground_truth is None :
333
333
ground_truth = self .get_ground_truth (input_ids , osl )
334
- ground_truth = self .check_data_consistency_across_ranks (
335
- ground_truth , fail_when_mismatch = False
336
- )
334
+ ground_truth = self .check_data_consistency_across_ranks (ground_truth )
337
335
338
336
cnt = 0
339
337
draft_tokens = None
@@ -348,16 +346,12 @@ def validate(
348
346
349
347
if tree_paths :
350
348
input_id , draft_tokens , pred_tokens = self .model .tree_decode (input_ids , tree = tree )
351
- pred_tokens = self .check_data_consistency_across_ranks (
352
- pred_tokens , fail_when_mismatch = False
353
- )
349
+ pred_tokens = self .check_data_consistency_across_ranks (pred_tokens )
354
350
else :
355
351
input_id , draft_tokens = self .model .pseudo_speculative_generate (
356
352
input_ids , steps = steps
357
353
)
358
- draft_tokens = self .check_data_consistency_across_ranks (
359
- draft_tokens , fail_when_mismatch = False
360
- )
354
+ draft_tokens = self .check_data_consistency_across_ranks (draft_tokens )
361
355
362
356
input_id = self .check_data_consistency_across_ranks (input_id )
363
357
input_ids = torch .cat ((input_ids , input_id ), dim = - 1 )
0 commit comments