diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index e04a722aeb..e8468a48ca 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -8,7 +8,8 @@ services: - RAY_ADDRESS=auto - CHECKPOINT_ROOT_DIR=/mnt/checkpoints - DATA_ROOT_DIR=/mnt/data - - MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct + - MODEL_PATH=/mnt/models/Qwen3-1.7B + - CHECKPOINT_PATH=/mnt/checkpoints working_dir: /workspace networks: - trinity-network @@ -32,7 +33,8 @@ services: - HF_ENDPOINT=https://hf-mirror.com - CHECKPOINT_ROOT_DIR=/mnt/checkpoints - DATA_ROOT_DIR=/mnt/data - - MODEL_PATH=/mnt/checkpoints/Qwen2.5-1.5B-Instruct + - MODEL_PATH=/mnt/models/Qwen3-1.7B + - CHECKPOINT_PATH=/mnt/checkpoints working_dir: /workspace volumes: - trinity-volume:/mnt diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 11f49640f7..ceca92d18c 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -36,7 +36,7 @@ jobs: - name: Run unittest working-directory: trinity-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec trinity-node-1 pytest tests --ignore=tests/data --ctrf report.json + docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json - name: Upload test results uses: actions/upload-artifact@v4 diff --git a/pyproject.toml b/pyproject.toml index d2a52385d6..45944d095e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,8 +20,8 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "verl==0.3.0.post1", - "ray==2.43.0", - "vllm==0.8.3", + "ray[default]==2.43.0", + "vllm>=0.8.3", "tensordict==0.6.2", "wandb", "omegaconf", diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py new file mode 100644 index 0000000000..058d69a126 --- /dev/null +++ b/tests/buffer/queue_test.py @@ -0,0 +1,49 @@ +import unittest + +import ray +import torch + +from trinity.buffer.reader.queue_reader import QueueReader +from trinity.buffer.writer.queue_writer import QueueWriter +from trinity.common.config import BufferConfig, DatasetConfig +from trinity.common.constants import AlgorithmType, StorageType +from trinity.common.experience import Experience + + +class TestQueueBuffer(unittest.TestCase): + def setUp(self): + ray.init(ignore_reinit_error=True) + + def test_queue_buffer(self): + total_num = 8 + put_batch_size = 2 + read_batch_size = 4 + meta = DatasetConfig( + name="test_buffer", + algorithm_type=AlgorithmType.PPO, + storage_type=StorageType.QUEUE, + ) + config = BufferConfig( + max_retry_times=3, + max_retry_interval=1, + read_batch_size=read_batch_size, + ) + writer = QueueWriter(meta, config) + reader = QueueReader(meta, config) + exps = [ + Experience( + tokens=torch.tensor([float(j) for j in range(i + 1)]), + prompt_length=i, + reward=float(i), + logprobs=torch.tensor([0.1]), + ) + for i in range(1, put_batch_size + 1) + ] + for _ in range(total_num // put_batch_size): + writer.write(exps) + writer.finish() + for _ in range(total_num // read_batch_size): + exps = reader.read() + self.assertEqual(len(exps), read_batch_size) + print(f"finish read {read_batch_size} experience") + self.assertRaises(StopIteration, reader.read) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 57761a4341..13eb657585 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -14,8 +14,9 @@ class TestSQLBuffer(unittest.TestCase): def test_create_sql_buffer(self) -> None: - put_batch_size = 4 - read_batch_size = 2 + total_num = 8 + put_batch_size = 2 + read_batch_size = 4 meta = DatasetConfig( name="test_buffer", algorithm_type=AlgorithmType.PPO, @@ -39,7 +40,8 @@ def test_create_sql_buffer(self) -> None: ) for i in range(1, put_batch_size + 1) ] - sql_writer.write(exps) - for _ in range(put_batch_size // read_batch_size): + for _ in range(total_num // put_batch_size): + sql_writer.write(exps) + for _ in range(total_num // read_batch_size): exps = sql_reader.read() self.assertEqual(len(exps), read_batch_size) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 2dceac2b96..6d035576aa 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -3,18 +3,16 @@ import os import unittest +from tests.tools import get_template_config from trinity.common.config import load_config -config_yaml_path = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml") - class TestConfig(unittest.TestCase): def test_load_default_config(self): - config = load_config(config_yaml_path) - print(config.data) + config = get_template_config() config.check_and_update() self.assertIsNotNone(config.trainer.trainer_config) - self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 4) + self.assertEqual(config.trainer.trainer_config.trainer.n_gpus_per_node, 2) self.assertEqual(config.trainer.trainer_config.trainer.nnodes, 1) self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project) self.assertEqual(config.trainer.trainer_config.trainer.experiment_name, config.monitor.name) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index ab39cc5c66..aeda2d0c76 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -5,7 +5,7 @@ import torch from transformers import AutoTokenizer -from trinity.common.config import load_config +from tests.tools import RayUnittestBase, get_template_config from trinity.common.models import create_rollout_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( @@ -13,8 +13,6 @@ tokenize_and_mask_messages_hf, ) -config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml") - def get_model_path() -> str: path = os.environ.get("MODEL_PATH") @@ -101,7 +99,12 @@ def test_generate(self): ] results = self.model_wrapper.chat(messages) self.assertEqual(len(results), self.config.explorer.repeat_times) - logprobs = self.model_wrapper.logprobs(results[0].tokens) + for result in results: + input_logprobs = result.logprobs[: result.prompt_length] + output_logprobs = result.logprobs[result.prompt_length :] + self.assertTrue(torch.all(input_logprobs == 0)) + self.assertTrue(torch.any(output_logprobs != 0)) + logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) messages.append( { @@ -126,10 +129,10 @@ def test_generate(self): self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) -class TestModelWrapperSync(BaseTestModelWrapper, unittest.TestCase): +class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) - self.config = load_config(config_dir) + self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm" self.config.explorer.engine_num = 1 @@ -138,10 +141,18 @@ def setUp(self): self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") -class TestModelWrapperAsync(BaseTestModelWrapper, unittest.TestCase): +class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + def setUp(self): ray.init(ignore_reinit_error=True) - self.config = load_config(config_dir) + self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" self.config.explorer.engine_num = 1 @@ -151,6 +162,14 @@ def setUp(self): class TestTokenizer(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + def test_assistant_token_mask(self): messages = [ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py new file mode 100644 index 0000000000..e0334dacc3 --- /dev/null +++ b/tests/explorer/explorer_test.py @@ -0,0 +1,63 @@ +"""Tests for explorer.""" +import os +from abc import abstractmethod +from datetime import datetime + +from tests.tools import ( + RayUnittestBase, + TensorBoardParser, + get_checkpoint_path, + get_model_path, + get_template_config, + get_unittest_dataset_config, +) +from trinity.cli.launcher import explore +from trinity.common.constants import MonitorType + + +class BaseExplorerCase(RayUnittestBase): + def setUp(self): + self.config = get_template_config() + self.config.model.model_path = get_model_path() + self.config.explorer.engine_type = "vllm_async" + self.config.explorer.repeat_times = 2 + self.config.monitor.monitor_type = MonitorType.TENSORBOARD + self.config.monitor.project = "Trinity-unittest" + self.config.model.checkpoint_path = get_checkpoint_path() + self.config.synchronizer.sync_iteration_interval = 2 + self.config.explorer.eval_interval = 4 + self.config.trainer.eval_interval = 4 + + @abstractmethod + def test_explorer(self): + """Test explorer""" + + +class TestExplorerCountdownEval(BaseExplorerCase): + def test_explorer(self): + self.config.data = get_unittest_dataset_config("countdown") + self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.check_and_update() + explore(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + eval_metrics = parser.metric_list("eval") + self.assertTrue(len(eval_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8) + + +class TestExplorerCountdownNoEval(BaseExplorerCase): + def test_explorer(self): + self.config.data = get_unittest_dataset_config("countdown") + self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.data.eval_split = None + self.config.check_and_update() + explore(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + eval_metrics = parser.metric_list("eval") + self.assertTrue(len(eval_metrics) == 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 9f3597216a..2ee03fbc2a 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -31,7 +31,7 @@ def run(self) -> List[Experience]: if "timeout" in self.error_type: time.sleep(self.seconds) elif self.error_type == "exception": - raise RuntimeError("Exception occurred") + raise ValueError("Exception occurred") elif self.error_type == "exit": exit(1) return [Experience(tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type)] @@ -107,19 +107,20 @@ def test_runner_pool(self): tasks=tasks, ) - # The excepted return order is: `exception` -> `timeout_5` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit` + # The excepted return order is: `exception` -> `timeout_2` -> `success` -> (`timeout_100`and `timeout_101`) -> `exit` # 1. `exception` st = time.time() status = pool.get_next_unorder() et = time.time() - self.assertTrue(et - st < 5) + self.assertTrue(et - st < 2) + print(f"First task use time: {et - st}") self.assertEqual(len(status), 1) self.assertFalse(status[0].ok) # 2. `timeout_2 st = time.time() status = pool.get_next_unorder() et = time.time() - self.assertTrue(et - st < 3) + self.assertTrue(et - st > 2) self.assertEqual(len(status), 1) self.assertTrue(status[0].ok) # 3. `success` diff --git a/tests/common/tmp/template_config.yaml b/tests/template/config.yaml similarity index 71% rename from tests/common/tmp/template_config.yaml rename to tests/template/config.yaml index e163623680..0eb84f4fb7 100644 --- a/tests/common/tmp/template_config.yaml +++ b/tests/template/config.yaml @@ -2,7 +2,7 @@ mode: both data: dataset_path: '' total_epochs: 1 - batch_size: 32 + batch_size: 4 train_split: 'train' eval_split: '' default_workflow_type: '' @@ -12,34 +12,31 @@ model: max_prompt_tokens: 2048 max_response_tokens: 2048 checkpoint_path: '' -cluster: +cluster: # 2 for explorer, 2 for trainer node_num: 1 - gpu_per_node: 8 + gpu_per_node: 4 buffer: - read_batch_size: 32 max_retry_times: 3 max_retry_interval: 1 explorer: engine_type: vllm_async engine_num: 2 - runner_num: 16 - repeat_times: 2 - tensor_parallel_size: 2 + runner_num: 4 + repeat_times: 1 + tensor_parallel_size: 1 enable_prefix_caching: false enforce_eager: true dtype: bfloat16 - temperature: 0.2 - top_p: 0.95 + temperature: 1.0 + top_p: 1.0 top_k: -1 seed: 42 logprobs: 0 backend: nccl use_ray: false - max_pending_requests: 5 - max_waiting_steps: 1 trainer: trainer_type: verl - trainer_config_path: tests/common/tmp/template_verl_config.yaml + trainer_config_path: tests/template/verl_config.yaml monitor: project: unittest name: test diff --git a/tests/template/data/countdown/test.jsonl b/tests/template/data/countdown/test.jsonl new file mode 100644 index 0000000000..dacd253db6 --- /dev/null +++ b/tests/template/data/countdown/test.jsonl @@ -0,0 +1,4 @@ +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [44, 19, 35], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [44, 19, 35], \"target\": 98}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [63, 95, 96], create an equation that equals 64. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [63, 95, 96], \"target\": 64}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [95, 11, 56], create an equation that equals 28. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [95, 11, 56], \"target\": 28}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 74, 45], create an equation that equals 48. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [19, 74, 45], \"target\": 48}"} diff --git a/tests/template/data/countdown/train.jsonl b/tests/template/data/countdown/train.jsonl new file mode 100644 index 0000000000..3e9adf0092 --- /dev/null +++ b/tests/template/data/countdown/train.jsonl @@ -0,0 +1,16 @@ +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [44, 19, 35], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [44, 19, 35], \"target\": 98}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [63, 95, 96], create an equation that equals 64. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [63, 95, 96], \"target\": 64}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [95, 11, 56], create an equation that equals 28. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [95, 11, 56], \"target\": 28}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 74, 45], create an equation that equals 48. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [19, 74, 45], \"target\": 48}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [49, 41, 73], create an equation that equals 17. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [49, 41, 73], \"target\": 17}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [48, 28, 42], create an equation that equals 62. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [48, 28, 42], \"target\": 62}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [83, 78, 1, 39], create an equation that equals 82. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [83, 78, 1, 39], \"target\": 82}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [67, 21, 31, 20], create an equation that equals 98. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [67, 21, 31, 20], \"target\": 98}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [96, 74, 94, 54], create an equation that equals 40. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [96, 74, 94, 54], \"target\": 40}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [81, 84, 62], create an equation that equals 65. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [81, 84, 62], \"target\": 65}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [2, 54, 17], create an equation that equals 35. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [2, 54, 17], \"target\": 35}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [26, 71, 58, 38], create an equation that equals 40. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [26, 71, 58, 38], \"target\": 40}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [21, 39, 34, 36], create an equation that equals 16. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [21, 39, 34, 36], \"target\": 16}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 25, 89], create an equation that equals 95. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [19, 25, 89], \"target\": 95}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [8, 62, 43], create an equation that equals 27. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [8, 62, 43], \"target\": 27}"} +{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [74, 5, 20, 88], create an equation that equals 50. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [74, 5, 20, 88], \"target\": 50}"} diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml new file mode 100644 index 0000000000..c902b0d98e --- /dev/null +++ b/tests/template/verl_config.yaml @@ -0,0 +1,145 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 4 + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 1 + use_dynamic_bsz: True + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + checkpoint: + contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + # --- below: opmd --- + alg_type: ppo # ppo / opmd / pairwise_opmd + tau: 0.000 # strength of regularization w.r.t. old / ref policy + opmd_baseline: mean # mean / logavgexp, applicable to opmd + use_uid: False # True / False, applicable to pairwise_opmd + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + # log_prob_micro_batch_size: 4 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 1 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + use_fire_sampling: False # https://arxiv.org/abs/2410.21236 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.4 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + # log_prob_micro_batch_size: 8 # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 1 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 1 + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + balance_batch: True + total_epochs: 10 + # total_training_steps: null + project_name: TinyZero + experiment_name: trinity-qwen2.5-1.5b + logger: [ 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 2 + save_freq: 20 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + val_before_train: False + max_actor_ckpt_to_keep: 1 + max_critic_ckpt_to_keep: 1 diff --git a/tests/tools.py b/tests/tools.py new file mode 100644 index 0000000000..ca8fcda4c6 --- /dev/null +++ b/tests/tools.py @@ -0,0 +1,105 @@ +import os +import unittest +from collections import defaultdict +from typing import Dict, List + +import ray +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from trinity.common.config import Config, DataConfig, FormatConfig, load_config + + +def get_template_config() -> Config: + config_path = os.path.join(os.path.dirname(__file__), "template", "config.yaml") + return load_config(config_path) + + +def get_model_path() -> str: + path = os.environ.get("MODEL_PATH") + if not path: + raise EnvironmentError( + "Please set `export MODEL_PATH=` before running this test." + ) + return path + + +def get_checkpoint_path() -> str: + path = os.environ.get("CHECKPOINT_PATH") + if not path: + raise EnvironmentError( + "Please set `export CHECKPOINT_PATH=` before running this test." + ) + return path + + +def get_unittest_dataset_config(dataset_name: str = "countdown") -> DataConfig: + """Countdown sample dataset for 8 iterations""" + if dataset_name == "countdown": + return DataConfig( + total_epochs=2, + batch_size=4, + default_workflow_type="math_workflow", + default_reward_fn_type="countdown_reward", + dataset_path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"), + train_split="train", + eval_split="test", + format_config=FormatConfig( + prompt_key="question", + response_key="answer", + ), + ) + else: + raise ValueError(f"Unknown dataset name: {dataset_name}") + + +class TensorBoardParser: + def __init__(self, log_dir: str): + self.log_dir = log_dir + self._event_files = self._find_event_files(log_dir) + self._metrics = self._load_metrics() + + def _find_event_files(self, log_dir: str) -> List[str]: + event_files = [] + for root, _, files in os.walk(log_dir): + for f in files: + if f.startswith("events.out.tfevents."): + event_files.append(os.path.join(root, f)) + return event_files + + def _load_metrics(self) -> Dict[str, Dict[int, float]]: + metrics = defaultdict(dict) + + for event_file in self._event_files: + ea = EventAccumulator(event_file) + ea.Reload() + tags = ea.Tags()["scalars"] + for tag in tags: + scalars = ea.Scalars(tag) + for scalar in scalars: + step = scalar.step + value = scalar.value + if step not in metrics[tag] or value > metrics[tag][step]: + metrics[tag][step] = value + return dict(metrics) + + def metric_exist(self, metric_name: str) -> bool: + return metric_name in self._metrics + + def metric_max_step(self, metric_name: str) -> int: + if not self.metric_exist(metric_name): + raise ValueError(f"Metric '{metric_name}' does not exist.") + steps = list(self._metrics[metric_name].keys()) + return max(steps) + + def metric_list(self, metric_prefix: str) -> List[str]: + return [name for name in self._metrics if name.startswith(metric_prefix)] + + +class RayUnittestBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True) + + @classmethod + def tearDownClass(cls): + ray.shutdown() diff --git a/tests/common/tmp/template_verl_config.yaml b/tests/trainer/__init__.py similarity index 100% rename from tests/common/tmp/template_verl_config.yaml rename to tests/trainer/__init__.py diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py new file mode 100644 index 0000000000..a62d8fe5d1 --- /dev/null +++ b/tests/trainer/trainer_test.py @@ -0,0 +1,76 @@ +"""Tests for trainer.""" +import os +import shutil +from abc import abstractmethod +from datetime import datetime + +import ray + +from tests.tools import ( + RayUnittestBase, + TensorBoardParser, + get_checkpoint_path, + get_model_path, + get_template_config, + get_unittest_dataset_config, +) +from trinity.cli.launcher import both +from trinity.common.constants import MonitorType + + +class BaseTrainerCase(RayUnittestBase): + def setUp(self): + ray.init(ignore_reinit_error=True) + self.config = get_template_config() + self.config.model.model_path = get_model_path() + self.config.trainer.engine_type = "vllm_async" + self.config.trainer.repeat_times = 3 + self.config.monitor.monitor_type = MonitorType.TENSORBOARD + self.config.model.checkpoint_path = os.path.join( + get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}" + ) + self.config.synchronizer.sync_iteration_interval = 2 + self.config.synchronizer.sync_method = "online" + self.config.explorer.eval_interval = 4 + self.config.trainer.eval_interval = 4 + + @abstractmethod + def test_trainer(self): + """Test the trainer.""" + + +class TestTrainerCountdown(BaseTrainerCase): + def test_trainer(self): + """Test the trainer.""" + self.config.data = get_unittest_dataset_config("countdown") + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.save_freq = 8 + both(self.config) + # check tensorboard + parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + eval_metrics = parser.metric_list("eval") + self.assertTrue(len(eval_metrics) > 0) + self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) + # check checkpoint + from trinity.common.models.utils import get_checkpoint_dir_with_iteration + + checkpoint_dir = get_checkpoint_dir_with_iteration( + checkpoint_root_path=self.config.model.checkpoint_path, + trainer_type=self.config.trainer.trainer_type, + iteration_num=None, + ) + self.assertTrue(os.path.exists(checkpoint_dir)) + self.assertTrue(checkpoint_dir.endswith("step_8")) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.model.checkpoint_path) diff --git a/trinity/buffer/buffer_writer.py b/trinity/buffer/buffer_writer.py index 6d92042258..ac245f50b6 100644 --- a/trinity/buffer/buffer_writer.py +++ b/trinity/buffer/buffer_writer.py @@ -9,3 +9,7 @@ class BufferWriter(ABC): @abstractmethod def write(self, data: List) -> None: """Write to buffer.""" + + @abstractmethod + def finish(self) -> None: + """Finish writing.""" diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index 943a15ebe2..0f135af0bc 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -14,6 +14,8 @@ class QueueActor: """An asyncio.Queue based queue actor.""" + FINISH_MESSAGE = "$FINISH$" + def __init__(self, dataset_config: DatasetConfig, config: BufferConfig) -> None: self.config = config self.capacity = getattr(config, "capacity", 10000) @@ -35,11 +37,17 @@ async def put_batch(self, exp_list: List) -> None: if self.sql_writer is not None: self.sql_writer.write(exp_list) + async def finish(self) -> None: + """Stop the queue.""" + await self.queue.put(self.FINISH_MESSAGE) + async def get_batch(self, batch_size: int) -> List: """Get batch of experience.""" batch = [] while True: exp_list = await self.queue.get() + if exp_list == self.FINISH_MESSAGE: + raise StopAsyncIteration() batch.extend(exp_list) if len(batch) >= batch_size: break diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 3379708da4..8ff4ef4870 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -27,4 +27,8 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig): def read(self, strategy: Optional[ReadStrategy] = None) -> List: if strategy is not None and strategy != ReadStrategy.FIFO: raise NotImplementedError(f"Read strategy {strategy} not supported for Queue Reader.") - return ray.get(self.queue.get_batch.remote(self.config.read_batch_size)) + try: + exps = ray.get(self.queue.get_batch.remote(self.config.read_batch_size)) + except StopAsyncIteration: + raise StopIteration() + return exps diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index d31b7a029a..9d4f2a83fd 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -1,7 +1,8 @@ """Writer of the Queue buffer.""" - from typing import List +import ray + from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.queue import QueueActor from trinity.common.config import BufferConfig, DatasetConfig @@ -23,4 +24,7 @@ def __init__(self, meta: DatasetConfig, config: BufferConfig): ).remote(meta, config) def write(self, data: List) -> None: - self.queue.put_batch.remote(data) + ray.get(self.queue.put_batch.remote(data)) + + def finish(self): + ray.get(self.queue.finish.remote()) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index bf3ce728a4..a2abb1e399 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -40,3 +40,7 @@ def write(self, data: list) -> None: with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: experience_models = [self.table_model_cls.from_experience(exp) for exp in data] session.add_all(experience_models) + + def finish(self) -> None: + # TODO: implement this + pass diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 9f9c799872..dda42ecaad 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -83,6 +83,10 @@ def both(config: Config) -> None: explore_continue, explore_iter_num = ray.get(ref_explore) train_continue, train_iter_num = ray.get(ref_train) if not explore_continue: + # If explore finished, the trainer may not have enough experiences to continue, + # which will cause the trainer be blocked. So we stop the training process + # immediately. + # TODO: use a more elegant way to stop the training process. logger.info("Explorer finished, stopping...") break if not train_continue: diff --git a/trinity/common/config.py b/trinity/common/config.py index bf3b215745..89693f11d3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -149,6 +149,10 @@ class ExplorerConfig: # for rollout tokneize chat_template: Optional[str] = None + # for evaluation + # TODO: remove trainer.eval_interval + eval_interval: int = 100 + # for vLLM tensor_parallel_size: int = 1 enable_prefix_caching: bool = False @@ -322,6 +326,11 @@ def check_and_update(self) -> None: print( f"Warning: eval_interval is not a multiple of sync_iteration_interval; adjusted to the nearest integer={self.trainer.eval_interval}." ) + if self.explorer.eval_interval != self.trainer.eval_interval: + self.explorer.eval_interval = self.trainer.eval_interval + print( + f"Warning: explorer.eval_interval is not equal to trainer.eval_interval; adjusted to the same value={self.trainer.eval_interval}." + ) # check monitor if not self.monitor.cache_root_dir: diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 7256a9089d..babd84ee5d 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -205,10 +205,7 @@ async def _generate_internal(self, request_id: int, prompt: Any, **kwargs) -> An # request_output.prompt = request.prompt return request_output - raise RuntimeError( - "[vLLM] The request is not finished. This should not happen. " - "Please report this issue to the Ray team." - ) + raise RuntimeError("[vLLM] The request is not finished. This should not happen.") async def convert_messages_to_experience_async(self, messages: List[dict]) -> Experience: """Convert a list of messages into an experience.""" @@ -219,7 +216,7 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex token_ids, action_mask = self.action_mask_method( self.tokenizer, messages, self.chat_template ) - logprobs = await self.logprobs_async(token_ids=token_ids) + logprobs = await self.logprobs_async(token_ids=token_ids.tolist()) return Experience( tokens=token_ids, prompt_length=len(token_ids), diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 0eaef5e3a3..ca205d7617 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -252,7 +252,7 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: token_ids, action_mask = self.action_mask_method( self.tokenizer, messages, self.chat_template ) - logprobs = self.logprobs(token_ids=token_ids) + logprobs = self.logprobs(token_ids=token_ids.tolist()) return Experience( tokens=token_ids, prompt_length=len(token_ids), diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 46676e62a8..f7598bbea0 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -133,7 +133,7 @@ class Rollout: @dataclass class ActorRolloutRef: - hybrid_engine: bool = False + hybrid_engine: bool = True model: ActorModel = field(default_factory=ActorModel) actor: Actor = field(default_factory=Actor) ref: Ref = field(default_factory=Ref) @@ -165,7 +165,7 @@ class Critic: use_dynamic_bsz: bool = False ppo_max_token_len_per_gpu: int = 0 forward_max_token_len_per_gpu: int = 0 - ulysses_sequence_parallel_size: int = 0 + ulysses_sequence_parallel_size: int = 1 ppo_epochs: int = 0 shuffle: bool = False grad_clip: float = 0.0 diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 6f6c44b69c..fc3a34ca55 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -8,6 +8,7 @@ import ray import torch +from trinity.buffer import get_buffer_writer from trinity.common.config import Config from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, TaskType from trinity.common.models import create_rollout_models @@ -33,6 +34,10 @@ def __init__(self, config: Config): self.iteration = explorer_meta.get("latest_iteration", 0) self.config = config self.models = create_rollout_models(config) + self.experience_buffer = get_buffer_writer( + self.config.buffer.train_dataset, # type: ignore + self.config.buffer, + ) self.taskset = TaskSet.load( self.config.data, explorer_meta.get("latest_task_index", 0), TaskType.EXPLORE ) @@ -150,10 +155,13 @@ def get_weight(self, name: str) -> torch.Tensor: def explore(self) -> None: """Explore the entire dataset.""" while True: - explore_status, _ = self.explore_step() + explore_status, explore_iter = self.explore_step() if not explore_status: break self.sync_weight() + if explore_iter % self.config.explorer.eval_interval == 0: + self.eval() + self.logger.info("Evaluation finished.") self.logger.info("Explorer finished.") def explore_step(self) -> Tuple[bool, int]: @@ -180,6 +188,7 @@ def explore_step(self) -> Tuple[bool, int]: tasks = [next(self.task_iter) for _ in range(task_num_per_step)] # type: ignore self.runner_pool.run_tasks(tasks) except StopIteration: + self.experience_buffer.finish() self.logger.warning("No more tasks in the task set. Stop exploring.") return False, self.iteration @@ -213,7 +222,7 @@ def explore_step(self) -> Tuple[bool, int]: current_task_index=self.taskset.index, ) - self.logger.info("Explore step finished.") + self.logger.info(f"Explore iteration {self.iteration} finished.") return True, self.iteration def eval(self) -> bool: diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py index d60746ac90..e7641387c8 100644 --- a/trinity/explorer/runner_pool.py +++ b/trinity/explorer/runner_pool.py @@ -52,6 +52,7 @@ def _create_actors(self, num: int = 1): for _ in range(num): engine_index = self.engine_status.index(min(self.engine_status)) new_actor = WorkflowRunner.remote(self.config, self.models[engine_index]) + ray.get(new_actor.__ray_ready__.remote()) self.engine_status[engine_index] += 1 self.actor_to_engine_index[new_actor] = engine_index self._return_actor(new_actor) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index fb35d087a1..a0c136faa1 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -59,7 +59,7 @@ def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool train_status, train_iter_num = self.train_iteration(algo_type) if not train_status: return False, train_iter_num - self.logger.info("Trainer finished.") + self.logger.info(f"Trainer iteration {train_iter_num} finished.") return True, train_iter_num def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: @@ -86,7 +86,11 @@ def train_iteration(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple strategy = ReadStrategy(self.config.trainer.get_exp_strategy) else: strategy = None - exps = self.train_buffer.read(strategy=strategy) + try: + exps = self.train_buffer.read(strategy=strategy) + except StopIteration: + self.logger.warning("No more data to train. Stop training.") + return False, 0 # TODO: get the actual iteration number return self.engine.train_rft_iteration( Experiences.gather_experiences( exps,