@@ -289,7 +289,7 @@ def check_draft(self, ground_truth, input_ids, draft_tokens):
289
289
290
290
return input_ids
291
291
292
- def check_data_consistency_across_ranks (self , data , group = None , fail_when_mismatch = True ):
292
+ def check_data_consistency_across_ranks (self , data , group = None , fail_when_mismatch = False ):
293
293
"""This function checks the data consistency across all ranks in the group.
294
294
295
295
Use rank 0 data as the golden set to broadcast to all ranks.
@@ -330,9 +330,7 @@ def validate(
330
330
331
331
if ground_truth is None :
332
332
ground_truth = self .get_ground_truth (input_ids , osl )
333
- ground_truth = self .check_data_consistency_across_ranks (
334
- ground_truth , fail_when_mismatch = False
335
- )
333
+ ground_truth = self .check_data_consistency_across_ranks (ground_truth )
336
334
337
335
cnt = 0
338
336
draft_tokens = None
@@ -347,16 +345,12 @@ def validate(
347
345
348
346
if tree_paths :
349
347
input_id , draft_tokens , pred_tokens = self .model .tree_decode (input_ids , tree = tree )
350
- pred_tokens = self .check_data_consistency_across_ranks (
351
- pred_tokens , fail_when_mismatch = False
352
- )
348
+ pred_tokens = self .check_data_consistency_across_ranks (pred_tokens )
353
349
else :
354
350
input_id , draft_tokens = self .model .pseudo_speculative_generate (
355
351
input_ids , steps = steps
356
352
)
357
- draft_tokens = self .check_data_consistency_across_ranks (
358
- draft_tokens , fail_when_mismatch = False
359
- )
353
+ draft_tokens = self .check_data_consistency_across_ranks (draft_tokens )
360
354
361
355
input_id = self .check_data_consistency_across_ranks (input_id )
362
356
input_ids = torch .cat ((input_ids , input_id ), dim = - 1 )
0 commit comments