Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions 0_encode.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model_path='/nvme3/interns1-mini-remote'


CUDA_VISIBLE_DEVICES=1 lmdeploy serve api_server \
${model_path} \
--tp 1 \
--role Encoder \
--backend pytorch \
--server-port 23334 \
--proxy-url http://0.0.0.0:8001 \
--log-level INFO
12 changes: 12 additions & 0 deletions 0_pd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_path='/nvme3/interns1-mini-remote'


CUDA_VISIBLE_DEVICES=2 lmdeploy serve api_server \
${model_path} \
--tp 1 \
--role Hybrid \
--backend pytorch \
--server-port 23335 \
--proxy-url http://0.0.0.0:8001 \
--disable-vision-encoder \
--log-level INFO
40 changes: 40 additions & 0 deletions 0_proxy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
lmdeploy serve proxy --server-name 0.0.0.0 --server-port 8001 --routing-strategy "min_expected_latency" --serving-strategy Hybrid --log-level DEBUG

curl -X POST http://0.0.0.0:8001/distserve/connection_warmup

curl http://0.0.0.0:8001/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/nvme3/interns1-mini-remote",
"messages": [
{
"role": "user",
"content": "Hello! How are you?"
}
]
}'


curl http://0.0.0.0:8001/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/nvme3/interns1-mini-remote",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg"
}
}
]
}
],
"max_tokens": 200
}'
2 changes: 1 addition & 1 deletion lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def role(parser):
return parser.add_argument('--role',
type=str,
default='Hybrid',
choices=['Hybrid', 'Prefill', 'Decode'],
choices=['Hybrid', 'Prefill', 'Decode', 'Encoder'],
help='Hybrid for Non-Disaggregated Engine; '
'Prefill for Disaggregated Prefill Engine; '
'Decode for Disaggregated Decode Engine')
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic.dataclasses import dataclass as pydantic_dataclass

from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest

from .tokenizer import Tokenizer
from .utils import get_logger
Expand Down Expand Up @@ -116,6 +116,7 @@ class GenerationConfig:
with_cache: bool = False
preserve_cache: bool = False
migration_request: Optional[MigrationRequest] = None
encoder_result: Optional[EncoderResult] = None

def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
1 change: 1 addition & 0 deletions lmdeploy/multimodal/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
132 changes: 132 additions & 0 deletions lmdeploy/multimodal/engine/cache_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import List, Optional, Tuple

import torch

from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

FEATURE_BLOCK_SHAPE = (256, 4096)


class EncoderCacheEngine:
"""Manages the memory pool for image features.

This engine allocates and manages a contiguous block of GPU memory
to store image embeddings transferred from an encoder. It is adapted for
an encoder-LLM separated architecture.
Args:
rank (int): Distributed rank.
tp_rank (int): Tensor parallelism rank.
world_size (int): Distributed world size.
"""

def __init__(
self,
num_gpu_blocks: int = 128,
rank: int = 0,
tp_rank: int = 0,
world_size: int = 1,
) -> None:
self.world_size = world_size
self.rank = rank
self.tp_rank = tp_rank

self.feature_dtype = torch.bfloat16
self._num_gpu_blocks = num_gpu_blocks

self.encoder_gpu_cache = self._allocate_gpu_cache()

self.migration_backend_impl: Optional[MigrationBackendImpl] = None

self.cache_stream = torch.cuda.Stream()
assert self.cache_stream != torch.cuda.current_stream()
self.events = torch.cuda.Event()

# for memory block management
self.free_blocks = list(range(num_gpu_blocks))
logger.debug(f'Initialize feature cache engine with {self.num_gpu_blocks} gpu blocks.')

@property
def free_block_count(self) -> int:
"""Number of free blocks available in the cache."""
return len(self.free_blocks)

@property
def gpu_cache(self) -> torch.Tensor:
"""The GPU feature pool tensor."""
return self.encoder_gpu_cache

@property
def num_gpu_blocks(self) -> int:
"""Number of GPU blocks."""
return self._num_gpu_blocks

@staticmethod
def get_feature_block_shape() -> Tuple[int, int]:
"""Get the shape of a single image feature block."""
return FEATURE_BLOCK_SHAPE

def _allocate_cache(self, num_blocks: int, device: torch.device) -> torch.Tensor:
"""Allocate the memory pool on the specified device."""
block_shape = self.get_feature_block_shape()

# allocate a large contiguous tensor as the encoder cache
encoder_cache = torch.empty(
size=(num_blocks, *block_shape),
dtype=self.feature_dtype,
device=device,
)
return encoder_cache

def _allocate_gpu_cache(self) -> torch.Tensor:
"""Allocate the feature pool on the GPU."""
return self._allocate_cache(self.num_gpu_blocks, 'cuda')

@classmethod
def get_cache_block_size(cls) -> int:
"""Get the memory size in bytes of a single feature block."""

shape = cls.get_feature_block_shape()
dtype = torch.bfloat16

meta_tensor = torch.empty(shape, dtype=dtype, device='meta')
return meta_tensor.numel() * meta_tensor.element_size()

""" Methods for Disaggregation Begin. """

def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> List[DistServeKVTransferEndpointInfo]:
if not self.migration_backend_impl:
self.migration_backend_impl: MigrationBackendImpl = MIGRATION_BACKENDS.module_dict['DLSlime']()
migration_init_request.rank = self.rank
self.migration_backend_impl.p2p_initialize(migration_init_request)

