Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ services:
- RAY_ADDRESS=auto
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
- DATA_ROOT_DIR=/mnt/data
- MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
- MODEL_PATH=/mnt/models/Qwen3-1.7B
- CHECKPOINT_PATH=/mnt/checkpoints
working_dir: /workspace
networks:
- trinity-network
Expand All @@ -32,7 +33,8 @@ services:
- HF_ENDPOINT=https://hf-mirror.com
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
- DATA_ROOT_DIR=/mnt/data
- MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
- MODEL_PATH=/mnt/models/Qwen3-1.7B
- CHECKPOINT_PATH=/mnt/checkpoints
working_dir: /workspace
volumes:
- trinity-volume:/mnt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: Run unittest
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
run: |
docker compose exec trinity-node-1 pytest tests --ignore=tests/data --ctrf report.json
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json

- name: Upload test results
uses: actions/upload-artifact@v4
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray==2.43.0",
"vllm==0.8.3",
"vllm>=0.8.3",
"tensordict==0.6.2",
"wandb",
"omegaconf",
Expand Down
49 changes: 49 additions & 0 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import ray
import torch

from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience


class TestQueueBuffer(unittest.TestCase):
def setUp(self):
ray.init(ignore_reinit_error=True)

def test_queue_buffer(self):
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
storage_type=StorageType.QUEUE,
)
config = BufferConfig(
max_retry_times=3,
max_retry_interval=1,
read_batch_size=read_batch_size,
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
exps = [
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
)
for i in range(1, put_batch_size + 1)
]
for _ in range(total_num // put_batch_size):
writer.write(exps)
writer.finish()
for _ in range(total_num // read_batch_size):
exps = reader.read()
self.assertEqual(len(exps), read_batch_size)
print(f"finish read {read_batch_size} experience")
self.assertRaises(StopIteration, reader.read)
10 changes: 6 additions & 4 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

class TestSQLBuffer(unittest.TestCase):
def test_create_sql_buffer(self) -> None:
put_batch_size = 4
read_batch_size = 2
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
Expand All @@ -39,7 +40,8 @@ def test_create_sql_buffer(self) -> None:
)
for i in range(1, put_batch_size + 1)
]
sql_writer.write(exps)
for _ in range(put_batch_size // read_batch_size):
for _ in range(total_num // put_batch_size):
sql_writer.write(exps)
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)
8 changes: 3 additions & 5 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
import os
import unittest

from tests.tools import get_template_config
from trinity.common.config import load_config

config_yaml_path = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")


class TestConfig(unittest.TestCase):
def test_load_default_config(self):
config = load_config(config_yaml_path)
print(config.data)
config = get_template_config()
config.check_and_update()
self.assertIsNotNone(config.trainer.trainer_config)
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 4)
self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 2)
self.assertEqual(config.trainer.trainer_config.trainer.nnodes, 1)
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project)
self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.monitor.name)
Expand Down
Empty file.
15 changes: 9 additions & 6 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
import torch
from transformers import AutoTokenizer

from trinity.common.config import load_config
from tests.tools import get_template_config
from trinity.common.models import create_rollout_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
tokenize_and_mask_messages_default,
tokenize_and_mask_messages_hf,
)

config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")


def get_model_path() -> str:
path = os.environ.get("MODEL_PATH")
Expand Down Expand Up @@ -101,7 +99,12 @@ def test_generate(self):
]
results = self.model_wrapper.chat(messages)
self.assertEqual(len(results), self.config.explorer.repeat_times)
logprobs = self.model_wrapper.logprobs(results[0].tokens)
for result in results:
input_logprobs = result.logprobs[: result.prompt_length]
output_logprobs = result.logprobs[result.prompt_length :]
self.assertTrue(torch.all(input_logprobs == 0))
self.assertTrue(torch.any(output_logprobs != 0))
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
messages.append(
{
Expand Down Expand Up @@ -129,7 +132,7 @@ def test_generate(self):
class TestModelWrapperSync(BaseTestModelWrapper, unittest.TestCase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = load_config(config_dir)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm"
self.config.explorer.engine_num = 1
Expand All @@ -141,7 +144,7 @@ def setUp(self):
class TestModelWrapperAsync(BaseTestModelWrapper, unittest.TestCase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = load_config(config_dir)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 1
Expand Down
61 changes: 61 additions & 0 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Tests for explorer."""
import os
import unittest
from datetime import datetime

import ray

from tests.tools import (
TensorBoardParser,
get_checkpoint_path,
get_model_path,
get_template_config,
get_unittest_dataset_config,
)
from trinity.cli.launcher import explore
from trinity.common.constants import MonitorType


class BaseExplorerCase:
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.repeat_times = 2
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.project = "Trinity-unittest"
self.config.model.checkpoint_path = get_checkpoint_path()
self.config.synchronizer.sync_iteration_interval = 5
self.config.explorer.eval_interval = 10
self.config.trainer.eval_interval = 10


class TestExplorerCountdownEval(BaseExplorerCase, unittest.TestCase):
def test_explorer(self):
self.config.data = get_unittest_dataset_config("countdown")
self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 20)
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 20)


class TestExplorerCountdownNoEval(BaseExplorerCase, unittest.TestCase):
def test_explorer(self):
self.config.data = get_unittest_dataset_config("countdown")
self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.data.eval_split = None
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
eval_metrics = parser.metric_list("eval")
self.assertTrue(len(eval_metrics) == 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 20)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mode: both
data:
dataset_path: ''
total_epochs: 1
batch_size: 32
batch_size: 4
train_split: 'train'
eval_split: ''
default_workflow_type: ''
Expand All @@ -12,34 +12,31 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
checkpoint_path: ''
cluster:
cluster: # 2 for explorer, 2 for trainer
node_num: 1
gpu_per_node: 8
gpu_per_node: 4
buffer:
read_batch_size: 32
max_retry_times: 3
max_retry_interval: 1
explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 16
repeat_times: 2
tensor_parallel_size: 2
runner_num: 4
repeat_times: 1
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 0.2
top_p: 0.95
temperature: 1.0
top_p: 1.0
top_k: -1
seed: 42
logprobs: 0
backend: nccl
use_ray: false
max_pending_requests: 5
max_waiting_steps: 1
trainer:
trainer_type: verl
trainer_config_path: tests/common/tmp/template_verl_config.yaml
trainer_config_path: tests/template/verl_config.yaml
monitor:
project: unittest
name: test
Expand Down
5 changes: 5 additions & 0 deletions tests/template/data/countdown/test.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [44, 19, 35], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [44, 19, 35], \"target\": 98}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [63, 95, 96], create an equation that equals 64. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [63, 95, 96], \"target\": 64}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [95, 11, 56], create an equation that equals 28. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [95, 11, 56], \"target\": 28}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 74, 45], create an equation that equals 48. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [19, 74, 45], \"target\": 48}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [49, 41, 73], create an equation that equals 17. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [49, 41, 73], \"target\": 17}"}
Loading