Skip to content

Commit e773560

Browse files
authored
Add workflow_args for fine-grained control (#73)
1 parent 77244d8 commit e773560

File tree

8 files changed

+70
-6
lines changed

8 files changed

+70
-6
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ The configuration for each task dataset is defined as follows:
223223
- `temperature`: The temperature for sampling.
224224
- `default_workflow_type`: Type of workflow logic applied to this dataset. If not specified, the `buffer.default_workflow_type` is used.
225225
- `default_reward_fn_type`: Reward function used during exploration. If not specified, the `buffer.default_reward_fn_type` is used.
226+
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.
227+
226228

227229
### Trainer Input
228230

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ To handle differences in `Task` contents, Trinity-RFT provides a unified `Task`
4545
- **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields.
4646
- **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`.
4747
- **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`.
48+
- **`workflow_args`** (`Dict`): A dictionary of parameters to facilitate the construction of `Workflow` instances. Provides more flexibility than `format_args` and `rollout_args` by using a dictionary. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.workflow_args`. Normally, you do not need to set this field.
49+
50+
```{tip}
51+
`workflow`, `workflow_args` and `raw_task` provide different levels of customization.
52+
53+
- `workflow` provides the global settings for all tasks that uses the same workflow. (Global Level)
54+
- `workflow_args` can be set for each task dataset, allowing different task datasets using the same workflow to behave differently. (Dataset Level)
55+
- `raw_task` provides the ability to customize the behavior of each task, which is most flexible. (Data Sample Level)
56+
```
4857

4958
In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example:
5059

tests/explorer/workflow_test.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from unittest.mock import MagicMock
66

77
from tests.tools import get_unittest_dataset_config
8-
from trinity.common.workflows import MathWorkflow
8+
from trinity.common.workflows import MathWorkflow, Workflow
99
from trinity.common.workflows.workflow import Task
1010

1111

@@ -15,6 +15,33 @@ class MockResponse:
1515
reward: float = 0.0
1616

1717

18+
class DummyWorkflow(Workflow):
19+
def __init__(self, model, task: Task, auxiliary_models=None):
20+
super().__init__(model, task, auxiliary_models)
21+
self.obj = task.raw_task
22+
self.output_format = task.workflow_args["output_format"]
23+
24+
@property
25+
def resettable(self):
26+
return True
27+
28+
def reset(self, task: Task):
29+
self.obj = task.raw_task
30+
self.output_format = task.workflow_args["output_format"]
31+
32+
def run(self):
33+
if self.output_format == "json":
34+
import json
35+
36+
return [json.dumps(self.obj)]
37+
elif self.output_format == "yaml":
38+
import yaml
39+
40+
return [yaml.safe_dump(self.obj)]
41+
else:
42+
raise ValueError("Invalid output format")
43+
44+
1845
class WorkflowTest(unittest.TestCase):
1946
def test_math_workflow(self) -> None:
2047
model = MagicMock()
@@ -150,3 +177,18 @@ def test_gsm8k_workflow(self) -> None:
150177
self.assertEqual(experiences[1].reward, -0.1)
151178
self.assertEqual(experiences[2].reward, -0.1)
152179
self.assertEqual(experiences[3].reward, 1.1)
180+
181+
def test_workflow_resettable(self) -> None:
182+
model = MagicMock()
183+
json_task = Task(
184+
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "json"}
185+
)
186+
yaml_task = Task(
187+
workflow=DummyWorkflow, raw_task={"a": 1}, workflow_args={"output_format": "yaml"}
188+
)
189+
workflow = json_task.to_workflow(model)
190+
answer = workflow.run()
191+
self.assertEqual(answer[0], '{"a": 1}')
192+
workflow.reset(yaml_task)
193+
answer = workflow.run()
194+
self.assertEqual(answer[0], "a: 1\n")

trinity/buffer/reader/file_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def read(self, strategy: Optional[ReadStrategy] = None):
221221
workflow=workflow_class,
222222
format_args=self.meta.format,
223223
rollout_args=self.meta.rollout_args,
224+
workflow_args=self.meta.workflow_args,
224225
is_eval=self.meta.task_type == TaskType.EVAL,
225226
reward_fn=reward_fn,
226227
raw_task=sample,

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class StorageConfig:
8484
default_workflow_type: Optional[str] = None
8585
default_reward_fn_type: Optional[str] = None
8686
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
87+
workflow_args: dict = field(default_factory=dict)
8788

8889
# ! DO NOT SET, automatically set from algorithm.algorithm_type
8990
algorithm_type: Optional[AlgorithmType] = None

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from .envs.alfworld.alfworld_workflow import AlfworldWorkflow
44
from .envs.sciworld.sciworld_workflow import SciWorldWorkflow
55
from .envs.webshop.webshop_workflow import WebShopWorkflow
6-
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task
6+
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow
77

88
__all__ = [
99
"Task",
10+
"Workflow",
1011
"WORKFLOWS",
1112
"SimpleWorkflow",
1213
"MathWorkflow",

trinity/common/workflows/workflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ class Task:
2828
"""A Task class that defines a task and its associated reward function / workflow."""
2929

3030
workflow: Type[Workflow]
31-
format_args: FormatConfig
31+
format_args: FormatConfig = field(default_factory=FormatConfig)
3232
rollout_args: GenerationConfig = field(default_factory=GenerationConfig)
33+
workflow_args: dict = field(default_factory=dict)
3334
is_eval: bool = False
3435
reward_fn: Optional[Type[RewardFn]] = None
3536
raw_task: Optional[dict] = None # The raw data sample

trinity/utils/eval_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ def simple_answer_parser(response: str) -> str:
1515
return parse(response)
1616

1717

18-
def find_boxed_answer(string):
18+
def find_boxed_answer(raw_answer, timeout=10):
1919
"""
20-
Find answers from solutions where the answers are enclosed in LaTeX's `\boxed` tag
20+
Find answers from solutions where the answers are enclosed in LaTeX's `\\boxed` tag
21+
22+
Args:
23+
raw_answer (`str`): raw answer from model
24+
timeout (`int`): timeout in seconds for regex
25+
26+
Returns:
27+
`str`: answer if found, otherwise None
2128
"""
2229
pattern = r"\\boxed\s*(({(?:\\.|[^{}]|(?2))*})|(.))"
23-
res = re.findall(pattern, string)
30+
res = re.findall(pattern, raw_answer, timeout=timeout)
2431
if res:
2532
answer = res[-1][0] # regard the last boxed as the answer
2633
if answer.startswith("{"):

0 commit comments

Comments
 (0)