Skip to content

Commit f35da6e

Browse files
committed
fix
1 parent d358b71 commit f35da6e

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def sample(logits, reqs, eos_id: List[int] = [2]):
4343
logits.div_(temperatures.view((-1, 1)))
4444
probs = torch.softmax(logits, dim=-1)
4545

46-
if get_env_start_args().sampling_backend == "triton":
46+
if get_env_start_args().sampling_backend == "triton":
4747
probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks)
4848
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
4949

@@ -56,12 +56,12 @@ def sample(logits, reqs, eos_id: List[int] = [2]):
5656
from sgl_kernel import top_k_top_p_sampling_from_probs
5757

5858
batch_next_token_ids = top_k_top_p_sampling_from_probs(
59-
probs,
60-
top_ks,
61-
top_ps,
62-
filter_apply_order="joint",
63-
check_nan=True,
64-
)
59+
probs,
60+
top_ks,
61+
top_ps,
62+
filter_apply_order="joint",
63+
check_nan=True,
64+
)
6565
int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64)
6666
int64_batch_next_token_ids[:] = batch_next_token_ids
6767
batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1))

0 commit comments

Comments
 (0)