t = self.encoder_gpu_cache
if t.numel() > 0:
register_mr_request = DistServeRegisterMRMessage(
protocol=migration_init_request.protocol,
remote_engine_id=migration_init_request.remote_engine_id,
mr_key='encoder_cache', # fix memory registration key
addr=t.data_ptr(),
offset=t.storage_offset(),
length=t.numel() * t.itemsize)
self.migration_backend_impl.register_memory_region(register_mr_request)

return [
DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol,
endpoint_info=json.dumps(
self.migration_backend_impl.endpoint_info(
migration_init_request.remote_engine_id,
migration_init_request.protocol)))
]

def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]):
self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank])

""" Methods for Disaggregation End. """
158 changes: 158 additions & 0 deletions lmdeploy/multimodal/engine/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import copy
from typing import Dict, List, Optional

from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeConnectionResponse,
DistServeConnectionStatus, DistServeDropConnectionRequest,
DistServeEngineEndpointInfo, DistServeInitRequest,
DistServeInitResponse)
from lmdeploy.utils import get_logger

from ..models.builder import load_mm_model
from .model_agent import build_model_agent
from .post_process import PostProcessor
from .pre_process import PreProcessor

logger = get_logger('lmdeploy')


class MultiModalEngine():
"""The multi-modal async inference engine of lmdeploy."""

def __init__(self,
model_path: str,
chat_template: object,
tokenizer: object,
engine_config: PytorchEngineConfig = None,
trust_remote_code: bool = True) -> None:
# make sure engine config exist
if engine_config is None:
engine_config = PytorchEngineConfig()
self.engine_config = copy.deepcopy(engine_config)
self.chat_template = chat_template
self.tokenizer = tokenizer

# build model
self.model = load_mm_model(model_path, backend_config=self.engine_config)

# build model agent
self.model_agent = build_model_agent(self.model)
self.model_agent.init()

# init pre / post processor
self.post_processor = PostProcessor(self.model_agent)
self.pre_processor = PreProcessor(self.model_agent, self.post_processor)

self.engine_conn = EngineP2PConnection(self)

def start_loop(self):
"""Start async loops."""
# invoked in api server start up event, where we already have running event loop started by uvicorn.run()
# therefore we don't create a new event loop manually, simply start loops for each module
self.pre_processor.start_loop()
self.post_processor.start_loop()
self.model_agent.start_loop()

def close(self):
"""Close the engine and release resources."""
self.pre_processor.close()
self.post_processor.close()
self.model_agent.close()

@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
chat_template: object,
tokenizer: object,
engine_config: PytorchEngineConfig = None,
trust_remote_code: bool = True,
**kwargs):
"""Create a MultiModalEngine instance."""
return cls(model_path=pretrained_model_name_or_path,
chat_template=chat_template,
tokenizer=tokenizer,
engine_config=engine_config,
trust_remote_code=trust_remote_code)

async def encode(self, messages, session_id: int):
"""Async encode."""
future = asyncio.Future()

# future will later be set in post-processor
self.pre_processor.process(session_id, messages, future)

return await future

# TODO: change this, put into pre-processor?
async def wrap_for_pytorch(
self,
messages: List[Dict],
chat_template,
tokenizer,
sequence_start,
tools: Optional[List[object]] = None,
enable_thinking: Optional[bool] = None,
) -> List[Dict]:
"""
Args:
messages (List[Dict]): a list of message, which is supposed to be
the output of `preprocess`
Returns:
a dict which will be passed to pytorch engine_instance's forward.
The dict is like the following:
Dict(
'prompt': 'the prompt after applying chat template'
'input_ids': [],
'multimodal': {
'pixel_values': torch.Tensor,
...
]
)
"""
result = self.model.to_pytorch(messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
enable_thinking=enable_thinking)
# clear data
for i, message in enumerate(messages):
if isinstance(message['content'], List):
messages[i]['preprocess'] = None
return result

def p2p_initialize(self, init_request: DistServeInitRequest):
"""Initialize p2p connection.

FIXME: This method is synchronous (`def`).
The standard PytorchEngine (in multi-process mode) has a synchronous
`p2p_initialize` that acts as an RPC bridge to an async worker.
To maintain a compatible interface for the `AsyncEngine` adapter,
this single-process engine also provides a synchronous implementation.
"""
kv_eps = self.model_agent.cache_engine.p2p_initialize(init_request)
# encoder has no zmq communication for now; return a dummy address
zmq_addr = 'tcp://0.0.0.0:65001'
resp = DistServeInitResponse(
status=DistServeConnectionStatus.SUCCESS,
engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_addr),
kvtransfer_endpoint_info=kv_eps,
)
return JSONResponse(jsonable_encoder(resp.model_dump()))

def p2p_connect(self, conn_request: DistServeConnectionRequest):
self.model_agent.cache_engine.p2p_connect(
conn_request.remote_engine_id,
conn_request.remote_kvtransfer_endpoint_info,
)
resp = DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS)
return JSONResponse(jsonable_encoder(resp.model_dump()))

async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
return self.engine_conn.p2p_drop_connect(drop_conn_request)
Loading
Loading