Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def __init__(
)
self.default_sampling_params = vllm.SamplingParams(
n=1,
temperature=0.0,
temperature=config.temperature,
max_tokens=config.max_response_tokens,
min_tokens=config.min_response_tokens,
truncate_prompt_tokens=config.max_prompt_tokens,
skip_special_tokens=True,
include_stop_str_in_output=False,
output_kind=RequestOutputKind.FINAL_ONLY,
logprobs=0,
logprobs=config.logprobs,
ignore_eos=config.ignore_eos,
)
self.enable_thinking = config.enable_thinking
Expand Down
18 changes: 10 additions & 8 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc
return experience


@WORKFLOWS.register_module("simple_workflow")
class SimpleWorkflow(Workflow):
"""A workflow for simple single-round task."""

can_reset: bool = True
can_repeat: bool = True

class BaseSimpleWorkflow(Workflow):
def __init__(
self,
*,
Expand Down Expand Up @@ -246,6 +240,14 @@ def format_messages(self):
messages.append({"role": "assistant", "content": self.reply_prefix})
return messages


@WORKFLOWS.register_module("simple_workflow")
class SimpleWorkflow(BaseSimpleWorkflow):
"""A workflow for simple single-round task."""

can_reset: bool = True
can_repeat: bool = True

def run(self) -> List[Experience]:
# TODO: Optimize the generate function
messages = self.format_messages()
Expand All @@ -272,7 +274,7 @@ def run(self) -> List[Experience]:


@WORKFLOWS.register_module("async_simple_workflow")
class AsyncSimpleWorkflow(Workflow):
class AsyncSimpleWorkflow(BaseSimpleWorkflow):
is_async: bool = True

async def run_async(self) -> List[Experience]:
Expand Down