Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ services:
- RAY_ADDRESS=auto
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
- DATA_ROOT_DIR=/mnt/data
- MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
- MODEL_PATH=/mnt/models/Qwen3-1.7B
- CHECKPOINT_PATH=/mnt/checkpoints
working_dir: /workspace
networks:
- trinity-network
Expand All @@ -32,7 +33,8 @@ services:
- HF_ENDPOINT=https://hf-mirror.com
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
- DATA_ROOT_DIR=/mnt/data
- MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
- MODEL_PATH=/mnt/models/Qwen3-1.7B
- CHECKPOINT_PATH=/mnt/checkpoints
working_dir: /workspace
volumes:
- trinity-volume:/mnt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: Run unittest
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
run: |
docker compose exec trinity-node-1 pytest tests --ignore=tests/data --ctrf report.json
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json

- name: Upload test results
uses: actions/upload-artifact@v4
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray==2.43.0",
"vllm==0.8.3",
"ray[default]==2.43.0",
"vllm>=0.8.3",
"tensordict==0.6.2",
"wandb",
"omegaconf",
Expand Down
49 changes: 49 additions & 0 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import ray
import torch

from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience


class TestQueueBuffer(unittest.TestCase):
def setUp(self):
ray.init(ignore_reinit_error=True)

def test_queue_buffer(self):
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
storage_type=StorageType.QUEUE,
)
config = BufferConfig(
max_retry_times=3,
max_retry_interval=1,
read_batch_size=read_batch_size,
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
exps = [
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
)
for i in range(1, put_batch_size + 1)
]
for _ in range(total_num // put_batch_size):
writer.write(exps)
writer.finish()
for _ in range(total_num // read_batch_size):
exps = reader.read()
self.assertEqual(len(exps), read_batch_size)
print(f"finish read {read_batch_size} experience")
self.assertRaises(StopIteration, reader.read)
10 changes: 6 additions & 4 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

class TestSQLBuffer(unittest.TestCase):
def test_create_sql_buffer(self) -> None:
put_batch_size = 4
read_batch_size = 2
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
Expand All @@ -39,7 +40,8 @@ def test_create_sql_buffer(self) -> None:
)
for i in range(1, put_batch_size + 1)
]
sql_writer.write(exps)
for _ in range(put_batch_size // read_batch_size):
for _ in range(total_num // put_batch_size):
sql_writer.write(exps)
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)
8 changes: 3 additions & 5 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
import os
import unittest

from tests.tools import get_template_config
from trinity.common.config import load_config

config_yaml_path = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")


class TestConfig(unittest.TestCase):
def test_load_default_config(self):
config = load_config(config_yaml_path)
print(config.data)
config = get_template_config()
config.check_and_update()
self.assertIsNotNone(config.trainer.trainer_config)
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 4)
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 2)
self.assertEqual(config.trainer.trainer_config.trainer.nnodes, 1)
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project)
self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.monitor.name)
Expand Down
35 changes: 27 additions & 8 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
import torch
from transformers import AutoTokenizer

from trinity.common.config import load_config
from tests.tools import RayUnittestBase, get_template_config
from trinity.common.models import create_rollout_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
tokenize_and_mask_messages_default,
tokenize_and_mask_messages_hf,
)

config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")


def get_model_path() -> str:
path = os.environ.get("MODEL_PATH")
Expand Down Expand Up @@ -101,7 +99,12 @@ def test_generate(self):
]
results = self.model_wrapper.chat(messages)
self.assertEqual(len(results), self.config.explorer.repeat_times)
logprobs = self.model_wrapper.logprobs(results[0].tokens)
for result in results:
input_logprobs = result.logprobs[: result.prompt_length]
output_logprobs = result.logprobs[result.prompt_length :]
self.assertTrue(torch.all(input_logprobs == 0))
self.assertTrue(torch.any(output_logprobs != 0))
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
messages.append(
{
Expand All @@ -126,10 +129,10 @@ def test_generate(self):
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))


class TestModelWrapperSync(BaseTestModelWrapper, unittest.TestCase):
class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = load_config(config_dir)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm"
self.config.explorer.engine_num = 1
Expand All @@ -138,10 +141,18 @@ def setUp(self):
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")


