@@ -289,7 +289,7 @@ def check_draft(self, ground_truth, input_ids, draft_tokens):
289289
290290 return input_ids
291291
292- def check_data_consistency_across_ranks (self , data , group = None , fail_when_mismatch = True ):
292+ def check_data_consistency_across_ranks (self , data , group = None , fail_when_mismatch = False ):
293293 """This function checks the data consistency across all ranks in the group.
294294
295295 Use rank 0 data as the golden set to broadcast to all ranks.
@@ -330,9 +330,7 @@ def validate(
330330
331331 if ground_truth is None :
332332 ground_truth = self .get_ground_truth (input_ids , osl )
333- ground_truth = self .check_data_consistency_across_ranks (
334- ground_truth , fail_when_mismatch = False
335- )
333+ ground_truth = self .check_data_consistency_across_ranks (ground_truth )
336334
337335 cnt = 0
338336 draft_tokens = None
@@ -347,16 +345,12 @@ def validate(
347345
348346 if tree_paths :
349347 input_id , draft_tokens , pred_tokens = self .model .tree_decode (input_ids , tree = tree )
350- pred_tokens = self .check_data_consistency_across_ranks (
351- pred_tokens , fail_when_mismatch = False
352- )
348+ pred_tokens = self .check_data_consistency_across_ranks (pred_tokens )
353349 else :
354350 input_id , draft_tokens = self .model .pseudo_speculative_generate (
355351 input_ids , steps = steps
356352 )
357- draft_tokens = self .check_data_consistency_across_ranks (
358- draft_tokens , fail_when_mismatch = False
359- )
353+ draft_tokens = self .check_data_consistency_across_ranks (draft_tokens )
360354
361355 input_id = self .check_data_consistency_across_ranks (input_id )
362356 input_ids = torch .cat ((input_ids , input_id ), dim = - 1 )
0 commit comments