Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ The configuration for each task dataset is defined as follows:
- `temperature`: The temperature for sampling.
- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used.
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.


### Trainer Input

Expand Down
9 changes: 9 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task`
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.

```{tip}
`workflow`, `workflow_args` and `raw_task` provide different levels of customization.

- `workflow` provides the global settings for all tasks that uses the same workflow. (Global Level)
- `workflow_args` can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)
- `raw_task` provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
```

In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:

Expand Down
44 changes: 43 additions & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest.mock import MagicMock

from tests.tools import get_unittest_dataset_config
from trinity.common.workflows import MathWorkflow
from trinity.common.workflows import MathWorkflow, Workflow
from trinity.common.workflows.workflow import Task


Expand All @@ -15,6 +15,33 @@ class MockResponse:
reward: float = 0.0


class DummyWorkflow(Workflow):
def __init__(self, model, task: Task, auxiliary_models=None):
super().__init__(model, task, auxiliary_models)
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]

@property
def resettable(self):
return True

def reset(self, task: Task):
self.obj = task.raw_task
self.output_format = task.workflow_args["output_format"]

def run(self):
if self.output_format == "json":
import json

return [json.dumps(self.obj)]
elif self.output_format == "yaml":
import yaml

return [yaml.safe_dump(self.obj)]
else:
raise ValueError("Invalid output format")


class WorkflowTest(unittest.TestCase):
def test_math_workflow(self) -> None:
model = MagicMock()
Expand Down Expand Up @@ -150,3 +177,18 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[1].reward, -0.1)
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)

def test_workflow_resettable(self) -> None:
model = MagicMock()
json_task = Task(
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "json"}
)
yaml_task = Task(
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "yaml"}
)
workflow = json_task.to_workflow(model)
answer = workflow.run()
self.assertEqual(answer[0], '{"a": 1}')
workflow.reset(yaml_task)
answer = workflow.run()
self.assertEqual(answer[0], "a: 1\n")
1 change: 1 addition & 0 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def read(self, strategy: Optional[ReadStrategy] = None):
workflow=workflow_class,
format_args=self.meta.format,
rollout_args=self.meta.rollout_args,
workflow_args=self.meta.workflow_args,
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class StorageConfig:
default_workflow_type: Optional[str] = None
default_reward_fn_type: Optional[str] = None
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)

# ! DO NOT SET, automatically set from algorithm.algorithm_type
algorithm_type: Optional[AlgorithmType] = None
Expand Down
3 changes: 2 additions & 1 deletion trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow

__all__ = [
"Task",
"Workflow",
"WORKFLOWS",
"SimpleWorkflow",
"MathWorkflow",
Expand Down
3 changes: 2 additions & 1 deletion trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class Task:
"""A Task class that defines a task and its associated reward function / workflow."""

workflow: Type[Workflow]
format_args: FormatConfig
format_args: FormatConfig = field(default_factory=FormatConfig)
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
workflow_args: dict = field(default_factory=dict)
is_eval: bool = False
reward_fn: Optional[Type[RewardFn]] = None
raw_task: Optional[dict] = None # The raw data sample
Expand Down
13 changes: 10 additions & 3 deletions trinity/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@ def simple_answer_parser(response: str) -> str:
return parse(response)


def find_boxed_answer(string):
def find_boxed_answer(raw_answer, timeout=10):
"""
Find answers from solutions where the answers are enclosed in LaTeX's `\boxed` tag
Find answers from solutions where the answers are enclosed in LaTeX's `\\boxed` tag

Args:
raw_answer (`str`): raw answer from model
timeout (`int`): timeout in seconds for regex

Returns:
`str`: answer if found, otherwise None
"""
pattern = r"\\boxed\s*(({(?:\\.|[^{}]|(?2))*})|(.))"
res = re.findall(pattern, string)
res = re.findall(pattern, raw_answer, timeout=timeout)
if res:
answer = res[-1][0] # regard the last boxed as the answer
if answer.startswith("{"):
Expand Down