Skip to content

Commit be260f6

Browse files
committed
set fail_when_mismatch to False
Signed-off-by: Ye Yu <[email protected]>
1 parent 5328b55 commit be260f6

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

modelopt/torch/speculative/utils.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def check_draft(self, ground_truth, input_ids, draft_tokens):
290290

291291
return input_ids
292292

293-
def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=True):
293+
def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=False):
294294
"""This function checks the data consistency across all ranks in the group.
295295
296296
Use rank 0 data as the golden set to broadcast to all ranks.
@@ -331,9 +331,7 @@ def validate(
331331

332332
if ground_truth is None:
333333
ground_truth = self.get_ground_truth(input_ids, osl)
334-
ground_truth = self.check_data_consistency_across_ranks(
335-
ground_truth, fail_when_mismatch=False
336-
)
334+
ground_truth = self.check_data_consistency_across_ranks(ground_truth)
337335

338336
cnt = 0
339337
draft_tokens = None
@@ -348,16 +346,12 @@ def validate(
348346

349347
if tree_paths:
350348
input_id, draft_tokens, pred_tokens = self.model.tree_decode(input_ids, tree=tree)
351-
pred_tokens = self.check_data_consistency_across_ranks(
352-
pred_tokens, fail_when_mismatch=False
353-
)
349+
pred_tokens = self.check_data_consistency_across_ranks(pred_tokens)
354350
else:
355351
input_id, draft_tokens = self.model.pseudo_speculative_generate(
356352
input_ids, steps=steps
357353
)
358-
draft_tokens = self.check_data_consistency_across_ranks(
359-
draft_tokens, fail_when_mismatch=False
360-
)
354+
draft_tokens = self.check_data_consistency_across_ranks(draft_tokens)
361355

362356
input_id = self.check_data_consistency_across_ranks(input_id)
363357
input_ids = torch.cat((input_ids, input_id), dim=-1)

0 commit comments

Comments
 (0)