diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 58e467b20a..2e4daeab0b 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -1,6 +1,6 @@ # Developer Guide -This guide will introduce how to add new task types to Trinity-RFT and provide relevant development guidelines. +This guide introduces how to add new workflows to Trinity-RFT and provides relevant development guidelines. ```{note} Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code. @@ -8,9 +8,9 @@ Trinity-RFT is still under development, and the following interfaces may change. --- -## Creating New Task Types +## Creating New Workflows -Trinity-RFT supports developers in registering new task types (e.g., multi-round interaction scenarios). Below are the steps for creating a new task type. +Trinity-RFT allows developers to register new workflows (e.g., for multi-turn interactions or agentic scenarios). Below are the steps to create a new workflow: --- @@ -18,24 +18,35 @@ Trinity-RFT supports developers in registering new task types (e.g., multi-round Before starting development, it's important to understand several core concepts: -- **Task**: Represents a data structure that can be converted into a `Workflow`. The `Task` data format may vary significantly depending on the type of task: - - **Math problems**: `Task` contains the problem description and the standard answer. - - **Programming scenarios**: `Task` includes the problem description, test cases, runtime environment, and other complex information. -- **Workflow**: Can be understood as the running state of a `Task`, defining the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates `Experience`. Trinity-RFT has several built-in `Workflows`: - - `MathWorkflow`: For math scenarios, submits problems to LLM, parses results, and calculates scores (rewards). +- **Task** ({class}`trinity.common.workflows.Task`): Represents a data structure that can be converted into a `Workflow`. The content of the `Task` varies depending on the task type: + - **Math problems**: A `Task` contains the problem description and the golden answer. + - **Programming scenarios**: A `Task` includes the problem description, test cases, runtime environment, and other complex information. + + +- **Workflow** ({class}`trinity.common.workflows.Workflow`): Can be understood as the running state of a `Task`. It defines the interaction flow between Agents and Environments, including logic similar to _Rollout_ and _Reward_ calculations in other frameworks. After execution, it generates a list of `Experience`. Trinity-RFT includes several built-in workflows: + - `MathWorkflow` ({class}`trinity.common.workflows.MathWorkflow`): For math scenarios, submits problems to LLM, parses LLM responses, and calculates scores (rewards). + - `WebShopWorkflow` ({class}`trinity.common.workflows.WebShopWorkflow`): For webshop scenarios, it contains multi-turn interaction with environment. - `CodeWorkflow` (Coming soon): For coding scenarios, executes returned code, runs tests, and calculates rewards based on test results. - ... -- **Experience**: The output of running a `Workflow`, where the internal data format depends on the algorithm used for training. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token_ids, action_mask (identifying which tokens were generated by the LLM), logprobs, rewards, etc. + +- **Experience** ({class}`trinity.common.experience.Experience`): The output of running a `Workflow`. The internal data format depends on the training algorithm used. For example, for common PPO/GRPO algorithms, `Experience` includes lists of token IDs, action masks (identifying which tokens were generated by the LLM), log probabilities, rewards, etc. --- ### Step 1: Prepare Task Dataset -Each `Task` contains various parameters needed to initialize the `Workflow`. Due to significant differences in initialization parameters across different `Workflows`, the following example uses a math problem scenario. +The task dataset is loaded via the `buffer.explorer_input.taskset` configuration entry in your YAML config file. +To handle differences in `Task` contents, Trinity-RFT provides a unified `Task` interface containing the following fields. -In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line’s JSON contains `question` and `answer` fields representing the problem description and standard answer, respectively. + - **`workflow`** (`str`): The registered name of your workflow class. You can specify it in `buffer.explorer_input.taskset.default_workflow_type` of your YAML config file. + - **`reward_fn`** (`Optional[str]`): The registered name of your reward function. You can specify it in `buffer.explorer_input.taskset.default_reward_fn_type`. Note that some workflows already include built-in reward calculation; in such cases, you can omit this field. + - **`raw_task`** (`Dict`): An record of raw data in `Dict` format. For highly customized workflow, you can directly use `raw_task` to initialize your `Workflow` instance without relying on the following fields. + - **`format_args`** ({class}`trinity.common.config.FormatConfig`): Parameters to facilitate the construction of `Workflow` instances. For example, the `prompt_key` and `response_key` can be used to get the prompt and response from `raw_task`. These settings come from the YAML configuration file and can be set in `buffer.explorer_input.task_set.format`. + - **`rollout_args`** ({class}`trinity.common.config.GenerationConfig`): Parameters that control the rollout process, such as `temperature`. This field also comes from the YAML configuration file and can be set in `buffer.explorer_input.task_set.rollout_args`. + +In the math problem scenario, the `Task` dataset can be a `jsonl` file, where each line contains JSON with `question` and `answer` fields representing the problem description and standard answer, respectively. For example: ``` {"question": "1+1=", "answer": "2"} @@ -43,15 +54,33 @@ In the math problem scenario, the `Task` dataset can be a `jsonl` file, where ea ... ``` +Example configuration snippet: + +```yaml +# some config +buffer: + explorer_input: + taskset: + default_workflow: "math_workflow" + path: "/PATH/TO/FILE/DIR" + format: + prompt_key: "question" + response_key: "answer" + rollout_args: + temperature: 1.0 + # some other configs +``` + +In this example, each task object's `raw_task` is a `Dict` with two keys (`question` and `answer`). The `MathWorkflow` uses the `prompt_key` and `response_key` to extract the question and answer from the `raw_task` and use the `rollout_args` to generate the response. + + --- -### Step 2: Write Workflow +### Step 2: Implement a New Workflow -The core of creating a new task type is writing a new `Workflow`, whose base class interface is as follows: +The `Workflow` base class interface is as follows: ```python -# import some packages - class Workflow(ABC): def __init__( @@ -68,39 +97,48 @@ class Workflow(ABC): """Run the workflow and return a list of Experiences.""" ``` -Developers can register their own `Workflow` through the `WORKFLOWS.register_module` method, but need to ensure that the name does not conflict with existing `Workflow` classes. -```python -# import some packages -from trinity.common.workflows.workflow import WORKFLOWS +#### Initializing Your Workflow -@WORKFLOWS.register_module("my_workflow") -class MyWorkflow(Workflow): - pass -``` +During initialization, `Workflow` receives the following parameters: + +- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`). +- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset. +- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs. -#### Initialization Parameters -When initializing, `Workflow` receives the following parameters: -- `model`: The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`). -- `task`: An instance of `Task`, which is generated by one line of data from the `Task` dataset. The `raw_task` field contains the `Dict` format source data, which can be used to construct the `Workflow` instance. -The `rollout_args` field contains the parameters for the rollout process, such as `n`, `temperature`, `top_k` and `top_p`. -- `auxiliary_models`: A list of auxiliary models, which will not be trained. All of them provide OpenAI compatible API. ```{tip} -The `model` also provided an OpenAI compatible API, you can switch to it by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and use `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. +You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. ``` -#### Example Code -Below is a simple example demonstrating how to implement a math problem `Workflow`: +Here’s an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): - def __init__(self, model: ModelWrapper, task: Task, **kwargs): - super().__init__(model, **kwargs) + def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List): + super().__init__(model, task, auxiliary_models) self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") + self.rollout_args = task.rollout_args + # Optional: If you want to use OpenAI API in your workflow + # self.openai_client = self.model.get_openai_client() +``` + +#### Implementing the `run` method + +The `run` method is the core of your workflow. It returns a list of `Experience`. +Below is a simple implementation for a math workflow. + +We first call the model to generate multiple response using the provided question and rollout arguments. +Then we calculate the reward for each response using the `calculate_reward` function. +Finally, we construct a list of `Experience` with the responses and rewards and return it. + + +```python +class ExampleWorkflow(Workflow): + + # the __init__ function def calculate_reward(self, response: str, truth: str) -> float: if response == truth: @@ -109,28 +147,50 @@ class ExampleWorkflow(Workflow): return 0.0 def run(self) -> List[Experience]: - response = self.model.chat( + # call the model to generate multiple responses + responses = self.model.chat( [ { "role": "user", "content": f"Question:\n{self.question}", } ], - n=self.task.rollout_args.n, - temperature=self.task.rollout_args.temperature, + n=self.rollout_args.n, + temperature=self.rollout_args.temperature, ) - reward: float = self.calculate_reward(response.response_text, self.answer) - return [ - Experience( - tokens=response.tokens, - prompt_length=response.prompt_length, - reward=reward, - logprobs=response.logprobs, + experiences = [] + for response in responses: + # calulcate reward + reward: float = self.calculate_reward(response.response_text, self.answer) + # construct Experience + experiences.append( + Experience( + tokens=response.tokens, + prompt_length=response.prompt_length, + reward=reward, + logprobs=response.logprobs, + ) ) - ] + return experiences ``` -For some heavy workflows, the initialization process may be time-consuming. +#### Registering Your Workflow + +Register your workflow using the `WORKFLOWS.register_module` decorator. +Ensure the name does not conflict with existing workflows. + +```python +# import some packages +from trinity.common.workflows.workflow import WORKFLOWS + +@WORKFLOWS.register_module("example_workflow") +class ExampleWorkflow(Workflow): + pass +``` + +#### Avoid Re-initialization + +For heavy workflows, re-initializing every time can incurs extra computational costs. In this case, you can implement the `resettable` and `reset` methods to avoid re-initialization. ```python @@ -147,23 +207,81 @@ class ExampleWorkflow(Workflow): self.answer = task.raw_task.get("answer") ``` + +#### Full Code Example + +```python +@WORKFLOWS.register_module("example_workflow") +class ExampleWorkflow(Workflow): + + def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List): + super().__init__(model, task, auxiliary_models) + self.question = task.raw_task.get("question") + self.answer = task.raw_task.get("answer") + self.rollout_args = task.rollout_args + + def calculate_reward(self, response: str, truth: str) -> float: + if response == truth: + return 1.0 + else: + return 0.0 + + def run(self) -> List[Experience]: + # call the model to generate multiple responses + responses = self.model.chat( + [ + { + "role": "user", + "content": f"Question:\n{self.question}", + } + ], + n=self.rollout_args.n, + temperature=self.rollout_args.temperature, + ) + experiences = [] + for response in responses: + # calulcate reward + reward: float = self.calculate_reward(response.response_text, self.answer) + # construct Experience + experiences.append( + Experience( + tokens=response.tokens, + prompt_length=response.prompt_length, + reward=reward, + logprobs=response.logprobs, + ) + ) + return experiences + + def resettable(self): + return True + + def reset(self, task: Task): + self.question = task.raw_task.get("question") + self.answer = task.raw_task.get("answer") +``` + + --- -### Step 3: Modify Configuration File +### Step 3: Use Your Workflow -After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `buffer.explorer_input` domain to the newly registered `Workflow` name. +After implementing and registering your workflow, you need to update the configuration file to set the `default_workflow_type` in the `buffer.explorer_input.taskset` domain to the newly registered `Workflow` name. ```yaml buffer: # Other fields explorer_input: taskset: - name: example_task - storage_type: file path: /path/to/taskset - # Other fields - default_workflow_type: example_workflow -# Other fields + default_workflow_type: example_workflow + # Other fields +``` + +Now you can run your workflow in Trinity-RFT using the command: + +``` +trinity run --config ``` ---