diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 09377e1f66..8cb8856fbc 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -223,6 +223,8 @@ The configuration for each task dataset is defined as follows: - `temperature`: The temperature for sampling. - `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used. - `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used. +- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters. + ### Trainer Input diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 4fbc1e468b..da5b447cdc 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -45,6 +45,15 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` - **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields. - **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`. - **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`. + - **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field. + +```{tip} +`workflow`, `workflow_args` and `raw_task` provide different levels of customization. + +- `workflow` provides the global settings for all tasks that uses the same workflow. (Global Level) +- `workflow_args` can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level) +- `raw_task` provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level) +``` In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example: diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 8cce2f9e85..0812fb5e6e 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock from tests.tools import get_unittest_dataset_config -from trinity.common.workflows import MathWorkflow +from trinity.common.workflows import MathWorkflow, Workflow from trinity.common.workflows.workflow import Task @@ -15,6 +15,33 @@ class MockResponse: reward: float = 0.0 +class DummyWorkflow(Workflow): + def __init__(self, model, task: Task, auxiliary_models=None): + super().__init__(model, task, auxiliary_models) + self.obj = task.raw_task + self.output_format = task.workflow_args["output_format"] + + @property + def resettable(self): + return True + + def reset(self, task: Task): + self.obj = task.raw_task + self.output_format = task.workflow_args["output_format"] + + def run(self): + if self.output_format == "json": + import json + + return [json.dumps(self.obj)] + elif self.output_format == "yaml": + import yaml + + return [yaml.safe_dump(self.obj)] + else: + raise ValueError("Invalid output format") + + class WorkflowTest(unittest.TestCase): def test_math_workflow(self) -> None: model = MagicMock() @@ -150,3 +177,18 @@ def test_gsm8k_workflow(self) -> None: self.assertEqual(experiences[1].reward, -0.1) self.assertEqual(experiences[2].reward, -0.1) self.assertEqual(experiences[3].reward, 1.1) + + def test_workflow_resettable(self) -> None: + model = MagicMock() + json_task = Task( + workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "json"} + ) + yaml_task = Task( + workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "yaml"} + ) + workflow = json_task.to_workflow(model) + answer = workflow.run() + self.assertEqual(answer[0], '{"a": 1}') + workflow.reset(yaml_task) + answer = workflow.run() + self.assertEqual(answer[0], "a: 1\n") diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 316d3ae297..cb69b5e017 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -221,6 +221,7 @@ def read(self, strategy: Optional[ReadStrategy] = None): workflow=workflow_class, format_args=self.meta.format, rollout_args=self.meta.rollout_args, + workflow_args=self.meta.workflow_args, is_eval=self.meta.task_type == TaskType.EVAL, reward_fn=reward_fn, raw_task=sample, diff --git a/trinity/common/config.py b/trinity/common/config.py index 3feec8e2ea..6bde79a5c9 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -81,6 +81,7 @@ class StorageConfig: default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None rollout_args: GenerationConfig = field(default_factory=GenerationConfig) + workflow_args: dict = field(default_factory=dict) # ! DO NOT SET, automatically set from algorithm.algorithm_type algorithm_type: Optional[AlgorithmType] = None diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 92bf29a64e..f5b1c9a7b9 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -3,10 +3,11 @@ from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow -from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task +from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow __all__ = [ "Task", + "Workflow", "WORKFLOWS", "SimpleWorkflow", "MathWorkflow", diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 169ad63279..fc4a87556b 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -28,8 +28,9 @@ class Task: """A Task class that defines a task and its associated reward function / workflow.""" workflow: Type[Workflow] - format_args: FormatConfig + format_args: FormatConfig = field(default_factory=FormatConfig) rollout_args: GenerationConfig = field(default_factory=GenerationConfig) + workflow_args: dict = field(default_factory=dict) is_eval: bool = False reward_fn: Optional[Type[RewardFn]] = None raw_task: Optional[dict] = None # The raw data sample diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index e3aa216eda..e80afaf59b 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -15,12 +15,19 @@ def simple_answer_parser(response: str) -> str: return parse(response) -def find_boxed_answer(string): +def find_boxed_answer(raw_answer, timeout=10): """ - Find answers from solutions where the answers are enclosed in LaTeX's `\boxed` tag + Find answers from solutions where the answers are enclosed in LaTeX's `\\boxed` tag + + Args: + raw_answer (`str`): raw answer from model + timeout (`int`): timeout in seconds for regex + + Returns: + `str`: answer if found, otherwise None """ pattern = r"\\boxed\s*(({(?:\\.|[^{}]|(?2))*})|(.))" - res = re.findall(pattern, string) + res = re.findall(pattern, raw_answer, timeout=timeout) if res: answer = res[-1][0] # regard the last boxed as the answer if answer.startswith("{"):