Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,10 @@ explorer:
max_retry_times: 2
env_vars: {}
rollout_model:
engine_type: vllm_async
engine_type: vllm
engine_num: 1
tensor_parallel_size: 1
enable_history: False
auxiliary_models:
- model_path: /PATH/TO/MODEL
tensor_parallel_size: 1
Expand All @@ -336,9 +337,10 @@ explorer:
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
- `max_retry_times`: Maximum number of retries for a workflow.
- `env_vars`: Environment variables to be set for every workflow runners.
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
- `rollout_model.engine_type`: Type of inference engine. For now, only `vllm_async` and `vllm` is supported, they have the same meaning and both use the asynchronous engine. In subsequent versions, only `vllm` may be retained for simplicity.
- `rollout_model.engine_num`: Number of inference engines.
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
- `rollout_model.enable_history`: Whether to enable model call history recording. If set to `True`, the model wrapper automatically records the return experiences of model calls. Please periodically extract the history via `extract_experience_from_history` to avoid out-of-memory issues. Default is `False`.
- `auxiliary_models`: Additional models used for custom workflows.
- `eval_interval`: Interval (in steps) for evaluating the model.
- `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ During initialization, `Workflow` receives the following parameters:

```{tip}
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id`.
```

Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "trinity-rft"
version = "0.2.0"
version = "0.2.1.dev0"
authors = [
{name="Trinity-RFT Team", email="[email protected]"},
]
Expand Down
200 changes: 110 additions & 90 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import unittest

import torch
from parameterized import parameterized_class
from transformers import AutoTokenizer

from tests.tools import RayUnittestBase, get_template_config
from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
Expand Down Expand Up @@ -82,12 +83,59 @@ def get_model_path() -> str:
"""


class BaseTestModelWrapper:
def test_generate(self):
@parameterized_class(
("tensor_parallel_size", "engine_num", "use_v1", "repeat_times", "enable_history", "use_async"),
[
(1, 2, False, 2, True, False),
(2, 2, False, 1, False, True),
(2, 2, True, 2, True, False),
(1, 2, True, 1, False, True),
(2, 1, True, 3, True, True),
],
)
class ModelWrapperTest(RayUnittestBaseAysnc):
def setUp(self):
# configure the model
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_num = self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
self.config.explorer.rollout_model.use_v1 = self.use_v1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = self.repeat_times
self.config.explorer.rollout_model.enable_history = self.enable_history
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(
self.engines[0], model_type="vllm_async", enable_history=self.enable_history
)

async def test_generate(
self,
):
prompts = ["Hello, world!", "Hello, my name is"]
n = self.config.algorithm.repeat_times
results = self.model_wrapper.generate(prompts, n=n, temperature=1.0)
self.assertEqual(len(results), len(prompts) * n)
if self.use_async:
generate_results = await self.model_wrapper.generate_async(
prompts, n=n, temperature=1.0
)
else:
generate_results = self.model_wrapper.generate(prompts, n=n, temperature=1.0)
self.assertEqual(len(generate_results), len(prompts) * n)
if self.config.explorer.rollout_model.enable_history:
history_experiences = self.model_wrapper.extract_experience_from_history(
clear_history=False
)
self.assertEqual(len(history_experiences), len(generate_results))
for exp, history_exp in zip(generate_results, history_experiences):
self.assertEqual(exp.response_text, history_exp.response_text)
self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist())
self.assertEqual(exp.prompt_length, history_exp.prompt_length)
self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist())
else:
with self.assertRaises(ValueError):
self.model_wrapper.extract_experience_from_history(clear_history=False)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like today?"},
Expand All @@ -97,15 +145,32 @@ def test_generate(self):
},
{"role": "user", "content": "OK, thanks!"},
]
results = self.model_wrapper.chat(messages, n=n, temperature=1.0)
if self.use_async:
results = await self.model_wrapper.chat_async(messages, n=n, temperature=1.0)
else:
results = self.model_wrapper.chat(messages, n=n, temperature=1.0)
self.assertEqual(len(results), n)
if self.config.explorer.rollout_model.enable_history:
history_experiences = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(history_experiences) - len(generate_results), len(results))
for exp, history_exp in zip(results, history_experiences[len(generate_results) :]):
self.assertEqual(exp.response_text, history_exp.response_text)
self.assertEqual(exp.tokens.tolist(), history_exp.tokens.tolist())
self.assertEqual(exp.prompt_length, history_exp.prompt_length)
self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist())
for result in results:
input_logprobs = result.logprobs[: result.prompt_length]
output_logprobs = result.logprobs[result.prompt_length :]
self.assertTrue(torch.all(input_logprobs == 0))
self.assertTrue(torch.any(output_logprobs != 0))
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
if self.use_async:
logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist())
else:
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
if self.config.explorer.rollout_model.enable_history:
history_experiences = self.model_wrapper.extract_experience_from_history()
self.assertTrue(len(history_experiences) == 0)
messages.append(
{
"role": "assistant",
Expand All @@ -128,84 +193,9 @@ def test_generate(self):
self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask))
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
self.assertRaises(ValueError, self.model_wrapper.get_openai_client)


class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.engine_num = 2
self.config.explorer.rollout_model.use_v1 = False
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = 2
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")


class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.explorer.rollout_model.engine_num = 2
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.use_v1 = False
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = 2
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.explorer.rollout_model.engine_num = 2
self.config.explorer.rollout_model.tensor_parallel_size = 2
self.config.explorer.rollout_model.use_v1 = False
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.explorer.rollout_model.engine_num = 2
self.config.explorer.rollout_model.tensor_parallel_size = 2
self.config.explorer.rollout_model.use_v1 = True
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.algorithm.repeat_times = 2
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.explorer.rollout_model.engine_num = 2
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.use_v1 = True
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
if self.config.explorer.rollout_model.enable_history:
history_experiences = self.model_wrapper.extract_experience_from_history()
self.assertTrue(len(history_experiences) == 0)


class TestAPIServer(RayUnittestBase):
Expand All @@ -221,21 +211,25 @@ def setUp(self):
self.config.explorer.rollout_model.enable_openai_api = True
self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")
self.model_wrapper = ModelWrapper(
self.engines[0], model_type="vllm_async", enable_history=True
)
self.model_wrapper_no_history = ModelWrapper(
self.engines[0], model_type="vllm_async", enable_history=False
)

def test_api(self):
openai_client = self.model_wrapper.get_openai_client()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
response = openai_client.chat.completions.create(
model=self.config.model.model_path, messages=messages, n=1
)
model_id = openai_client.models.list().data[0].id
response = openai_client.chat.completions.create(model=model_id, messages=messages, n=1)
self.assertEqual(1, len(response.choices))
self.assertTrue(len(response.choices[0].message.content) > 0)
response = openai_client.chat.completions.create(
model=self.config.model.model_path,
model=model_id,
messages=messages,
n=2,
temperature=0.5,
Expand All @@ -246,6 +240,32 @@ def test_api(self):
self.assertTrue(response.choices[0].logprobs is not None)
self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs))
self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0)
self.assertTrue(hasattr(response, "prompt_token_ids"))
self.assertTrue(len(response.prompt_token_ids) > 0)
self.assertTrue(hasattr(response.choices[0], "token_ids"))
self.assertTrue(len(response.choices[0].token_ids) > 0)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 3)
response = openai_client.chat.completions.create(
model=model_id,
messages=messages,
n=4,
temperature=0.5,
logprobs=True,
top_logprobs=0,
)
exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), 4)
self.assertEqual(len(self.model_wrapper.extract_experience_from_history()), 0)
response = self.model_wrapper_no_history.get_openai_client().chat.completions.create(
model=model_id, messages=messages, n=2
)
self.assertEqual(2, len(response.choices))
self.assertTrue(hasattr(response.choices[0], "token_ids"))
self.assertTrue(len(response.choices[0].token_ids) > 0)
with self.assertRaises(ValueError):
self.model_wrapper_no_history.extract_experience_from_history()
self.assertEqual(len(self.model_wrapper_no_history.history), 0)


class TestTokenizer(unittest.TestCase):
Expand Down
6 changes: 3 additions & 3 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
import unittest
from typing import List, Tuple
from typing import List

import ray
import torch
Expand Down Expand Up @@ -98,8 +98,8 @@ def init_process_group(
def has_api_server(self) -> bool:
return True

def api_server_ready(self) -> Tuple[str, str]:
return "http://localhosts:12345", "placeholder"
def api_server_ready(self) -> str:
return "http://localhosts:12345"


def generate_tasks(
Expand Down
2 changes: 1 addition & 1 deletion trinity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Trinity-RFT (Reinforcement Fine-Tuning)"""

