Skip to content

Commit 292ec59

Browse files
committed
set fail_when_mismatch to False
Signed-off-by: Ye Yu <[email protected]>
1 parent 0f75bba commit 292ec59

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

modelopt/torch/speculative/utils.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def check_draft(self, ground_truth, input_ids, draft_tokens):
289289

290290
return input_ids
291291

292-
def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=True):
292+
def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=False):
293293
"""This function checks the data consistency across all ranks in the group.
294294
295295
Use rank 0 data as the golden set to broadcast to all ranks.
@@ -330,9 +330,7 @@ def validate(
330330

331331
if ground_truth is None:
332332
ground_truth = self.get_ground_truth(input_ids, osl)
333-
ground_truth = self.check_data_consistency_across_ranks(
334-
ground_truth, fail_when_mismatch=False
335-
)
333+
ground_truth = self.check_data_consistency_across_ranks(ground_truth)
336334

337335
cnt = 0
338336
draft_tokens = None
@@ -347,16 +345,12 @@ def validate(
347345

348346
if tree_paths:
349347
input_id, draft_tokens, pred_tokens = self.model.tree_decode(input_ids, tree=tree)
350-
pred_tokens = self.check_data_consistency_across_ranks(
351-
pred_tokens, fail_when_mismatch=False
352-
)
348+
pred_tokens = self.check_data_consistency_across_ranks(pred_tokens)
353349
else:
354350
input_id, draft_tokens = self.model.pseudo_speculative_generate(
355351
input_ids, steps=steps
356352
)
357-
draft_tokens = self.check_data_consistency_across_ranks(
358-
draft_tokens, fail_when_mismatch=False
359-
)
353+
draft_tokens = self.check_data_consistency_across_ranks(draft_tokens)
360354

361355
input_id = self.check_data_consistency_across_ranks(input_id)
362356
input_ids = torch.cat((input_ids, input_id), dim=-1)

0 commit comments

Comments
 (0)