Skip to content

Commit ca92466

Browse files
refactor the solution of vllm integration (#60)
1 parent d628866 commit ca92466

File tree

11 files changed

+1118
-359
lines changed

11 files changed

+1118
-359
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import TYPE_CHECKING, Type
2+
from llmserve.backend.logger import get_logger
3+
4+
logger = get_logger(__name__)
5+
6+
7+
if TYPE_CHECKING:
8+
from ._base import LLMEngine
9+
10+
from .generic import GenericEngine
11+
try:
12+
from .vllm import VllmEngine
13+
except ImportError:
14+
logger.info("Import vllm related stuff failed, please make sure 'vllm' is installed.")
15+
16+
def get_engine_cls_by_name(name: str) -> Type["LLMEngine"]:
17+
lowercase_globals = {k.lower(): v for k, v in globals().items()}
18+
ret = lowercase_globals.get(
19+
f"{name.lower()}engine", lowercase_globals.get(name.lower(), None)
20+
)
21+
assert ret
22+
return ret
23+
24+
25+
__all__ = [
26+
"get_engine_cls_by_name",
27+
"GenericEngine",
28+
"VllmEngine",
29+
]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Optional, Any
3+
from ray.air import ScalingConfig
4+
from ray.util.placement_group import PlacementGroup
5+
from llmserve.backend.server.models import Prompt
6+
7+
from llmserve.backend.logger import get_logger
8+
9+
from typing import List, Optional
10+
from ray.air import ScalingConfig
11+
12+
from llmserve.backend.logger import get_logger
13+
from llmserve.backend.server.models import Args, Prompt
14+
import asyncio
15+
16+
logger = get_logger(__name__)
17+
18+
class LLMEngine(ABC):
19+
args: Args = None
20+
"""Initialize model and tokenizer and place them on the correct device.
21+
22+
Args:
23+
device (torch.device): Device to place model and tokenizer on.
24+
world_size (int): Number of GPUs to use.
25+
"""
26+
27+
def __init__(
28+
self,
29+
args: Args,
30+
31+
):
32+
self.args = args
33+
34+
@abstractmethod
35+
async def launch_engine(
36+
self,
37+
scaling_config: ScalingConfig,
38+
placement_group: PlacementGroup,
39+
scaling_options: dict,
40+
) -> Any:
41+
"""Load model.
42+
43+
Args:
44+
model_id (str): Hugging Face model ID.
45+
"""
46+
pass
47+
48+
@abstractmethod
49+
async def predict(
50+
self,
51+
prompts: List[Prompt],
52+
*,
53+
timeout_s: float = 60,
54+
start_timestamp: Optional[float] = None,
55+
lock: asyncio.Lock,
56+
) -> List[str]:
57+
"""Load model.
58+
59+
Args:
60+
model_id (str): Hugging Face model ID.
61+
"""
62+
pass
63+
64+
@abstractmethod
65+
async def check_health(self):
66+
pass

0 commit comments

Comments
 (0)