Skip to content

Commit 576542c

Browse files
committed
update
1 parent 2485128 commit 576542c

File tree

2 files changed

+106
-61
lines changed

2 files changed

+106
-61
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -991,20 +991,33 @@ def handle_logprobs(
991991
)
992992
else:
993993
assert beam_width == 1, "beam width must be 1 for non-beam search"
994-
994+
995995
sampled_tokens = request.get_tokens(0)[-count:]
996996

997997
if request.py_num_logprobs == 0:
998998
# Return only the sampled token's logprob
999999
# Compute at least top-1 to determine rank
1000-
if hasattr(request, 'py_sampled_logprobs') and request.py_sampled_logprobs is not None:
1000+
if (
1001+
hasattr(request, "py_sampled_logprobs")
1002+
and request.py_sampled_logprobs is not None
1003+
):
10011004
sampled_logprobs = request.py_sampled_logprobs[:count]
10021005
topk_log_probs_vals = request.py_topk_logprobs_vals[:count] # At least k=1
10031006
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
10041007

10051008
token_log_probs = []
1006-
for step, (sampled_token, sampled_logprob, topk_tokens, topk_logprobs) in enumerate(
1007-
zip(sampled_tokens, sampled_logprobs, topk_log_probs_indices, topk_log_probs_vals)
1009+
for step, (
1010+
sampled_token,
1011+
sampled_logprob,
1012+
topk_tokens,
1013+
topk_logprobs,
1014+
) in enumerate(
1015+
zip(
1016+
sampled_tokens,
1017+
sampled_logprobs,
1018+
topk_log_probs_indices,
1019+
topk_log_probs_vals,
1020+
)
10081021
):
10091022
topk_tokens_list = topk_tokens.tolist()
10101023
if sampled_token in topk_tokens_list:
@@ -1014,38 +1027,53 @@ def handle_logprobs(
10141027
# TODO: fix rank
10151028
rank = 2
10161029

1017-
step_dict = {sampled_token: Logprob(logprob=sampled_logprob.item(), rank=rank)}
1030+
step_dict = {
1031+
sampled_token: Logprob(logprob=sampled_logprob.item(), rank=rank)
1032+
}
10181033
token_log_probs.append(step_dict)
10191034
else:
1020-
raise ValueError("py_sampled_logprobs not available when py_num_logprobs == 0")
1035+
raise ValueError(
1036+
"py_sampled_logprobs not available when py_num_logprobs == 0"
1037+
)
10211038
else:
10221039
# Return top-K logprobs + logprob of sampled token
10231040
sampled_logprobs = request.py_sampled_logprobs[:count]
10241041
topk_log_probs_vals = request.py_topk_logprobs_vals[:count]
10251042
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
10261043

10271044
token_log_probs = []
1028-
for step, (sampled_token, sampled_logprob, topk_tokens, topk_logprobs) in enumerate(
1029-
zip(sampled_tokens, sampled_logprobs, topk_log_probs_indices, topk_log_probs_vals)
1045+
for step, (
1046+
sampled_token,
1047+
sampled_logprob,
1048+
topk_tokens,
1049+
topk_logprobs,
1050+
) in enumerate(
1051+
zip(
1052+
sampled_tokens,
1053+
sampled_logprobs,
1054+
topk_log_probs_indices,
1055+
topk_log_probs_vals,
1056+
)
10301057
):
10311058
step_dict = {}
10321059
topk_tokens_list = topk_tokens.tolist()
10331060
topk_logprobs_list = topk_logprobs.tolist()
10341061

1035-
for rank_idx, (token, logprob) in enumerate(zip(topk_tokens_list, topk_logprobs_list), start=1):
1062+
for rank_idx, (token, logprob) in enumerate(
1063+
zip(topk_tokens_list, topk_logprobs_list), start=1
1064+
):
10361065
step_dict[token] = Logprob(logprob=logprob, rank=rank_idx)
10371066

10381067
if sampled_token not in step_dict:
10391068
# TODO: fix rank
10401069
step_dict[sampled_token] = Logprob(
1041-
logprob=sampled_logprob.item(),
1042-
rank=len(topk_tokens_list) + 1
1070+
logprob=sampled_logprob.item(), rank=len(topk_tokens_list) + 1
10431071
)
10441072
token_log_probs.append(step_dict)
1045-
1073+
10461074
# Wrap in list for non-beam search (beam_width=1)
10471075
token_log_probs = [token_log_probs]
1048-
1076+
10491077
request.py_result.append_log_probs(token_log_probs)
10501078

10511079
def finish_if_reason(
@@ -2518,16 +2546,18 @@ def _process_requests(
25182546
device=logits_cuda.device, non_blocking=True
25192547
)
25202548
logprobs_cuda = F.log_softmax(
2521-
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
2549+
logits_cuda[logprobs_logit_indices_cuda].to(
2550+
dtype=torch.float32, non_blocking=True
2551+
),
25222552
dim=-1,
25232553
)
25242554

2525-
max_k = max(max(1, req.py_num_logprobs) for req in requests if req.py_num_logprobs is not None)
2526-
topk_vals_cuda, topk_indices_cuda = torch.topk(
2527-
logprobs_cuda,
2528-
k=max_k,
2529-
dim=-1
2555+
max_k = max(
2556+
max(1, req.py_num_logprobs)
2557+
for req in requests
2558+
if req.py_num_logprobs is not None
25302559
)
2560+
topk_vals_cuda, topk_indices_cuda = torch.topk(logprobs_cuda, k=max_k, dim=-1)
25312561
# Use a single D2H copy to reduce overheads
25322562
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
25332563
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
@@ -2542,11 +2572,9 @@ def _process_requests(
25422572
# Store at least k=1 for all requests (including logprobs=0) to compute ranks
25432573
k_for_req = max(1, req.py_num_logprobs)
25442574
# NB: Assigning views on memory which is being filled asynchronously
2545-
req.py_topk_logprobs_vals = topk_vals[
2546-
current_offset:next_offset, : k_for_req
2547-
]
2575+
req.py_topk_logprobs_vals = topk_vals[current_offset:next_offset, :k_for_req]
25482576
req.py_topk_logprobs_indices = topk_indices[
2549-
current_offset:next_offset, : k_for_req
2577+
current_offset:next_offset, :k_for_req
25502578
]
25512579

25522580
# context requests do not have multiple input beams, but they need multiple output beams
@@ -2581,15 +2609,16 @@ def _process_requests(
25812609

25822610
# Build offsets for the GROUPED order
25832611
grouped_num_steps = req_num_steps[batch_req_indices]
2584-
grouped_offsets = torch.cat([
2585-
torch.zeros((1,), dtype=torch.int32, pin_memory=True),
2586-
grouped_num_steps.cumsum(dim=0)[:-1]
2587-
])
2612+
grouped_offsets = torch.cat(
2613+
[
2614+
torch.zeros((1,), dtype=torch.int32, pin_memory=True),
2615+
grouped_num_steps.cumsum(dim=0)[:-1],
2616+
]
2617+
)
25882618

25892619
# Reverse mapping: original_req_id → position in grouped result
25902620
req_to_grouped_pos = {
2591-
orig_id.item(): grouped_pos
2592-
for grouped_pos, orig_id in enumerate(batch_req_indices)
2621+
orig_id.item(): grouped_pos for grouped_pos, orig_id in enumerate(batch_req_indices)
25932622
}
25942623

25952624
for req_id in range(len(requests)):
@@ -2599,7 +2628,9 @@ def _process_requests(
25992628
if logprobs_idx == 0:
26002629
start_offset = 0
26012630
else:
2602-
start_offset = sum(req_num_steps[logprobs_req_indices[:logprobs_idx]].tolist())
2631+
start_offset = sum(
2632+
req_num_steps[logprobs_req_indices[:logprobs_idx]].tolist()
2633+
)
26032634

26042635
num_steps_this_req = req_num_steps[req_id].item()
26052636
end_offset = start_offset + num_steps_this_req
@@ -2610,8 +2641,12 @@ def _process_requests(
26102641

26112642
sampled_tokens_this_req = sampled_tokens_cuda[grouped_start:grouped_end]
26122643

2613-
step_indices = torch.arange(start_offset, end_offset, device=logprobs_cuda.device)
2614-
sampled_logprobs_cuda = logprobs_cuda[step_indices, sampled_tokens_this_req.long()]
2644+
step_indices = torch.arange(
2645+
start_offset, end_offset, device=logprobs_cuda.device
2646+
)
2647+
sampled_logprobs_cuda = logprobs_cuda[
2648+
step_indices, sampled_tokens_this_req.long()
2649+
]
26152650

26162651
sampled_logprobs_cpu = sampled_logprobs_cuda.to(device="cpu", non_blocking=True)
26172652
sampled_logprobs_list.append((req_id, sampled_logprobs_cpu))

tests/unittest/_torch/sampler/test_logits_logprobs.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
import torch
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
56
from utils.llm_data import llm_models_root
67
from utils.util import force_ampere
78

@@ -390,41 +391,50 @@ def test_logprobs_with_grouped_samplings_strategies(logprobs_k: int):
390391
f"the wrong token position."
391392

392393

393-
# def test_logprobs_match_hf_tp2():
394-
# """Compare TensorRT-LLM logprobs against HuggingFace reference."""
395-
# model_path = os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0")
396-
# llm = LLM(
397-
# model=model_path,
398-
# tensor_parallel_size=2,
399-
# )
394+
def test_logprobs_match_hf_tp2():
395+
model_path = os.path.join(llm_models_root(), "llama-models-v2",
396+
"TinyLlama-1.1B-Chat-v1.0")
397+
llm = LLM(
398+
model=model_path,
399+
tensor_parallel_size=2,
400+
)
401+
402+
prompts = ["The future of the AI is"]
400403

401-
# sampling_params = SamplingParams(
402-
# max_tokens=10,
403-
# temperature=0,
404-
# logprobs=0,
405-
# )
404+
sampling_params = SamplingParams(
405+
max_tokens=10,
406+
temperature=1.0,
407+
logprobs=0,
408+
)
406409

407-
# hf_model = AutoModelForCausalLM.from_pretrained(modesl_path, torch_dtype=torch.bfloat16).to("cuda")
408-
# hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
410+
hf_model = AutoModelForCausalLM.from_pretrained(
411+
model_path, torch_dtype=torch.bfloat16).to("cuda")
412+
hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
409413

410-
# output = list(llm.generate(prompts, sampling_params=sampling_params))[0]
414+
output = list(llm.generate(prompts, sampling_params=sampling_params))[0]
411415

412-
# trtllm_token_ids = output.outputs[0].token_ids
413-
# trtllm_logprobs = torch.tensor([list(lp.values())[0].logprob for lp in output.outputs[0].logprobs])
416+
trtllm_token_ids = output.outputs[0].token_ids
417+
trtllm_logprobs = torch.tensor(
418+
[list(lp.values())[0].logprob for lp in output.outputs[0].logprobs])
414419

415-
# base_ids = hf_tokenizer.encode(prompts[0], return_tensors="pt").to("cuda")
416-
# hf_logprobs = []
420+
base_ids = hf_tokenizer.encode(prompts[0], return_tensors="pt").to("cuda")
421+
hf_logprobs = []
417422

418-
# for i, token_id in enumerate(trtllm_token_ids):
419-
# input_ids = torch.cat([base_ids, torch.tensor(trtllm_token_ids[:i], device="cuda").unsqueeze(0)], dim=1) if i > 0 else base_ids
420-
# with torch.no_grad():
421-
# logits = hf_model(input_ids).logits[0, -1, :]
422-
# hf_logprobs.append(torch.log_softmax(logits, dim=-1)[token_id].item())
423+
for i, token_id in enumerate(trtllm_token_ids):
424+
if i > 0:
425+
prev_tokens = torch.tensor([trtllm_token_ids[:i]], device="cuda")
426+
input_ids = torch.cat([base_ids, prev_tokens], dim=1)
427+
else:
428+
input_ids = base_ids
429+
with torch.no_grad():
430+
logits = hf_model(input_ids).logits[0, -1, :]
431+
hf_logprobs.append(torch.log_softmax(logits, dim=-1)[token_id].item())
423432

424-
# hf_logprobs = torch.tensor(hf_logprobs)
433+
hf_logprobs = torch.tensor(hf_logprobs)
425434

426-
# print(f"\nTensorRT-LLM logprobs: {trtllm_logprobs}")
427-
# print(f"HuggingFace logprobs: {hf_logprobs}")
435+
print(f"\nTensorRT-LLM logprobs: {trtllm_logprobs}")
436+
print(f"HuggingFace logprobs: {hf_logprobs}")
437+
print(f"Diff: {(trtllm_logprobs - hf_logprobs).abs()}")
428438

429-
# max_diff = (trtllm_logprobs - hf_logprobs).abs().max().item()
430-
# assert max_diff < 0.1, f"Max logprob diff {max_diff:.4f} exceeds threshold"
439+
max_diff = (trtllm_logprobs - hf_logprobs).abs().max().item()
440+
assert max_diff < 0.15, f"Max logprob diff {max_diff:.4f} exceeds threshold"

0 commit comments

Comments
 (0)