diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 0133586f7c..eeb90cec12 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -62,14 +62,14 @@ def __init__( ) self.default_sampling_params = vllm.SamplingParams( n=1, - temperature=0.0, + temperature=config.temperature, max_tokens=config.max_response_tokens, min_tokens=config.min_response_tokens, truncate_prompt_tokens=config.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, - logprobs=0, + logprobs=config.logprobs, ignore_eos=config.ignore_eos, ) self.enable_thinking = config.enable_thinking diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 8a493e161f..91716e1688 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -190,13 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc return experience -@WORKFLOWS.register_module("simple_workflow") -class SimpleWorkflow(Workflow): - """A workflow for simple single-round task.""" - - can_reset: bool = True - can_repeat: bool = True - +class BaseSimpleWorkflow(Workflow): def __init__( self, *, @@ -246,6 +240,14 @@ def format_messages(self): messages.append({"role": "assistant", "content": self.reply_prefix}) return messages + +@WORKFLOWS.register_module("simple_workflow") +class SimpleWorkflow(BaseSimpleWorkflow): + """A workflow for simple single-round task.""" + + can_reset: bool = True + can_repeat: bool = True + def run(self) -> List[Experience]: # TODO: Optimize the generate function messages = self.format_messages() @@ -272,7 +274,7 @@ def run(self) -> List[Experience]: @WORKFLOWS.register_module("async_simple_workflow") -class AsyncSimpleWorkflow(Workflow): +class AsyncSimpleWorkflow(BaseSimpleWorkflow): is_async: bool = True async def run_async(self) -> List[Experience]: