Skip to content

Commit a95f1da

Browse files
committed
fix vllm test
1 parent 66fc66a commit a95f1da

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tests/explorer/workflow_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def test_workflow_repeatable(self, workflow_cls) -> None:
475475
(DummyAsyncMultiTurnWorkflow,),
476476
],
477477
)
478-
class MultiTurnWorkflowTest(unittest.TestCase):
478+
class MultiTurnWorkflowTest(unittest.IsolatedAsyncioTestCase):
479479
def setUp(self):
480480
# configure the model
481481
self.config = get_template_config()
@@ -490,7 +490,8 @@ def setUp(self):
490490
self.engines, self.auxiliary_engines = create_inference_models(self.config)
491491
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
492492

493-
def test_multi_turn_workflow(self):
493+
async def test_multi_turn_workflow(self):
494+
await asyncio.gather(*[engine.prepare.remote() for engine in self.engines])
494495
task = Task(
495496
workflow=self.workflow_cls,
496497
repeat_times=3,
@@ -500,7 +501,7 @@ def test_multi_turn_workflow(self):
500501
workflow = task.to_workflow(self.model_wrapper)
501502
workflow.set_repeat_times(2, run_id_base=0)
502503
if workflow.asynchronous:
503-
answer = asyncio.run(workflow.run_async())
504+
answer = await workflow.run_async()
504505
else:
505506
answer = workflow.run()
506507
self.assertEqual(len(answer), 2)
@@ -727,7 +728,7 @@ async def test_workflow_with_openai(self):
727728
config.explorer.rollout_model.enable_history = True
728729
config.check_and_update()
729730
engines, auxiliary_engines = create_inference_models(config)
730-
731+
await asyncio.gather(*[engine.prepare.remote() for engine in engines])
731732
runner = WorkflowRunner(
732733
config,
733734
model=engines[0],

trinity/common/models/vllm_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,9 +554,11 @@ async def run_api_server(self) -> bool:
554554
success (bool): Whether the API server is started successfully.
555555
"""
556556
if not self.config.enable_openai_api:
557+
self.logger.info("OpenAI API server is not enabled. Skipping...")
557558
return False # Not enabled
558559

559560
if self.api_server_host is not None and self.api_server_port is not None:
561+
self.logger.info("OpenAI API server is already running. Skipping...")
560562
return True # already running
561563

562564
from trinity.common.models.vllm_patch.api_patch import (
@@ -588,6 +590,9 @@ def get_api_server_url(self) -> Optional[str]:
588590
"""
589591
if not self._prepared:
590592
raise RuntimeError("Model is not prepared. Please call `prepare()` first.")
593+
if self.api_server_host is None or self.api_server_port is None:
594+
# openai api is not enabled
595+
return None
591596
return f"http://{self.api_server_host}:{self.api_server_port}"
592597

593598
async def reset_prefix_cache(self) -> None:

0 commit comments

Comments
 (0)