diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml
index e04a722aeb..e8468a48ca 100644
--- a/.github/workflows/docker/docker-compose.yaml
+++ b/.github/workflows/docker/docker-compose.yaml
@@ -8,7 +8,8 @@ services:
- RAY_ADDRESS=auto
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
- DATA_ROOT_DIR=/mnt/data
- - MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
+ - MODEL_PATH=/mnt/models/Qwen3-1.7B
+ - CHECKPOINT_PATH=/mnt/checkpoints
working_dir: /workspace
networks:
- trinity-network
@@ -32,7 +33,8 @@ services:
- HF_ENDPOINT=https://hf-mirror.com
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
- DATA_ROOT_DIR=/mnt/data
- - MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct
+ - MODEL_PATH=/mnt/models/Qwen3-1.7B
+ - CHECKPOINT_PATH=/mnt/checkpoints
working_dir: /workspace
volumes:
- trinity-volume:/mnt
diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml
index 11f49640f7..ceca92d18c 100644
--- a/.github/workflows/unittest.yaml
+++ b/.github/workflows/unittest.yaml
@@ -36,7 +36,7 @@ jobs:
- name: Run unittest
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
run: |
- docker compose exec trinity-node-1 pytest tests --ignore=tests/data --ctrf report.json
+ docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
- name: Upload test results
uses: actions/upload-artifact@v4
diff --git a/pyproject.toml b/pyproject.toml
index d2a52385d6..45944d095e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,8 +20,8 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
- "ray==2.43.0",
- "vllm==0.8.3",
+ "ray[default]==2.43.0",
+ "vllm>=0.8.3",
"tensordict==0.6.2",
"wandb",
"omegaconf",
diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py
new file mode 100644
index 0000000000..058d69a126
--- /dev/null
+++ b/tests/buffer/queue_test.py
@@ -0,0 +1,49 @@
+import unittest
+
+import ray
+import torch
+
+from trinity.buffer.reader.queue_reader import QueueReader
+from trinity.buffer.writer.queue_writer import QueueWriter
+from trinity.common.config import BufferConfig, DatasetConfig
+from trinity.common.constants import AlgorithmType, StorageType
+from trinity.common.experience import Experience
+
+
+class TestQueueBuffer(unittest.TestCase):
+ def setUp(self):
+ ray.init(ignore_reinit_error=True)
+
+ def test_queue_buffer(self):
+ total_num = 8
+ put_batch_size = 2
+ read_batch_size = 4
+ meta = DatasetConfig(
+ name="test_buffer",
+ algorithm_type=AlgorithmType.PPO,
+ storage_type=StorageType.QUEUE,
+ )
+ config = BufferConfig(
+ max_retry_times=3,
+ max_retry_interval=1,
+ read_batch_size=read_batch_size,
+ )
+ writer = QueueWriter(meta, config)
+ reader = QueueReader(meta, config)
+ exps = [
+ Experience(
+ tokens=torch.tensor([float(j) for j in range(i + 1)]),
+ prompt_length=i,
+ reward=float(i),
+ logprobs=torch.tensor([0.1]),
+ )
+ for i in range(1, put_batch_size + 1)
+ ]
+ for _ in range(total_num // put_batch_size):
+ writer.write(exps)
+ writer.finish()
+ for _ in range(total_num // read_batch_size):
+ exps = reader.read()
+ self.assertEqual(len(exps), read_batch_size)
+ print(f"finish read {read_batch_size} experience")
+ self.assertRaises(StopIteration, reader.read)
diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py
index 57761a4341..13eb657585 100644
--- a/tests/buffer/sql_test.py
+++ b/tests/buffer/sql_test.py
@@ -14,8 +14,9 @@
class TestSQLBuffer(unittest.TestCase):
def test_create_sql_buffer(self) -> None:
- put_batch_size = 4
- read_batch_size = 2
+ total_num = 8
+ put_batch_size = 2
+ read_batch_size = 4
meta = DatasetConfig(
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
@@ -39,7 +40,8 @@ def test_create_sql_buffer(self) -> None:
)
for i in range(1, put_batch_size + 1)
]
- sql_writer.write(exps)
- for _ in range(put_batch_size // read_batch_size):
+ for _ in range(total_num // put_batch_size):
+ sql_writer.write(exps)
+ for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)
diff --git a/tests/common/config_test.py b/tests/common/config_test.py
index 2dceac2b96..6d035576aa 100644
--- a/tests/common/config_test.py
+++ b/tests/common/config_test.py
@@ -3,18 +3,16 @@
import os
import unittest
+from tests.tools import get_template_config
from trinity.common.config import load_config
-config_yaml_path = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")
-
class TestConfig(unittest.TestCase):
def test_load_default_config(self):
- config = load_config(config_yaml_path)
- print(config.data)
+ config = get_template_config()
config.check_and_update()
self.assertIsNotNone(config.trainer.trainer_config)
- self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 4)
+ self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 2)
self.assertEqual(config.trainer.trainer_config.trainer.nnodes, 1)
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project)
self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.monitor.name)
diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py
index ab39cc5c66..aeda2d0c76 100644
--- a/tests/common/vllm_test.py
+++ b/tests/common/vllm_test.py
@@ -5,7 +5,7 @@
import torch
from transformers import AutoTokenizer
-from trinity.common.config import load_config
+from tests.tools import RayUnittestBase, get_template_config
from trinity.common.models import create_rollout_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
@@ -13,8 +13,6 @@
tokenize_and_mask_messages_hf,
)
-config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")
-
def get_model_path() -> str:
path = os.environ.get("MODEL_PATH")
@@ -101,7 +99,12 @@ def test_generate(self):
]
results = self.model_wrapper.chat(messages)
self.assertEqual(len(results), self.config.explorer.repeat_times)
- logprobs = self.model_wrapper.logprobs(results[0].tokens)
+ for result in results:
+ input_logprobs = result.logprobs[: result.prompt_length]
+ output_logprobs = result.logprobs[result.prompt_length :]
+ self.assertTrue(torch.all(input_logprobs == 0))
+ self.assertTrue(torch.any(output_logprobs != 0))
+ logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
messages.append(
{
@@ -126,10 +129,10 @@ def test_generate(self):
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
-class TestModelWrapperSync(BaseTestModelWrapper, unittest.TestCase):
+class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
ray.init(ignore_reinit_error=True)
- self.config = load_config(config_dir)
+ self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm"
self.config.explorer.engine_num = 1
@@ -138,10 +141,18 @@ def setUp(self):
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
-class TestModelWrapperAsync(BaseTestModelWrapper, unittest.TestCase):
+class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase):
+ @classmethod
+ def setUpClass(cls):
+ ray.init(ignore_reinit_error=True)
+
+ @classmethod
+ def tearDownClass(cls):
+ ray.shutdown()
+
def setUp(self):
ray.init(ignore_reinit_error=True)
- self.config = load_config(config_dir)
+ self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 1
@@ -151,6 +162,14 @@ def setUp(self):
class TestTokenizer(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ ray.init(ignore_reinit_error=True)
+
+ @classmethod
+ def tearDownClass(cls):
+ ray.shutdown()
+
def test_assistant_token_mask(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py
new file mode 100644
index 0000000000..e0334dacc3
--- /dev/null
+++ b/tests/explorer/explorer_test.py
@@ -0,0 +1,63 @@
+"""Tests for explorer."""
+import os
+from abc import abstractmethod
+from datetime import datetime
+
+from tests.tools import (
+ RayUnittestBase,
+ TensorBoardParser,
+ get_checkpoint_path,
+ get_model_path,
+ get_template_config,
+ get_unittest_dataset_config,
+)
+from trinity.cli.launcher import explore
+from trinity.common.constants import MonitorType
+
+
+class BaseExplorerCase(RayUnittestBase):
+ def setUp(self):
+ self.config = get_template_config()
+ self.config.model.model_path = get_model_path()
+ self.config.explorer.engine_type = "vllm_async"
+ self.config.explorer.repeat_times = 2
+ self.config.monitor.monitor_type = MonitorType.TENSORBOARD
+ self.config.monitor.project = "Trinity-unittest"
+ self.config.model.checkpoint_path = get_checkpoint_path()
+ self.config.synchronizer.sync_iteration_interval = 2
+ self.config.explorer.eval_interval = 4
+ self.config.trainer.eval_interval = 4
+
+ @abstractmethod
+ def test_explorer(self):
+ """Test explorer"""
+
+
+class TestExplorerCountdownEval(BaseExplorerCase):
+ def test_explorer(self):
+ self.config.data = get_unittest_dataset_config("countdown")
+ self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
+ self.config.check_and_update()
+ explore(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ eval_metrics = parser.metric_list("eval")
+ self.assertTrue(len(eval_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
+ self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
+
+
+class TestExplorerCountdownNoEval(BaseExplorerCase):
+ def test_explorer(self):
+ self.config.data = get_unittest_dataset_config("countdown")
+ self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
+ self.config.data.eval_split = None
+ self.config.check_and_update()
+ explore(self.config)
+ parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ eval_metrics = parser.metric_list("eval")
+ self.assertTrue(len(eval_metrics) == 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py
index 9f3597216a..2ee03fbc2a 100644
--- a/tests/explorer/runner_pool_test.py
+++ b/tests/explorer/runner_pool_test.py
@@ -31,7 +31,7 @@ def run(self) -> List[Experience]:
if "timeout" in self.error_type:
time.sleep(self.seconds)
elif self.error_type == "exception":
- raise RuntimeError("Exception occurred")
+ raise ValueError("Exception occurred")
elif self.error_type == "exit":
exit(1)
return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)]
@@ -107,19 +107,20 @@ def test_runner_pool(self):
tasks=tasks,
)
- # The excepted return order is: `exception` -> `timeout_5` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit`
+ # The excepted return order is: `exception` -> `timeout_2` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit`
# 1. `exception`
st = time.time()
status = pool.get_next_unorder()
et = time.time()
- self.assertTrue(et - st < 5)
+ self.assertTrue(et - st < 2)
+ print(f"First task use time: {et - st}")
self.assertEqual(len(status), 1)
self.assertFalse(status[0].ok)
# 2. `timeout_2
st = time.time()
status = pool.get_next_unorder()
et = time.time()
- self.assertTrue(et - st < 3)
+ self.assertTrue(et - st > 2)
self.assertEqual(len(status), 1)
self.assertTrue(status[0].ok)
# 3. `success`
diff --git a/tests/common/tmp/template_config.yaml b/tests/template/config.yaml
similarity index 71%
rename from tests/common/tmp/template_config.yaml
rename to tests/template/config.yaml
index e163623680..0eb84f4fb7 100644
--- a/tests/common/tmp/template_config.yaml
+++ b/tests/template/config.yaml
@@ -2,7 +2,7 @@ mode: both
data:
dataset_path: ''
total_epochs: 1
- batch_size: 32
+ batch_size: 4
train_split: 'train'
eval_split: ''
default_workflow_type: ''
@@ -12,34 +12,31 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
checkpoint_path: ''
-cluster:
+cluster: # 2 for explorer, 2 for trainer
node_num: 1
- gpu_per_node: 8
+ gpu_per_node: 4
buffer:
- read_batch_size: 32
max_retry_times: 3
max_retry_interval: 1
explorer:
engine_type: vllm_async
engine_num: 2
- runner_num: 16
- repeat_times: 2
- tensor_parallel_size: 2
+ runner_num: 4
+ repeat_times: 1
+ tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
- temperature: 0.2
- top_p: 0.95
+ temperature: 1.0
+ top_p: 1.0
top_k: -1
seed: 42
logprobs: 0
backend: nccl
use_ray: false
- max_pending_requests: 5
- max_waiting_steps: 1
trainer:
trainer_type: verl
- trainer_config_path: tests/common/tmp/template_verl_config.yaml
+ trainer_config_path: tests/template/verl_config.yaml
monitor:
project: unittest
name: test
diff --git a/tests/template/data/countdown/test.jsonl b/tests/template/data/countdown/test.jsonl
new file mode 100644
index 0000000000..dacd253db6
--- /dev/null
+++ b/tests/template/data/countdown/test.jsonl
@@ -0,0 +1,4 @@
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [44, 19, 35], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [44, 19, 35], \"target\": 98}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [63, 95, 96], create an equation that equals 64. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [63, 95, 96], \"target\": 64}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [95, 11, 56], create an equation that equals 28. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [95, 11, 56], \"target\": 28}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 74, 45], create an equation that equals 48. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [19, 74, 45], \"target\": 48}"}
diff --git a/tests/template/data/countdown/train.jsonl b/tests/template/data/countdown/train.jsonl
new file mode 100644
index 0000000000..3e9adf0092
--- /dev/null
+++ b/tests/template/data/countdown/train.jsonl
@@ -0,0 +1,16 @@
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [44, 19, 35], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [44, 19, 35], \"target\": 98}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [63, 95, 96], create an equation that equals 64. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [63, 95, 96], \"target\": 64}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [95, 11, 56], create an equation that equals 28. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [95, 11, 56], \"target\": 28}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 74, 45], create an equation that equals 48. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [19, 74, 45], \"target\": 48}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [49, 41, 73], create an equation that equals 17. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [49, 41, 73], \"target\": 17}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [48, 28, 42], create an equation that equals 62. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [48, 28, 42], \"target\": 62}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [83, 78, 1, 39], create an equation that equals 82. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [83, 78, 1, 39], \"target\": 82}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [67, 21, 31, 20], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [67, 21, 31, 20], \"target\": 98}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [96, 74, 94, 54], create an equation that equals 40. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [96, 74, 94, 54], \"target\": 40}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [81, 84, 62], create an equation that equals 65. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [81, 84, 62], \"target\": 65}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [2, 54, 17], create an equation that equals 35. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [2, 54, 17], \"target\": 35}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [26, 71, 58, 38], create an equation that equals 40. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [26, 71, 58, 38], \"target\": 40}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [21, 39, 34, 36], create an equation that equals 16. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [21, 39, 34, 36], \"target\": 16}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 25, 89], create an equation that equals 95. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [19, 25, 89], \"target\": 95}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [8, 62, 43], create an equation that equals 27. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [8, 62, 43], \"target\": 27}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [74, 5, 20, 88], create an equation that equals 50. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [74, 5, 20, 88], \"target\": 50}"}
diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml
new file mode 100644
index 0000000000..c902b0d98e
--- /dev/null
+++ b/tests/template/verl_config.yaml
@@ -0,0 +1,145 @@
+actor_rollout_ref:
+ hybrid_engine: True
+ model:
+ external_lib: null
+ override_config: { }
+ enable_gradient_checkpointing: True
+ use_remove_padding: True
+ actor:
+ strategy: fsdp # This is for backward-compatibility
+ ppo_mini_batch_size: 4
+ # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
+ ppo_micro_batch_size_per_gpu: 1
+ use_dynamic_bsz: True
+ ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
+ grad_clip: 1.0
+ clip_ratio: 0.2
+ entropy_coeff: 0.001
+ use_kl_loss: False # True for GRPO
+ kl_loss_coef: 0.001 # for grpo
+ kl_loss_type: low_var_kl # for grpo
+ ppo_epochs: 1
+ shuffle: False
+ ulysses_sequence_parallel_size: 1 # sp size
+ checkpoint:
+ contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
+ optim:
+ lr: 1e-6
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
+ # min_lr_ratio: null # only useful for warmup with cosine
+ warmup_style: constant # select from constant/cosine
+ total_training_steps: -1 # must be override by program
+ fsdp_config:
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ param_offload: False
+ optimizer_offload: False
+ fsdp_size: -1
+ # --- below: opmd ---
+ alg_type: ppo # ppo / opmd / pairwise_opmd
+ tau: 0.000 # strength of regularization w.r.t. old / ref policy
+ opmd_baseline: mean # mean / logavgexp, applicable to opmd
+ use_uid: False # True / False, applicable to pairwise_opmd
+ ref:
+ fsdp_config:
+ param_offload: False
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ # log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu
+ log_prob_micro_batch_size_per_gpu: 1
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
+ rollout:
+ name: vllm
+ use_fire_sampling: False # https://arxiv.org/abs/2410.21236
+ prompt_length: ${data.max_prompt_length} # not use for opensource
+ response_length: ${data.max_response_length}
+ # for vllm rollout
+ dtype: bfloat16 # should align with FSDP
+ gpu_memory_utilization: 0.4
+ ignore_eos: False
+ enforce_eager: True
+ free_cache_engine: True
+ load_format: dummy_dtensor
+ tensor_model_parallel_size: 2
+ max_num_batched_tokens: 8192
+ max_model_len: null
+ max_num_seqs: 1024
+ # log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu
+ log_prob_micro_batch_size_per_gpu: 1
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
+ disable_log_stats: True
+ enable_chunked_prefill: True # could get higher throughput
+ # for hf rollout
+ do_sample: True
+ # number of responses (i.e. num sample times)
+ n: 1 # > 1 for grpo
+
+critic:
+ strategy: fsdp
+ optim:
+ lr: 1e-5
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
+ # min_lr_ratio: null # only useful for warmup with cosine
+ warmup_style: constant # select from constant/cosine
+ total_training_steps: -1 # must be override by program
+ model:
+ tokenizer_path: ${actor_rollout_ref.model.path}
+ override_config: { }
+ external_lib: ${actor_rollout_ref.model.external_lib}
+ enable_gradient_checkpointing: True
+ use_remove_padding: False
+ fsdp_config:
+ param_offload: False
+ optimizer_offload: False
+ wrap_policy:
+ # transformer_layer_cls_to_wrap: None
+ min_num_params: 0
+ fsdp_size: -1
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
+ # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu
+ ppo_micro_batch_size_per_gpu: 1
+ forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
+ ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
+ ulysses_sequence_parallel_size: 1 # sp size
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
+ shuffle: ${actor_rollout_ref.actor.shuffle}
+ grad_clip: 1.0
+ cliprange_value: 0.5
+
+algorithm:
+ gamma: 1.0
+ lam: 1.0
+ adv_estimator: gae
+ kl_penalty: kl # how to estimate kl divergence
+ kl_ctrl:
+ type: fixed
+ kl_coef: 0.001
+
+trainer:
+ balance_batch: True
+ total_epochs: 10
+ # total_training_steps: null
+ project_name: TinyZero
+ experiment_name: trinity-qwen2.5-1.5b
+ logger: [ 'wandb' ]
+ val_generations_to_log_to_wandb: 0
+ nnodes: 1
+ n_gpus_per_node: 2
+ save_freq: 20
+ # auto: find the last ckpt to resume. If can't find, start from scratch
+ resume_mode: auto # or auto or resume_path if
+ critic_warmup: 0
+ default_hdfs_dir: null
+ remove_previous_ckpt_in_save: False
+ del_local_ckpt_after_load: False
+ default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
+ val_before_train: False
+ max_actor_ckpt_to_keep: 1
+ max_critic_ckpt_to_keep: 1
diff --git a/tests/tools.py b/tests/tools.py
new file mode 100644
index 0000000000..ca8fcda4c6
--- /dev/null
+++ b/tests/tools.py
@@ -0,0 +1,105 @@
+import os
+import unittest
+from collections import defaultdict
+from typing import Dict, List
+
+import ray
+from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
+
+from trinity.common.config import Config, DataConfig, FormatConfig, load_config
+
+
+def get_template_config() -> Config:
+ config_path = os.path.join(os.path.dirname(__file__), "template", "config.yaml")
+ return load_config(config_path)
+
+
+def get_model_path() -> str:
+ path = os.environ.get("MODEL_PATH")
+ if not path:
+ raise EnvironmentError(
+ "Please set `export MODEL_PATH=` before running this test."
+ )
+ return path
+
+
+def get_checkpoint_path() -> str:
+ path = os.environ.get("CHECKPOINT_PATH")
+ if not path:
+ raise EnvironmentError(
+ "Please set `export CHECKPOINT_PATH=` before running this test."
+ )
+ return path
+
+
+def get_unittest_dataset_config(dataset_name: str = "countdown") -> DataConfig:
+ """Countdown sample dataset for 8 iterations"""
+ if dataset_name == "countdown":
+ return DataConfig(
+ total_epochs=2,
+ batch_size=4,
+ default_workflow_type="math_workflow",
+ default_reward_fn_type="countdown_reward",
+ dataset_path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"),
+ train_split="train",
+ eval_split="test",
+ format_config=FormatConfig(
+ prompt_key="question",
+ response_key="answer",
+ ),
+ )
+ else:
+ raise ValueError(f"Unknown dataset name: {dataset_name}")
+
+
+class TensorBoardParser:
+ def __init__(self, log_dir: str):
+ self.log_dir = log_dir
+ self._event_files = self._find_event_files(log_dir)
+ self._metrics = self._load_metrics()
+
+ def _find_event_files(self, log_dir: str) -> List[str]:
+ event_files = []
+ for root, _, files in os.walk(log_dir):
+ for f in files:
+ if f.startswith("events.out.tfevents."):
+ event_files.append(os.path.join(root, f))
+ return event_files
+
+ def _load_metrics(self) -> Dict[str, Dict[int, float]]:
+ metrics = defaultdict(dict)
+
+ for event_file in self._event_files:
+ ea = EventAccumulator(event_file)
+ ea.Reload()
+ tags = ea.Tags()["scalars"]
+ for tag in tags:
+ scalars = ea.Scalars(tag)
+ for scalar in scalars:
+ step = scalar.step
+ value = scalar.value
+ if step not in metrics[tag] or value > metrics[tag][step]:
+ metrics[tag][step] = value
+ return dict(metrics)
+
+ def metric_exist(self, metric_name: str) -> bool:
+ return metric_name in self._metrics
+
+ def metric_max_step(self, metric_name: str) -> int:
+ if not self.metric_exist(metric_name):
+ raise ValueError(f"Metric '{metric_name}' does not exist.")
+ steps = list(self._metrics[metric_name].keys())
+ return max(steps)
+
+ def metric_list(self, metric_prefix: str) -> List[str]:
+ return [name for name in self._metrics if name.startswith(metric_prefix)]
+
+
+class RayUnittestBase(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ ray.init(ignore_reinit_error=True)
+
+ @classmethod
+ def tearDownClass(cls):
+ ray.shutdown()
diff --git a/tests/common/tmp/template_verl_config.yaml b/tests/trainer/__init__.py
similarity index 100%
rename from tests/common/tmp/template_verl_config.yaml
rename to tests/trainer/__init__.py
diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py
new file mode 100644
index 0000000000..a62d8fe5d1
--- /dev/null
+++ b/tests/trainer/trainer_test.py
@@ -0,0 +1,76 @@
+"""Tests for trainer."""
+import os
+import shutil
+from abc import abstractmethod
+from datetime import datetime
+
+import ray
+
+from tests.tools import (
+ RayUnittestBase,
+ TensorBoardParser,
+ get_checkpoint_path,
+ get_model_path,
+ get_template_config,
+ get_unittest_dataset_config,
+)
+from trinity.cli.launcher import both
+from trinity.common.constants import MonitorType
+
+
+class BaseTrainerCase(RayUnittestBase):
+ def setUp(self):
+ ray.init(ignore_reinit_error=True)
+ self.config = get_template_config()
+ self.config.model.model_path = get_model_path()
+ self.config.trainer.engine_type = "vllm_async"
+ self.config.trainer.repeat_times = 3
+ self.config.monitor.monitor_type = MonitorType.TENSORBOARD
+ self.config.model.checkpoint_path = os.path.join(
+ get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}"
+ )
+ self.config.synchronizer.sync_iteration_interval = 2
+ self.config.synchronizer.sync_method = "online"
+ self.config.explorer.eval_interval = 4
+ self.config.trainer.eval_interval = 4
+
+ @abstractmethod
+ def test_trainer(self):
+ """Test the trainer."""
+
+
+class TestTrainerCountdown(BaseTrainerCase):
+ def test_trainer(self):
+ """Test the trainer."""
+ self.config.data = get_unittest_dataset_config("countdown")
+ self.config.check_and_update()
+ self.config.trainer.trainer_config.trainer.save_freq = 8
+ both(self.config)
+ # check tensorboard
+ parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
+ rollout_metrics = parser.metric_list("rollout")
+ self.assertTrue(len(rollout_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
+ eval_metrics = parser.metric_list("eval")
+ self.assertTrue(len(eval_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
+ actor_metrics = parser.metric_list("actor")
+ self.assertTrue(len(actor_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
+ response_metrics = parser.metric_list("response_length")
+ self.assertTrue(len(response_metrics) > 0)
+ self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
+ # check checkpoint
+ from trinity.common.models.utils import get_checkpoint_dir_with_iteration
+
+ checkpoint_dir = get_checkpoint_dir_with_iteration(
+ checkpoint_root_path=self.config.model.checkpoint_path,
+ trainer_type=self.config.trainer.trainer_type,
+ iteration_num=None,
+ )
+ self.assertTrue(os.path.exists(checkpoint_dir))
+ self.assertTrue(checkpoint_dir.endswith("step_8"))
+
+ def tearDown(self):
+ # remove dir only when the test passed
+ shutil.rmtree(self.config.model.checkpoint_path)
diff --git a/trinity/buffer/buffer_writer.py b/trinity/buffer/buffer_writer.py
index 6d92042258..ac245f50b6 100644
--- a/trinity/buffer/buffer_writer.py
+++ b/trinity/buffer/buffer_writer.py
@@ -9,3 +9,7 @@ class BufferWriter(ABC):
@abstractmethod
def write(self, data: List) -> None:
"""Write to buffer."""
+
+ @abstractmethod
+ def finish(self) -> None:
+ """Finish writing."""
diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py
index 943a15ebe2..0f135af0bc 100644
--- a/trinity/buffer/queue.py
+++ b/trinity/buffer/queue.py
@@ -14,6 +14,8 @@
class QueueActor:
"""An asyncio.Queue based queue actor."""
+ FINISH_MESSAGE = "$FINISH$"
+
def __init__(self, dataset_config: DatasetConfig, config: BufferConfig) -> None:
self.config = config
self.capacity = getattr(config, "capacity", 10000)
@@ -35,11 +37,17 @@ async def put_batch(self, exp_list: List) -> None:
if self.sql_writer is not None:
self.sql_writer.write(exp_list)
+ async def finish(self) -> None:
+ """Stop the queue."""
+ await self.queue.put(self.FINISH_MESSAGE)
+
async def get_batch(self, batch_size: int) -> List:
"""Get batch of experience."""
batch = []
while True:
exp_list = await self.queue.get()
+ if exp_list == self.FINISH_MESSAGE:
+ raise StopAsyncIteration()
batch.extend(exp_list)
if len(batch) >= batch_size:
break
diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py
index 3379708da4..8ff4ef4870 100644
--- a/trinity/buffer/reader/queue_reader.py
+++ b/trinity/buffer/reader/queue_reader.py
@@ -27,4 +27,8 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig):
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
if strategy is not None and strategy != ReadStrategy.FIFO:
raise NotImplementedError(f"Read strategy {strategy} not supported for Queue Reader.")
- return ray.get(self.queue.get_batch.remote(self.config.read_batch_size))
+ try:
+ exps = ray.get(self.queue.get_batch.remote(self.config.read_batch_size))
+ except StopAsyncIteration:
+ raise StopIteration()
+ return exps
diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py
index d31b7a029a..9d4f2a83fd 100644
--- a/trinity/buffer/writer/queue_writer.py
+++ b/trinity/buffer/writer/queue_writer.py
@@ -1,7 +1,8 @@
"""Writer of the Queue buffer."""
-
from typing import List
+import ray
+
from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.queue import QueueActor
from trinity.common.config import BufferConfig, DatasetConfig
@@ -23,4 +24,7 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig):
).remote(meta, config)
def write(self, data: List) -> None:
- self.queue.put_batch.remote(data)
+ ray.get(self.queue.put_batch.remote(data))
+
+ def finish(self):
+ ray.get(self.queue.finish.remote())
diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py
index bf3ce728a4..a2abb1e399 100644
--- a/trinity/buffer/writer/sql_writer.py
+++ b/trinity/buffer/writer/sql_writer.py
@@ -40,3 +40,7 @@ def write(self, data: list) -> None:
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
session.add_all(experience_models)
+
+ def finish(self) -> None:
+ # TODO: implement this
+ pass
diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py
index 9f9c799872..dda42ecaad 100644
--- a/trinity/cli/launcher.py
+++ b/trinity/cli/launcher.py
@@ -83,6 +83,10 @@ def both(config: Config) -> None:
explore_continue, explore_iter_num = ray.get(ref_explore)
train_continue, train_iter_num = ray.get(ref_train)
if not explore_continue:
+ # If explore finished, the trainer may not have enough experiences to continue,
+ # which will cause the trainer be blocked. So we stop the training process
+ # immediately.
+ # TODO: use a more elegant way to stop the training process.
logger.info("Explorer finished, stopping...")
break
if not train_continue:
diff --git a/trinity/common/config.py b/trinity/common/config.py
index bf3b215745..89693f11d3 100644
--- a/trinity/common/config.py
+++ b/trinity/common/config.py
@@ -149,6 +149,10 @@ class ExplorerConfig:
# for rollout tokneize
chat_template: Optional[str] = None
+ # for evaluation
+ # TODO: remove trainer.eval_interval
+ eval_interval: int = 100
+
# for vLLM
tensor_parallel_size: int = 1
enable_prefix_caching: bool = False
@@ -322,6 +326,11 @@ def check_and_update(self) -> None:
print(
f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}."
)
+ if self.explorer.eval_interval != self.trainer.eval_interval:
+ self.explorer.eval_interval = self.trainer.eval_interval
+ print(
+ f"Warning: explorer.eval_interval is not equal to trainer.eval_interval; adjusted to the same value={self.trainer.eval_interval}."
+ )
# check monitor
if not self.monitor.cache_root_dir:
diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py
index 7256a9089d..babd84ee5d 100644
--- a/trinity/common/models/vllm_async_model.py
+++ b/trinity/common/models/vllm_async_model.py
@@ -205,10 +205,7 @@ async def _generate_internal(self, request_id: int, prompt: Any, **kwargs) -> An
# request_output.prompt = request.prompt
return request_output
- raise RuntimeError(
- "[vLLM] The request is not finished. This should not happen. "
- "Please report this issue to the Ray team."
- )
+ raise RuntimeError("[vLLM] The request is not finished. This should not happen.")
async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience:
"""Convert a list of messages into an experience."""
@@ -219,7 +216,7 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex
token_ids, action_mask = self.action_mask_method(
self.tokenizer, messages, self.chat_template
)
- logprobs = await self.logprobs_async(token_ids=token_ids)
+ logprobs = await self.logprobs_async(token_ids=token_ids.tolist())
return Experience(
tokens=token_ids,
prompt_length=len(token_ids),
diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py
index 0eaef5e3a3..ca205d7617 100644
--- a/trinity/common/models/vllm_model.py
+++ b/trinity/common/models/vllm_model.py
@@ -252,7 +252,7 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
token_ids, action_mask = self.action_mask_method(
self.tokenizer, messages, self.chat_template
)
- logprobs = self.logprobs(token_ids=token_ids)
+ logprobs = self.logprobs(token_ids=token_ids.tolist())
return Experience(
tokens=token_ids,
prompt_length=len(token_ids),
diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py
index 46676e62a8..f7598bbea0 100644
--- a/trinity/common/verl_config.py
+++ b/trinity/common/verl_config.py
@@ -133,7 +133,7 @@ class Rollout:
@dataclass
class ActorRolloutRef:
- hybrid_engine: bool = False
+ hybrid_engine: bool = True
model: ActorModel = field(default_factory=ActorModel)
actor: Actor = field(default_factory=Actor)
ref: Ref = field(default_factory=Ref)
@@ -165,7 +165,7 @@ class Critic:
use_dynamic_bsz: bool = False
ppo_max_token_len_per_gpu: int = 0
forward_max_token_len_per_gpu: int = 0
- ulysses_sequence_parallel_size: int = 0
+ ulysses_sequence_parallel_size: int = 1
ppo_epochs: int = 0
shuffle: bool = False
grad_clip: float = 0.0
diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py
index 6f6c44b69c..fc3a34ca55 100644
--- a/trinity/explorer/explorer.py
+++ b/trinity/explorer/explorer.py
@@ -8,6 +8,7 @@
import ray
import torch
+from trinity.buffer import get_buffer_writer
from trinity.common.config import Config
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, TaskType
from trinity.common.models import create_rollout_models
@@ -33,6 +34,10 @@ def __init__(self, config: Config):
self.iteration = explorer_meta.get("latest_iteration", 0)
self.config = config
self.models = create_rollout_models(config)
+ self.experience_buffer = get_buffer_writer(
+ self.config.buffer.train_dataset, # type: ignore
+ self.config.buffer,
+ )
self.taskset = TaskSet.load(
self.config.data, explorer_meta.get("latest_task_index", 0), TaskType.EXPLORE
)
@@ -150,10 +155,13 @@ def get_weight(self, name: str) -> torch.Tensor:
def explore(self) -> None:
"""Explore the entire dataset."""
while True:
- explore_status, _ = self.explore_step()
+ explore_status, explore_iter = self.explore_step()
if not explore_status:
break
self.sync_weight()
+ if explore_iter % self.config.explorer.eval_interval == 0:
+ self.eval()
+ self.logger.info("Evaluation finished.")
self.logger.info("Explorer finished.")
def explore_step(self) -> Tuple[bool, int]:
@@ -180,6 +188,7 @@ def explore_step(self) -> Tuple[bool, int]:
tasks = [next(self.task_iter) for _ in range(task_num_per_step)] # type: ignore
self.runner_pool.run_tasks(tasks)
except StopIteration:
+ self.experience_buffer.finish()
self.logger.warning("No more tasks in the task set. Stop exploring.")
return False, self.iteration
@@ -213,7 +222,7 @@ def explore_step(self) -> Tuple[bool, int]:
current_task_index=self.taskset.index,
)
- self.logger.info("Explore step finished.")
+ self.logger.info(f"Explore iteration {self.iteration} finished.")
return True, self.iteration
def eval(self) -> bool:
diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py
index d60746ac90..e7641387c8 100644
--- a/trinity/explorer/runner_pool.py
+++ b/trinity/explorer/runner_pool.py
@@ -52,6 +52,7 @@ def _create_actors(self, num: int = 1):
for _ in range(num):
engine_index = self.engine_status.index(min(self.engine_status))
new_actor = WorkflowRunner.remote(self.config, self.models[engine_index])
+ ray.get(new_actor.__ray_ready__.remote())
self.engine_status[engine_index] += 1
self.actor_to_engine_index[new_actor] = engine_index
self._return_actor(new_actor)
diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py
index fb35d087a1..a0c136faa1 100644
--- a/trinity/trainer/trainer.py
+++ b/trinity/trainer/trainer.py
@@ -59,7 +59,7 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool
train_status, train_iter_num = self.train_iteration(algo_type)
if not train_status:
return False, train_iter_num
- self.logger.info("Trainer finished.")
+ self.logger.info(f"Trainer iteration {train_iter_num} finished.")
return True, train_iter_num
def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]:
@@ -86,7 +86,11 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple
strategy = ReadStrategy(self.config.trainer.get_exp_strategy)
else:
strategy = None
- exps = self.train_buffer.read(strategy=strategy)
+ try:
+ exps = self.train_buffer.read(strategy=strategy)
+ except StopIteration:
+ self.logger.warning("No more data to train. Stop training.")
+ return False, 0 # TODO: get the actual iteration number
return self.engine.train_rft_iteration(
Experiences.gather_experiences(
exps,