Skip to content

Commit 992b273

Browse files
authored
[https://nvbugs/5387375] fix(scaffolding): fix scaffolding aime test in test_e2e (NVIDIA#6140)
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent 200ea9e commit 992b273

File tree

12 files changed

+46
-70
lines changed

12 files changed

+46
-70
lines changed

examples/scaffolding/run_best_of_n_with_reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main():
6060
prompts = [query]
6161

6262
results = llm.generate(prompts)
63-
print(results[0].output.output_str)
63+
print(results[0].outputs[0].text)
6464
llm.shutdown(shutdown_workers=True)
6565
print(f'main shut down done')
6666

examples/scaffolding/run_majority_vote_aime24.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ def main():
101101
result = results[i]
102102
test_case = test_dataset[i]
103103
ref_answer = int(test_case["answer"])
104-
result.result()
105-
output = result.output
106-
extracted_answer = extract_answer_from_boxed(output.output_str)
104+
output = result.outputs[0]
105+
extracted_answer = extract_answer_from_boxed(output.text)
107106
try:
108107
# print(f"[QUESTION]:\n{prompt}\n\n[OUTPUT]\n\n{output.output_str}\n\n")
109108
answer = int(extracted_answer)

tensorrt_llm/scaffolding/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
__all__ = [
1414
"ScaffoldingLlm",
15-
"ScaffoldingOutput",
1615
"ParallelProcess",
1716
"Controller",
1817
"NativeGenerationController",

tensorrt_llm/scaffolding/controller.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
from abc import ABC
33
from enum import Enum
4-
from typing import Any, List, Mapping
4+
from typing import Any, List, Mapping, Tuple
55

66
import torch
77
from torch.nn import functional as F
@@ -231,13 +231,14 @@ def process(self,
231231
generation_kwargs_list)
232232

233233
candidates = [tasks[0].output_str for tasks in tasks_list]
234-
result = self.majority_vote(candidates, **majority_vote_kwargs)
234+
majority_index, majority_answer = self.majority_vote(
235+
candidates, **majority_vote_kwargs)
235236

236-
assert isinstance(result, str), "majority_vote failed"
237+
assert isinstance(majority_answer, str), "majority_vote failed"
237238
# The task returned by majority vote does not have output_tokens and logits.
238-
tasks[0].output_str = result
239+
tasks[0].result = tasks_list[majority_index][0].result
239240

240-
def majority_vote(self, candidates: List[str], **kwargs) -> str:
241+
def majority_vote(self, candidates: List[str], **kwargs) -> Tuple[int, str]:
241242
return get_digit_majority_vote_result(candidates)
242243

243244

@@ -292,7 +293,7 @@ def process(self,
292293

293294
best_task, best_idx = self.select_best(generation_tasks, reward_values,
294295
**select_best_kwargs)
295-
task.output_str = best_task.output_str
296+
task.result = best_task.result
296297

297298
def select_best(self, tasks: List[Task], reward_values, **kwargs) -> Task:
298299
max_index = torch.argmax(torch.tensor(reward_values)).item()

tensorrt_llm/scaffolding/math_utils.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
from collections import Counter
32
from typing import List
43

54

@@ -59,28 +58,31 @@ def get_majority_result(
5958
result_extractor=lambda x: x,
6059
result_validator=lambda x: True,
6160
):
62-
valid_answers_and_results = [(result, result_extractor(result))
63-
for result in results
64-
if result_validator(result) is True
65-
and result_extractor(result) is not None]
66-
if len(valid_answers_and_results) == 0:
61+
extract_answers = [result_extractor(result) for result in results]
62+
valid_answers = [
63+
result for result in extract_answers
64+
if result is not None and result_validator(result) is True
65+
]
66+
if len(valid_answers) == 0:
6767
return None, None
6868

69-
majority_result = Counter(valid_answers_and_results).most_common(1)[0][0]
70-
# return result and extracted result
71-
return majority_result[0], majority_result[1]
69+
answer_counts = {}
70+
for answer in valid_answers:
71+
answer_counts[answer] = answer_counts.get(answer, 0) + 1
72+
majority_answer = max(answer_counts, key=answer_counts.get)
73+
majority_index = next(
74+
filter(lambda x: x[1] == majority_answer,
75+
enumerate(extract_answers)))[0]
76+
return majority_index, majority_answer
7277

7378

7479
def get_digit_majority_vote_result(results: List[str]) -> str:
7580

7681
def is_digit(result: str):
77-
extracted_answer = extract_answer_from_boxed(result)
78-
if extracted_answer is None:
79-
return False
80-
return extracted_answer.isdigit()
82+
return result.isdigit()
8183

82-
vote_result = get_majority_result(
84+
index, extract_answer = get_majority_result(
8385
results,
8486
result_extractor=extract_answer_from_boxed,
85-
result_validator=is_digit)[0]
86-
return vote_result if vote_result else results[0]
87+
result_validator=is_digit)
88+
return (index, extract_answer) if extract_answer else (0, None)

tensorrt_llm/scaffolding/result.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
import asyncio
2-
from dataclasses import dataclass
32
from typing import Mapping, Optional
43

54
from tensorrt_llm.executor.result import GenerationResult
65

76

8-
@dataclass(slots=True)
9-
class ScaffoldingOutput:
10-
11-
def __init__(self):
12-
self.output_str = None
13-
14-
157
class ScaffoldingResult:
168

179
def __init__(self, streaming_event: Optional[asyncio.Event] = None):
1810
super().__init__()
1911
self.aqueue = asyncio.Queue()
20-
self.cur_output = None
12+
self.cur_output: GenerationResult = None
2113
self._done = False
2214
self.task_collections = None
2315
self.streaming_event = streaming_event

tensorrt_llm/scaffolding/scaffolding_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def _handle_task_list(self,
8282
]
8383
await asyncio.gather(*async_tasks)
8484
for task in tasks:
85-
if task.streaming:
85+
if getattr(task, 'streaming', False):
8686
await request.result.set_output_async(task.result)
8787
self.streaming_event.clear()
8888
await self.streaming_event.wait()

tensorrt_llm/scaffolding/task.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ class GenerationTask(Task):
6262
worker_tag: Union[str, "Controller.WorkerTag"] = None
6363

6464
# result field
65-
_outputs: Optional[List[dict]] = None
66-
6765
# link to TRTLLM's GenerationResult, for async update in streaming mode
6866
_result: Optional[GenerationResult] = None
6967

@@ -74,35 +72,36 @@ def result(self) -> GenerationResult:
7472
@result.setter
7573
def result(self, result: GenerationResult) -> None:
7674
self._result = result
77-
self._outputs = result.outputs
75+
76+
@property
77+
def outputs(self) -> Optional[List[dict]]:
78+
return self._result.outputs if self._result else None
7879

7980
@property
8081
def output_tokens(self) -> List[int]:
81-
return self._outputs[
82-
0].token_ids if self.result and self._outputs else None
82+
return self._result.outputs[0].token_ids if self._result else None
8383

8484
@property
8585
def output_str(self) -> Optional[str]:
86-
return self._outputs[0].text if self.result and self._outputs else None
86+
return self._result.outputs[0].text if self._result else None
8787

8888
@output_str.setter
8989
def output_str(self, output) -> Optional[str]:
90-
assert self.result and self._outputs
91-
self._outputs[0].text = output
90+
assert self.result
91+
self._result.outputs[0].text = output
9292

9393
@property
9494
def cumulative_logprob(self) -> Optional[float]:
95-
return self._outputs[
96-
0].cumulative_logprob if self.result and self._outputs else None
95+
return self._result.outputs[
96+
0].cumulative_logprob if self._result else None
9797

9898
@property
9999
def logprobs(self) -> Optional[List[float]]:
100-
return self._outputs[
101-
0].logprobs if self.result and self._outputs else None
100+
return self._result.outputs[0].logprobs if self._result else None
102101

103102
@property
104103
def context_logits(self) -> Optional[torch.Tensor]:
105-
return self.result.context_logits if self.result else None
104+
return self._result.context_logits if self._result else None
106105

107106
@staticmethod
108107
def create_from_prompt(prompt: str) -> "GenerationTask":
@@ -113,7 +112,7 @@ def create_from_prompt(prompt: str) -> "GenerationTask":
113112
return task
114113

115114
def create_scaffolding_output(self) -> GenerationResult:
116-
return self.result
115+
return self._result
117116

118117

119118
@dataclass

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,6 @@ examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp
433433
examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5385987)
434434
examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5385992)
435435
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5377914)
436-
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] SKIP (https://nvbugs/5387375)
437436
examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5387422)
438437
examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5387424)
439438
test_e2e.py::test_ptp_quickstart SKIP (https://nvbugs/5387762)

tests/unittest/scaffolding/test_bench.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class DummyWorker(Worker):
1414

1515
async def dummy_generation_handler(self, task: GenerationTask):
16-
task.output_str = OUTPUT_STR
16+
task.result = OUTPUT_STR
1717
return TaskStatus.SUCCESS
1818

1919
task_handlers = {GenerationTask: dummy_generation_handler}
@@ -29,7 +29,7 @@ def before_yield(self, tasks: List[Task]):
2929
pass
3030

3131
def after_yield(self, tasks: List[Task]):
32-
self.output_len = len(tasks[0].output_str)
32+
self.output_len = len(tasks[0].result)
3333

3434

3535
def test_scaffolding_benchmark():
@@ -56,6 +56,6 @@ def test_scaffolding_benchmark():
5656

5757
assert len(results) == requests_num
5858
assert len(requests_execution_time) == requests_num
59-
assert results[0].output.output_str == OUTPUT_STR
59+
assert results[0].cur_output == OUTPUT_STR
6060
assert results[0].task_collections[
6161
"bench_dummy_collection"].output_len == len(OUTPUT_STR)

0 commit comments

Comments
 (0)