diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index f05e50498e..0f0ef93e76 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1,7 +1,6 @@ import os import unittest -import ray import torch from transformers import AutoTokenizer @@ -131,12 +130,12 @@ def test_generate(self): class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): - ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm" self.config.explorer.tensor_parallel_size = 1 self.config.explorer.engine_num = 2 + self.config.explorer.repeat_times = 2 self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") @@ -144,12 +143,12 @@ def setUp(self): class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): - ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" self.config.explorer.engine_num = 2 self.config.explorer.tensor_parallel_size = 1 + self.config.explorer.repeat_times = 2 self.config.explorer.use_v1 = False self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) @@ -158,7 +157,6 @@ def setUp(self): class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase): def setUp(self): - ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" @@ -172,12 +170,12 @@ def setUp(self): class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase): def setUp(self): - ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" self.config.explorer.engine_num = 2 self.config.explorer.tensor_parallel_size = 2 + self.config.explorer.repeat_times = 2 self.config.explorer.use_v1 = True self.config.explorer.chat_template = CHAT_TEMPLATE self.engines = create_rollout_models(self.config) @@ -186,7 +184,6 @@ def setUp(self): class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase): def setUp(self): - ray.init(ignore_reinit_error=True) self.config = get_template_config() self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 4134161ef5..50a774d8d4 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -12,6 +12,7 @@ import ray import torch import vllm +from vllm.sampling_params import RequestOutputKind from trinity.common.config import Config from trinity.common.experience import Experience @@ -61,6 +62,7 @@ def __init__( truncate_prompt_tokens=config.model.max_prompt_tokens, skip_special_tokens=True, include_stop_str_in_output=False, + output_kind=RequestOutputKind.FINAL_ONLY, logprobs=config.explorer.logprobs, ) self.request_id = 0 @@ -148,10 +150,8 @@ async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: Returns: A list of experiences. """ - request_id = self.request_id - self.request_id += 1 async with self.semaphore: - output = await self._generate_internal(request_id=request_id, prompt=prompt, **kwargs) + output = await self._generate_internal(prompt=prompt, **kwargs) experiences = [ Experience( tokens=torch.cat( @@ -186,11 +186,8 @@ async def generate_async(self, prompt: str, **kwargs) -> List[Experience]: async def logprobs_async(self, token_ids: List[int]) -> torch.Tensor: """Calculate the logprobs of the given tokens in async.""" - request_id = self.request_id - self.request_id += 1 async with self.semaphore: output = await self._generate_internal( - request_id=request_id, prompt={"prompt_token_ids": token_ids}, n=1, max_tokens=1, @@ -205,10 +202,11 @@ async def logprobs_async(self, token_ids: List[int]) -> torch.Tensor: dtype=torch.float32, ) - async def _generate_internal(self, request_id: int, prompt: Any, **kwargs) -> Any: + async def _generate_internal(self, prompt: Any, **kwargs) -> Any: # Send the request to the LLM engine. + self.request_id += 1 stream = self.async_llm.generate( - request_id=str(request_id), + request_id=str(self.request_id), prompt=prompt, sampling_params=self._create_sampling_params(**kwargs), )