From e8e63b91a5d8e3194f30d99d0d08a1af3e0f52ae Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 7 May 2025 20:35:00 +0800 Subject: [PATCH 1/7] use vllm v1 engine --- pyproject.toml | 2 +- tests/common/vllm_test.py | 68 ++++++++++++++++++++++- tests/template/config.yaml | 1 + trinity/common/config.py | 2 + trinity/common/models/__init__.py | 59 +++++++++++--------- trinity/common/models/vllm_async_model.py | 11 +++- trinity/common/models/vllm_model.py | 12 +++- trinity/common/models/vllm_worker.py | 25 +-------- 8 files changed, 123 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45944d095e..437aae189d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ requires-python = ">=3.10" dependencies = [ "verl==0.3.0.post1", "ray[default]==2.43.0", - "vllm>=0.8.3", + "vllm>=0.8.5", "tensordict==0.6.2", "wandb", "omegaconf", diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index aeda2d0c76..d3979c35b3 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -141,7 +141,7 @@ def setUp(self): self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") -class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase): +class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): @classmethod def setUpClass(cls): ray.init(ignore_reinit_error=True) @@ -156,6 +156,72 @@ def setUp(self): self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" self.config.explorer.engine_num = 1 + self.config.explorer.use_v1 = False + self.config.explorer.chat_template = CHAT_TEMPLATE + self.engines = create_rollout_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + + +class TestModelWrapperAsyncMPV1(BaseTestModelWrapper, RayUnittestBase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def setUp(self): + ray.init(ignore_reinit_error=True) + self.config = get_template_config() + self.config.model.model_path = get_model_path() + self.config.explorer.engine_type = "vllm_async" + self.config.explorer.engine_num = 1 + self.config.explorer.tensor_parallel_size = 2 + self.config.explorer.use_v1 = True + self.config.explorer.chat_template = CHAT_TEMPLATE + self.engines = create_rollout_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + + +class TestModelWrapperAsyncSPV1(BaseTestModelWrapper, RayUnittestBase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def setUp(self): + ray.init(ignore_reinit_error=True) + self.config = get_template_config() + self.config.model.model_path = get_model_path() + self.config.explorer.engine_type = "vllm_async" + self.config.explorer.engine_num = 1 + self.config.explorer.tensor_parallel_size = 1 + self.config.explorer.use_v1 = True + self.config.explorer.chat_template = CHAT_TEMPLATE + self.engines = create_rollout_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") + + +class TestModelWrapperAsyncTensorParallel(BaseTestModelWrapper, RayUnittestBase): + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def setUp(self): + ray.init(ignore_reinit_error=True) + self.config = get_template_config() + self.config.model.model_path = get_model_path() + self.config.explorer.engine_type = "vllm_async" + self.config.explorer.engine_num = 2 + self.config.explorer.tensor_parallel_size = 2 self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 81f679f1fd..76cde18dc0 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -32,6 +32,7 @@ explorer: logprobs: 0 backend: nccl use_ray: false + use_v1: true trainer: trainer_type: verl trainer_config_path: tests/template/verl_config.yaml diff --git a/trinity/common/config.py b/trinity/common/config.py index 6dc3510ddc..a73b7915d9 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -172,6 +172,8 @@ class ExplorerConfig: use_ray: bool = False gpu_memory_utilization: float = 0.9 enable_chunked_prefil: bool = False + use_v1: bool = True + bundle_indices: str = "" # DO NOT SET this field # for workflow runner max_pending_requests: int = 5 diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index fe64706e20..eeecd4b741 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import List from trinity.common.config import Config @@ -12,20 +13,15 @@ def create_rollout_models( Each model has `tensor_parallel_size` workers. """ import ray - import vllm - from ray.util.placement_group import placement_group + from ray.util.placement_group import placement_group, placement_group_table from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from trinity.common.models.vllm_async_model import vLLMAysncRolloutModel from trinity.common.models.vllm_model import vLLMRolloutModel - from trinity.utils.log import get_logger - - logger = get_logger(__name__) - - assert vllm.__version__ >= "0.7.3", "Trinity-RFT only supports vllm >= 0.7.3" engine_num = config.explorer.engine_num tensor_parallel_size = config.explorer.tensor_parallel_size + is_multi_process = config.explorer.tensor_parallel_size > 1 vllm_engines = [] @@ -36,28 +32,39 @@ def create_rollout_models( else: raise ValueError(f"Unknown engine type: {config.explorer.engine_type}") - bundles = [{"GPU": tensor_parallel_size, "CPU": 1} for _ in range(engine_num)] - pg = placement_group(bundles) + bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)] + pg = placement_group(bundles, strategy="PACK") ray.get(pg.ready()) - for i in range(engine_num): - logger.info(f"Creating vLLM engine {i}") - scheduling_strategy = None + vllm_engines = [] - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=i, - ) + # to address https://github.com/ray-project/ray/issues/51117 + # aggregate bundles belonging to the same node + bundle_node_map = placement_group_table(pg)["bundles_to_node_id"] + node_bundle_map = defaultdict(list) + for bundle_id, node_id in bundle_node_map.items(): + node_bundle_map[node_id].append(bundle_id) - vllm_engines.append( - engine_cls.options( # type: ignore [attr-defined] - num_cpus=0, - num_gpus=tensor_parallel_size, - scheduling_strategy=scheduling_strategy, - ).remote( - config=config, - ) + for node_id, bundle_ids in node_bundle_map.items(): + assert len(bundle_ids) % tensor_parallel_size == 0, ( + f"Node {node_id} has {len(bundle_ids)} bundles, " + f"which is not divisible by tensor_parallel_size({tensor_parallel_size})" ) - + for i in range(len(bundle_ids) // tensor_parallel_size): + bundles_for_engine = bundle_ids[ + i * tensor_parallel_size : (i + 1) * tensor_parallel_size + ] + config.explorer.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) + vllm_engines.append( + engine_cls.options( + num_cpus=0, + num_gpus=0 if is_multi_process else 1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundles_for_engine[0], + ), + ).remote( + config=config, + ) + ) return vllm_engines diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 4ad86d86ae..6beee6886e 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -44,6 +44,13 @@ def __init__( self.config = config if config.explorer.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if not vllm.envs.is_set("VLLM_USE_V1"): + self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine") + os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1)) + if config.explorer.use_v1: + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.explorer.use_v1)) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" self.default_sampling_params = vllm.SamplingParams( n=config.explorer.repeat_times, temperature=config.explorer.temperature, @@ -58,11 +65,11 @@ def __init__( engine_args = vllm.AsyncEngineArgs( model=config.model.model_path, enforce_eager=config.explorer.enforce_eager, - worker_cls="trinity.common.models.vllm_worker.VLLMWorker", + worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", tensor_parallel_size=config.explorer.tensor_parallel_size, seed=config.explorer.seed, distributed_executor_backend=( - "uni" if config.explorer.tensor_parallel_size == 1 else "mp" + "uni" if config.explorer.tensor_parallel_size == 1 else "ray" ), max_model_len=config.model.max_prompt_tokens + config.model.max_response_tokens, enable_prefix_caching=config.explorer.enable_prefix_caching, diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 95aa9e805c..c6c7a94376 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -12,6 +12,7 @@ import ray import torch +import vllm from vllm import LLM from vllm.sampling_params import SamplingParams @@ -32,8 +33,13 @@ class vLLMRolloutModel(InferenceModel): def __init__(self, config: Config, **kwargs): self.logger = get_logger(__name__) self.config = config - if config.explorer.tensor_parallel_size != 1: + if not vllm.envs.is_set("VLLM_USE_V1"): + self.logger.info(f"Using vLLM v{int(config.explorer.use_v1)} engine") + os.environ["VLLM_USE_V1"] = str(int(config.explorer.use_v1)) + if config.explorer.use_v1: + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str(int(config.explorer.use_v1)) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" self.default_sampling_params = SamplingParams( n=config.explorer.repeat_times, temperature=config.explorer.temperature, @@ -48,11 +54,11 @@ def __init__(self, config: Config, **kwargs): # TODO: check checkpoint path model=config.model.model_path, enforce_eager=config.explorer.enforce_eager, - worker_cls="trinity.common.models.vllm_worker.VLLMWorker", + worker_extension_cls="trinity.common.models.vllm_worker.WorkerExtension", tensor_parallel_size=config.explorer.tensor_parallel_size, seed=config.explorer.seed, distributed_executor_backend=( - "uni" if config.explorer.tensor_parallel_size == 1 else "mp" + "uni" if config.explorer.tensor_parallel_size == 1 else "ray" ), max_model_len=config.model.max_prompt_tokens + config.model.max_response_tokens, enable_prefix_caching=config.explorer.enable_prefix_caching, diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 52e0ade475..4a32628a96 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -3,7 +3,6 @@ import ray import torch import torch.distributed -from vllm.worker.worker import Worker from trinity.utils.distributed import init_process_group, is_ipv6_address from trinity.utils.log import get_logger @@ -11,13 +10,7 @@ logger = get_logger(__name__) -def get_physical_gpu_id(): - device = torch.cuda.current_device() - props = torch.cuda.get_device_properties(device) - return str(props.uuid) - - -class VLLMWorker(Worker): +class WorkerExtension: def init_process_group( self, master_address: str, @@ -79,19 +72,3 @@ def update_weight(self, name, dtype, shape, empty_cache=False): self.model_runner.model.load_weights(weights=[(name, weight)]) del weight - - def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None, empty_cache=False): - assert ( - dtype == self.model_config.dtype - ), f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" - - handle = ipc_handles[get_physical_gpu_id()] - device_id = self.device.index - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - weight = func(*list_args) - self.model_runner.model.load_weights(weights=[(name, weight)]) - torch.cuda.synchronize() From 2d42b650bf09fdfb119827a8d1af89cb75e12506 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 7 May 2025 20:56:06 +0800 Subject: [PATCH 2/7] simplify test --- tests/common/vllm_test.py | 40 --------------------------------------- 1 file changed, 40 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index d3979c35b3..9cc45895a3 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -142,14 +142,6 @@ def setUp(self): class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): - @classmethod - def setUpClass(cls): - ray.init(ignore_reinit_error=True) - - @classmethod - def tearDownClass(cls): - ray.shutdown() - def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() @@ -163,14 +155,6 @@ def setUp(self): class TestModelWrapperAsyncMPV1(BaseTestModelWrapper, RayUnittestBase): - @classmethod - def setUpClass(cls): - ray.init(ignore_reinit_error=True) - - @classmethod - def tearDownClass(cls): - ray.shutdown() - def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() @@ -185,14 +169,6 @@ def setUp(self): class TestModelWrapperAsyncSPV1(BaseTestModelWrapper, RayUnittestBase): - @classmethod - def setUpClass(cls): - ray.init(ignore_reinit_error=True) - - @classmethod - def tearDownClass(cls): - ray.shutdown() - def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() @@ -207,14 +183,6 @@ def setUp(self): class TestModelWrapperAsyncTensorParallel(BaseTestModelWrapper, RayUnittestBase): - @classmethod - def setUpClass(cls): - ray.init(ignore_reinit_error=True) - - @classmethod - def tearDownClass(cls): - ray.shutdown() - def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() @@ -228,14 +196,6 @@ def setUp(self): class TestTokenizer(unittest.TestCase): - @classmethod - def setUpClass(cls): - ray.init(ignore_reinit_error=True) - - @classmethod - def tearDownClass(cls): - ray.shutdown() - def test_assistant_token_mask(self): messages = [ {"role": "system", "content": "You are a helpful assistant."}, From d630671d08cbb93b54d5510028a5ef8a7e74a0b0 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 7 May 2025 21:06:48 +0800 Subject: [PATCH 3/7] fix tests --- tests/buffer/queue_test.py | 10 +++------- tests/common/config_test.py | 4 +++- tests/explorer/runner_pool_test.py | 1 + 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 058d69a126..deffa9d68a 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -1,8 +1,6 @@ -import unittest - -import ray import torch +from tests.tools import RayUnittestBase from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter from trinity.common.config import BufferConfig, DatasetConfig @@ -10,16 +8,14 @@ from trinity.common.experience import Experience -class TestQueueBuffer(unittest.TestCase): - def setUp(self): - ray.init(ignore_reinit_error=True) - +class TestQueueBuffer(RayUnittestBase): def test_queue_buffer(self): total_num = 8 put_batch_size = 2 read_batch_size = 4 meta = DatasetConfig( name="test_buffer", + namespace="test_namespace", algorithm_type=AlgorithmType.PPO, storage_type=StorageType.QUEUE, ) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 6d035576aa..67a372fb5b 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -26,7 +26,9 @@ def test_all_examples_are_valid(self): example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples") for example_name in os.listdir(example_dir): for filename in os.listdir(os.path.join(example_dir, example_name)): - if filename.endswith(".yaml") and not filename.startswith("train"): + if filename.endswith(".yaml") and not ( + filename.startswith("train_") or filename.startswith("verl_") + ): print(f"Checking config: {filename}") config_path = os.path.join(example_dir, example_name, filename) try: diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 7513d0c4d3..f5160d3083 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -70,6 +70,7 @@ def setUp(self): self.config.buffer.pad_token_id = 0 self.config.buffer.train_dataset = DatasetConfig( name="test", + namespace="test_runner_pool", storage_type=StorageType.QUEUE, algorithm_type=AlgorithmType.PPO, ) From 7452480b3d91d3b0b5b04690104ef5a7879e95ab Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 8 May 2025 10:46:48 +0800 Subject: [PATCH 4/7] update vllm test --- .github/workflows/unittest.yaml | 3 +-- tests/common/vllm_test.py | 25 ++++++++++++++----------- tests/tools.py | 2 +- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index ceca92d18c..71e957a7ce 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -16,10 +16,9 @@ jobs: runs-on: self-hosted steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@master with: path: trinity-${{ github.run_id }} - fetch-depth: 0 - name: Setup docker compose working-directory: trinity-${{ github.run_id }}/.github/workflows/docker diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 9cc45895a3..e0affd2ac4 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -129,13 +129,14 @@ def test_generate(self): self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) -class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase): +class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm" - self.config.explorer.engine_num = 1 + self.config.explorer.tensor_parallel_size = 1 + self.config.explorer.engine_num = 2 self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") @@ -147,49 +148,51 @@ def setUp(self): self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 1 + self.config.explorer.engine_num = 2 + self.config.explorer.tensor_parallel_size = 1 self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") -class TestModelWrapperAsyncMPV1(BaseTestModelWrapper, RayUnittestBase): +class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 1 + self.config.explorer.engine_num = 2 self.config.explorer.tensor_parallel_size = 2 - self.config.explorer.use_v1 = True + self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") -class TestModelWrapperAsyncSPV1(BaseTestModelWrapper, RayUnittestBase): +class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" - self.config.explorer.engine_num = 1 - self.config.explorer.tensor_parallel_size = 1 + self.config.explorer.engine_num = 2 + self.config.explorer.tensor_parallel_size = 2 self.config.explorer.use_v1 = True self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") -class TestModelWrapperAsyncTensorParallel(BaseTestModelWrapper, RayUnittestBase): +class TestModelWrapperAsyncSPV1(BaseTestModelWrapper, RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" self.config.explorer.engine_num = 2 - self.config.explorer.tensor_parallel_size = 2 + self.config.explorer.tensor_parallel_size = 1 + self.config.explorer.use_v1 = True self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") diff --git a/tests/tools.py b/tests/tools.py index ca8fcda4c6..0aacf295cd 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -102,4 +102,4 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - ray.shutdown() + ray.shutdown(_exiting_interpreter=True) From 6a7719941112016484155e5293d771e74f49ec7d Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 8 May 2025 11:15:22 +0800 Subject: [PATCH 5/7] compatible with V1 --- tests/common/vllm_test.py | 2 +- tests/explorer/explorer_test.py | 2 ++ trinity/common/models/model.py | 19 ----------------- trinity/common/models/vllm_async_model.py | 26 ++++++++++------------- trinity/common/models/vllm_model.py | 5 ----- 5 files changed, 14 insertions(+), 40 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index e0affd2ac4..f05e50498e 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -184,7 +184,7 @@ def setUp(self): self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async") -class TestModelWrapperAsyncSPV1(BaseTestModelWrapper, RayUnittestBase): +class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase): def setUp(self): ray.init(ignore_reinit_error=True) self.config = get_template_config() diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index e0334dacc3..7d1fbd4c88 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -37,6 +37,7 @@ class TestExplorerCountdownEval(BaseExplorerCase): def test_explorer(self): self.config.data = get_unittest_dataset_config("countdown") self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" + self.config.explorer.use_v1 = True self.config.check_and_update() explore(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) @@ -53,6 +54,7 @@ def test_explorer(self): self.config.data = get_unittest_dataset_config("countdown") self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.data.eval_split = None + self.config.explorer.use_v1 = False self.config.check_and_update() explore(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 5cc4d41cca..8225aa09fa 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -45,29 +45,10 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex """Convert a list of messages into an experience in async.""" raise NotImplementedError - @abstractmethod - def sync_model(self, update_weight_args_list: List) -> bool: - """Sync model weights.""" - # TODO: sync with high efficiency - @abstractmethod def get_ckp_version(self) -> int: """Get the checkpoint version.""" - @abstractmethod - def init_process_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - timeout: int = 1200, - update_with_checkpoint: bool = True, - ) -> None: - """Init the process group for model weights sync.""" - def get_address(self) -> Tuple[str, int]: """Get the address of the actor.""" address = ray.util.get_node_ip_address() diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 6beee6886e..51b7848602 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -26,6 +26,7 @@ # TODO: merge into vLLMRolloutModel +# TODO: remove V0 when V1 is stable @ray.remote class vLLMAysncRolloutModel(InferenceModel): """Wrapper around the vLLM engine to handle async requests. @@ -249,15 +250,15 @@ def _create_sampling_params(self, **kwargs): setattr(params, k, v) return params - def sync_model(self, update_weight_args_list) -> bool: + async def sync_model(self, update_weight_args_list) -> bool: """Sync model weights to vLLM.""" for args in update_weight_args_list: - self.async_llm.engine.model_executor.collective_rpc("update_weight", args=args) + await self.async_llm.collective_rpc("update_weight", args=args) self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 return True - def init_process_group( + async def init_process_group( self, master_address: str, master_port: int, @@ -268,7 +269,7 @@ def init_process_group( timeout: int = 1200, update_with_checkpoint: bool = True, ): - return self.async_llm.engine.model_executor.collective_rpc( + return await self.async_llm.collective_rpc( "init_process_group", args=( master_address, @@ -282,24 +283,19 @@ def init_process_group( ), ) - def update_weight(self, name, dtype, shape, empty_cache=False): - return self.async_llm.engine.model_executor.collective_rpc( + async def update_weight(self, name, dtype, shape, empty_cache=False): + return await self.async_llm.collective_rpc( "update_weight", args=(name, dtype, shape, empty_cache) ) - def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): - return self.async_llm.engine.model_executor.collective_rpc( - "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache) - ) - - def reset_prefix_cache(self): - self.async_llm.engine.reset_prefix_cache() + async def reset_prefix_cache(self) -> None: + await self.async_llm.reset_prefix_cache() def get_ckp_version(self) -> int: return self.ckp_version - async def sleep(self, level: int = 1): + async def sleep(self, level: int = 1) -> None: await self.async_llm.sleep(level=level) - async def wake_up(self): + async def wake_up(self) -> None: await self.async_llm.wake_up() diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index c6c7a94376..b993b82586 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -113,11 +113,6 @@ def init_process_group( def update_weight(self, name, dtype, shape, empty_cache=False): return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) - def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): - return self.llm.collective_rpc( - "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache) - ) - def reset_prefix_cache(self): self.llm.llm_engine.reset_prefix_cache() From da1819b2d9f6be028faadfc7ad4bceb29ae9c972 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 8 May 2025 11:31:39 +0800 Subject: [PATCH 6/7] back compatible with V0 --- trinity/common/models/vllm_async_model.py | 25 +++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 51b7848602..fc257a6845 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -7,7 +7,7 @@ import os import re from contextlib import nullcontext -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import ray import torch @@ -43,6 +43,7 @@ def __init__( ) -> None: self.logger = get_logger(__name__) self.config = config + self.use_v1 = config.explorer.use_v1 if config.explorer.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" if not vllm.envs.is_set("VLLM_USE_V1"): @@ -250,10 +251,24 @@ def _create_sampling_params(self, **kwargs): setattr(params, k, v) return params + async def _collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): + if self.use_v1: + return await self.async_llm.collective_rpc(method, timeout, args, kwargs) + else: + return self.async_llm.engine.model_executor.collective_rpc( + method, timeout, args, kwargs + ) + async def sync_model(self, update_weight_args_list) -> bool: """Sync model weights to vLLM.""" for args in update_weight_args_list: - await self.async_llm.collective_rpc("update_weight", args=args) + await self._collective_rpc("update_weight", args=args) self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 return True @@ -269,7 +284,7 @@ async def init_process_group( timeout: int = 1200, update_with_checkpoint: bool = True, ): - return await self.async_llm.collective_rpc( + return await self._collective_rpc( "init_process_group", args=( master_address, @@ -284,9 +299,7 @@ async def init_process_group( ) async def update_weight(self, name, dtype, shape, empty_cache=False): - return await self.async_llm.collective_rpc( - "update_weight", args=(name, dtype, shape, empty_cache) - ) + return await self._collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) async def reset_prefix_cache(self) -> None: await self.async_llm.reset_prefix_cache() From 57d17c42c7ba8a4acab6edbf9b5113a9908080d2 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 8 May 2025 12:18:05 +0800 Subject: [PATCH 7/7] run unittest on PR branch --- .github/workflows/unittest.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 71e957a7ce..e03426c46b 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -16,9 +16,10 @@ jobs: runs-on: self-hosted steps: - - uses: actions/checkout@master + - uses: actions/checkout@v4 with: path: trinity-${{ github.run_id }} + ref: refs/pull/${{ github.event.issue.number }}/head - name: Setup docker compose working-directory: trinity-${{ github.run_id }}/.github/workflows/docker