Skip to content

Commit 73c81b7

Browse files
authored
Fix default_sampling_params and simple_workflow (#373)
1 parent 90b55e8 commit 73c81b7

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

trinity/common/models/vllm_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ def __init__(
6262
)
6363
self.default_sampling_params = vllm.SamplingParams(
6464
n=1,
65-
temperature=0.0,
65+
temperature=config.temperature,
6666
max_tokens=config.max_response_tokens,
6767
min_tokens=config.min_response_tokens,
6868
truncate_prompt_tokens=config.max_prompt_tokens,
6969
skip_special_tokens=True,
7070
include_stop_str_in_output=False,
7171
output_kind=RequestOutputKind.FINAL_ONLY,
72-
logprobs=0,
72+
logprobs=config.logprobs,
7373
ignore_eos=config.ignore_eos,
7474
)
7575
self.enable_thinking = config.enable_thinking

trinity/common/workflows/workflow.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc
190190
return experience
191191

192192

193-
@WORKFLOWS.register_module("simple_workflow")
194-
class SimpleWorkflow(Workflow):
195-
"""A workflow for simple single-round task."""
196-
197-
can_reset: bool = True
198-
can_repeat: bool = True
199-
193+
class BaseSimpleWorkflow(Workflow):
200194
def __init__(
201195
self,
202196
*,
@@ -246,6 +240,14 @@ def format_messages(self):
246240
messages.append({"role": "assistant", "content": self.reply_prefix})
247241
return messages
248242

243+
244+
@WORKFLOWS.register_module("simple_workflow")
245+
class SimpleWorkflow(BaseSimpleWorkflow):
246+
"""A workflow for simple single-round task."""
247+
248+
can_reset: bool = True
249+
can_repeat: bool = True
250+
249251
def run(self) -> List[Experience]:
250252
# TODO: Optimize the generate function
251253
messages = self.format_messages()
@@ -272,7 +274,7 @@ def run(self) -> List[Experience]:
272274

273275

274276
@WORKFLOWS.register_module("async_simple_workflow")
275-
class AsyncSimpleWorkflow(Workflow):
277+
class AsyncSimpleWorkflow(BaseSimpleWorkflow):
276278
is_async: bool = True
277279

278280
async def run_async(self) -> List[Experience]:

0 commit comments

Comments
 (0)