__version__ = "0.2.0"
__version__ = "0.2.1.dev0"
5 changes: 3 additions & 2 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import torch
from verl.trainer.ppo.ray_trainer import DataProto

from trinity.algorithm.sample_strategy.sample_strategy import (
SAMPLE_STRATEGY,
Expand Down Expand Up @@ -85,7 +84,9 @@ def default_args(cls) -> Dict:
}


def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor):
from verl.trainer.ppo.ray_trainer import DataProto

attention_mask = experiences.attention_masks
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
Expand Down
10 changes: 3 additions & 7 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ class InferenceModelConfig:
# For Qwen3
enable_thinking: bool = False

# For history recording
enable_history: bool = False

# For OpenAI API
enable_openai_api: bool = False

Expand Down Expand Up @@ -310,8 +313,6 @@ class ExplorerConfig:
name: str = EXPLORER_NAME
# for workflow runner
# number of workflow runners.
# For sync engine (vllm), it should be `1`.
# For async engine (vllm_async), it could be a large number.
runner_per_model: int = 8 # number of runners per each rollout model
max_timeout: int = 1800 # wait each task for 30 minutes
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
Expand Down Expand Up @@ -719,11 +720,6 @@ def check_and_update(self) -> None: # noqa: C901
self.model.critic_model_path = self.model.model_path

# check explorer
if (
self.explorer.rollout_model.engine_type != "vllm_async"
and self.explorer.rollout_model.enable_openai_api
):
raise ValueError("OpenAI API server only support `vllm_async` engine.")
if self.explorer.rollout_model.max_prompt_tokens is None:
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
if self.explorer.rollout_model.max_response_tokens is None:
Expand Down
Loading