Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def __init__(
)
self.default_sampling_params = vllm.SamplingParams(
n=1,
temperature=0.0,
max_tokens=config.max_response_tokens,
min_tokens=config.min_response_tokens,
truncate_prompt_tokens=config.max_prompt_tokens,
skip_special_tokens=True,
Expand Down
18 changes: 10 additions & 8 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,7 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc
return experience


@WORKFLOWS.register_module("simple_workflow")
class SimpleWorkflow(Workflow):
"""A workflow for simple single-round task."""

can_reset: bool = True
can_repeat: bool = True

class BaseSimpleWorkflow(Workflow):
def __init__(
self,
*,
Expand Down Expand Up @@ -246,6 +240,14 @@ def format_messages(self):
messages.append({"role": "assistant", "content": self.reply_prefix})
return messages


@WORKFLOWS.register_module("simple_workflow")
class SimpleWorkflow(BaseSimpleWorkflow):
"""A workflow for simple single-round task."""

can_reset: bool = True
can_repeat: bool = True

def run(self) -> List[Experience]:
# TODO: Optimize the generate function
messages = self.format_messages()
Expand All @@ -272,7 +274,7 @@ def run(self) -> List[Experience]:


@WORKFLOWS.register_module("async_simple_workflow")
class AsyncSimpleWorkflow(Workflow):
class AsyncSimpleWorkflow(BaseSimpleWorkflow):
is_async: bool = True

async def run_async(self) -> List[Experience]:
Expand Down