class TestModelWrapperAsync(BaseTestModelWrapper, unittest.TestCase):
class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True)

@classmethod
def tearDownClass(cls):
ray.shutdown()

def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = load_config(config_dir)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 1
Expand All @@ -151,6 +162,14 @@ def setUp(self):


class TestTokenizer(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True)

@classmethod
def tearDownClass(cls):
ray.shutdown()

def test_assistant_token_mask(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down
63 changes: 63 additions & 0 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for explorer."""
import os
from abc import abstractmethod
from datetime import datetime

from tests.tools import (
RayUnittestBase,
TensorBoardParser,
get_checkpoint_path,
get_model_path,
get_template_config,
get_unittest_dataset_config,
)
from trinity.cli.launcher import explore
from trinity.common.constants import MonitorType


class BaseExplorerCase(RayUnittestBase):
def setUp(self):
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.repeat_times = 2
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.project = "Trinity-unittest"
self.config.model.checkpoint_path = get_checkpoint_path()
self.config.synchronizer.sync_iteration_interval = 2
self.config.explorer.eval_interval = 4
self.config.trainer.eval_interval = 4

@abstractmethod
def test_explorer(self):
"""Test explorer"""


class TestExplorerCountdownEval(BaseExplorerCase):
def test_explorer(self):
self.config.data = get_unittest_dataset_config("countdown")
self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)


class TestExplorerCountdownNoEval(BaseExplorerCase):
def test_explorer(self):
self.config.data = get_unittest_dataset_config("countdown")
self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.data.eval_split = None
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) == 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
9 changes: 5 additions & 4 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def run(self) -> List[Experience]:
if "timeout" in self.error_type:
time.sleep(self.seconds)
elif self.error_type == "exception":
raise RuntimeError("Exception occurred")
raise ValueError("Exception occurred")
elif self.error_type == "exit":
exit(1)
return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)]
Expand Down Expand Up @@ -107,19 +107,20 @@ def test_runner_pool(self):
tasks=tasks,
)

# The excepted return order is: `exception` -> `timeout_5` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit`
# The excepted return order is: `exception` -> `timeout_2` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit`
# 1. `exception`
st = time.time()
status = pool.get_next_unorder()
et = time.time()
self.assertTrue(et - st < 5)
self.assertTrue(et - st < 2)
print(f"First task use time: {et - st}")
self.assertEqual(len(status), 1)
self.assertFalse(status[0].ok)
# 2. `timeout_2
st = time.time()
status = pool.get_next_unorder()
et = time.time()
self.assertTrue(et - st < 3)
self.assertTrue(et - st > 2)
self.assertEqual(len(status), 1)
self.assertTrue(status[0].ok)
# 3. `success`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mode: both
data:
dataset_path: ''
total_epochs: 1
batch_size: 32
batch_size: 4
train_split: 'train'
eval_split: ''
default_workflow_type: ''
Expand All @@ -12,34 +12,31 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
checkpoint_path: ''
cluster:
cluster: # 2 for explorer, 2 for trainer
node_num: 1
gpu_per_node: 8
gpu_per_node: 4
buffer:
read_batch_size: 32
max_retry_times: 3
max_retry_interval: 1
explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 16
repeat_times: 2
tensor_parallel_size: 2
runner_num: 4
repeat_times: 1
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 0.2
top_p: 0.95
temperature: 1.0
top_p: 1.0
top_k: -1
seed: 42
logprobs: 0
backend: nccl
use_ray: false
max_pending_requests: 5
max_waiting_steps: 1
trainer:
trainer_type: verl
trainer_config_path: tests/common/tmp/template_verl_config.yaml
trainer_config_path: tests/template/verl_config.yaml
monitor:
project: unittest
name: test
Expand Down
4 changes: 4 additions & 0 deletions tests/template/data/countdown/test.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [44, 19, 35], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [44, 19, 35], \"target\": 98}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [63, 95, 96], create an equation that equals 64. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [63, 95, 96], \"target\": 64}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [95, 11, 56], create an equation that equals 28. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [95, 11, 56], \"target\": 28}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 74, 45], create an equation that equals 48. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [19, 74, 45], \"target\": 48}"}
Loading