diff --git a/0_encode.sh b/0_encode.sh new file mode 100644 index 0000000000..7942ca091e --- /dev/null +++ b/0_encode.sh @@ -0,0 +1,11 @@ +model_path='/nvme3/interns1-mini-remote' + + +CUDA_VISIBLE_DEVICES=1 lmdeploy serve api_server \ + ${model_path} \ + --tp 1 \ + --role Encoder \ + --backend pytorch \ + --server-port 23334 \ + --proxy-url http://0.0.0.0:8001 \ + --log-level INFO diff --git a/0_pd.sh b/0_pd.sh new file mode 100644 index 0000000000..54c132f9fe --- /dev/null +++ b/0_pd.sh @@ -0,0 +1,12 @@ +model_path='/nvme3/interns1-mini-remote' + + +CUDA_VISIBLE_DEVICES=2 lmdeploy serve api_server \ + ${model_path} \ + --tp 1 \ + --role Hybrid \ + --backend pytorch \ + --server-port 23335 \ + --proxy-url http://0.0.0.0:8001 \ + --disable-vision-encoder \ + --log-level INFO diff --git a/0_proxy.sh b/0_proxy.sh new file mode 100644 index 0000000000..ebcaa2f294 --- /dev/null +++ b/0_proxy.sh @@ -0,0 +1,40 @@ +lmdeploy serve proxy --server-name 0.0.0.0 --server-port 8001 --routing-strategy "min_expected_latency" --serving-strategy Hybrid --log-level DEBUG + +curl -X POST http://0.0.0.0:8001/distserve/connection_warmup + +curl http://0.0.0.0:8001/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "/nvme3/interns1-mini-remote", + "messages": [ + { + "role": "user", + "content": "Hello! How are you?" + } + ] + }' + + +curl http://0.0.0.0:8001/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "/nvme3/interns1-mini-remote", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg" + } + } + ] + } + ], + "max_tokens": 200 + }' diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index bfd94182d0..a7b2e6a778 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -594,7 +594,7 @@ def role(parser): return parser.add_argument('--role', type=str, default='Hybrid', - choices=['Hybrid', 'Prefill', 'Decode'], + choices=['Hybrid', 'Prefill', 'Decode', 'Encoder'], help='Hybrid for Non-Disaggregated Engine; ' 'Prefill for Disaggregated Prefill Engine; ' 'Decode for Disaggregated Decode Engine') diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index c078d97d75..4c243cde80 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -8,7 +8,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest from .tokenizer import Tokenizer from .utils import get_logger @@ -116,6 +116,7 @@ class GenerationConfig: with_cache: bool = False preserve_cache: bool = False migration_request: Optional[MigrationRequest] = None + encoder_result: Optional[EncoderResult] = None def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to diff --git a/lmdeploy/multimodal/__init__.py b/lmdeploy/multimodal/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/multimodal/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/multimodal/engine/__init__.py b/lmdeploy/multimodal/engine/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/multimodal/engine/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/multimodal/engine/cache_engine.py b/lmdeploy/multimodal/engine/cache_engine.py new file mode 100644 index 0000000000..cf4fa5343f --- /dev/null +++ b/lmdeploy/multimodal/engine/cache_engine.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import List, Optional, Tuple + +import torch + +from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo +from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +FEATURE_BLOCK_SHAPE = (256, 4096) + + +class EncoderCacheEngine: + """Manages the memory pool for image features. + + This engine allocates and manages a contiguous block of GPU memory + to store image embeddings transferred from an encoder. It is adapted for + an encoder-LLM separated architecture. + Args: + rank (int): Distributed rank. + tp_rank (int): Tensor parallelism rank. + world_size (int): Distributed world size. + """ + + def __init__( + self, + num_gpu_blocks: int = 128, + rank: int = 0, + tp_rank: int = 0, + world_size: int = 1, + ) -> None: + self.world_size = world_size + self.rank = rank + self.tp_rank = tp_rank + + self.feature_dtype = torch.bfloat16 + self._num_gpu_blocks = num_gpu_blocks + + self.encoder_gpu_cache = self._allocate_gpu_cache() + + self.migration_backend_impl: Optional[MigrationBackendImpl] = None + + self.cache_stream = torch.cuda.Stream() + assert self.cache_stream != torch.cuda.current_stream() + self.events = torch.cuda.Event() + + # for memory block management + self.free_blocks = list(range(num_gpu_blocks)) + logger.debug(f'Initialize feature cache engine with {self.num_gpu_blocks} gpu blocks.') + + @property + def free_block_count(self) -> int: + """Number of free blocks available in the cache.""" + return len(self.free_blocks) + + @property + def gpu_cache(self) -> torch.Tensor: + """The GPU feature pool tensor.""" + return self.encoder_gpu_cache + + @property + def num_gpu_blocks(self) -> int: + """Number of GPU blocks.""" + return self._num_gpu_blocks + + @staticmethod + def get_feature_block_shape() -> Tuple[int, int]: + """Get the shape of a single image feature block.""" + return FEATURE_BLOCK_SHAPE + + def _allocate_cache(self, num_blocks: int, device: torch.device) -> torch.Tensor: + """Allocate the memory pool on the specified device.""" + block_shape = self.get_feature_block_shape() + + # allocate a large contiguous tensor as the encoder cache + encoder_cache = torch.empty( + size=(num_blocks, *block_shape), + dtype=self.feature_dtype, + device=device, + ) + return encoder_cache + + def _allocate_gpu_cache(self) -> torch.Tensor: + """Allocate the feature pool on the GPU.""" + return self._allocate_cache(self.num_gpu_blocks, 'cuda') + + @classmethod + def get_cache_block_size(cls) -> int: + """Get the memory size in bytes of a single feature block.""" + + shape = cls.get_feature_block_shape() + dtype = torch.bfloat16 + + meta_tensor = torch.empty(shape, dtype=dtype, device='meta') + return meta_tensor.numel() * meta_tensor.element_size() + + """ Methods for Disaggregation Begin. """ + + def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> List[DistServeKVTransferEndpointInfo]: + if not self.migration_backend_impl: + self.migration_backend_impl: MigrationBackendImpl = MIGRATION_BACKENDS.module_dict['DLSlime']() + migration_init_request.rank = self.rank + self.migration_backend_impl.p2p_initialize(migration_init_request) + + t = self.encoder_gpu_cache + if t.numel() > 0: + register_mr_request = DistServeRegisterMRMessage( + protocol=migration_init_request.protocol, + remote_engine_id=migration_init_request.remote_engine_id, + mr_key='encoder_cache', # fix memory registration key + addr=t.data_ptr(), + offset=t.storage_offset(), + length=t.numel() * t.itemsize) + self.migration_backend_impl.register_memory_region(register_mr_request) + + return [ + DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, + endpoint_info=json.dumps( + self.migration_backend_impl.endpoint_info( + migration_init_request.remote_engine_id, + migration_init_request.protocol))) + ] + + def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]): + self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank]) + + """ Methods for Disaggregation End. """ diff --git a/lmdeploy/multimodal/engine/engine.py b/lmdeploy/multimodal/engine/engine.py new file mode 100644 index 0000000000..ea1eb9924b --- /dev/null +++ b/lmdeploy/multimodal/engine/engine.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import copy +from typing import Dict, List, Optional + +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse + +from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeConnectionResponse, + DistServeConnectionStatus, DistServeDropConnectionRequest, + DistServeEngineEndpointInfo, DistServeInitRequest, + DistServeInitResponse) +from lmdeploy.utils import get_logger + +from ..models.builder import load_mm_model +from .model_agent import build_model_agent +from .post_process import PostProcessor +from .pre_process import PreProcessor + +logger = get_logger('lmdeploy') + + +class MultiModalEngine(): + """The multi-modal async inference engine of lmdeploy.""" + + def __init__(self, + model_path: str, + chat_template: object, + tokenizer: object, + engine_config: PytorchEngineConfig = None, + trust_remote_code: bool = True) -> None: + # make sure engine config exist + if engine_config is None: + engine_config = PytorchEngineConfig() + self.engine_config = copy.deepcopy(engine_config) + self.chat_template = chat_template + self.tokenizer = tokenizer + + # build model + self.model = load_mm_model(model_path, backend_config=self.engine_config) + + # build model agent + self.model_agent = build_model_agent(self.model) + self.model_agent.init() + + # init pre / post processor + self.post_processor = PostProcessor(self.model_agent) + self.pre_processor = PreProcessor(self.model_agent, self.post_processor) + + self.engine_conn = EngineP2PConnection(self) + + def start_loop(self): + """Start async loops.""" + # invoked in api server start up event, where we already have running event loop started by uvicorn.run() + # therefore we don't create a new event loop manually, simply start loops for each module + self.pre_processor.start_loop() + self.post_processor.start_loop() + self.model_agent.start_loop() + + def close(self): + """Close the engine and release resources.""" + self.pre_processor.close() + self.post_processor.close() + self.model_agent.close() + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path: str, + chat_template: object, + tokenizer: object, + engine_config: PytorchEngineConfig = None, + trust_remote_code: bool = True, + **kwargs): + """Create a MultiModalEngine instance.""" + return cls(model_path=pretrained_model_name_or_path, + chat_template=chat_template, + tokenizer=tokenizer, + engine_config=engine_config, + trust_remote_code=trust_remote_code) + + async def encode(self, messages, session_id: int): + """Async encode.""" + future = asyncio.Future() + + # future will later be set in post-processor + self.pre_processor.process(session_id, messages, future) + + return await future + + # TODO: change this, put into pre-processor? + async def wrap_for_pytorch( + self, + messages: List[Dict], + chat_template, + tokenizer, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + ) -> List[Dict]: + """ + Args: + messages (List[Dict]): a list of message, which is supposed to be + the output of `preprocess` + Returns: + a dict which will be passed to pytorch engine_instance's forward. + The dict is like the following: + Dict( + 'prompt': 'the prompt after applying chat template' + 'input_ids': [], + 'multimodal': { + 'pixel_values': torch.Tensor, + ... + ] + ) + """ + result = self.model.to_pytorch(messages, + chat_template, + tokenizer, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + # clear data + for i, message in enumerate(messages): + if isinstance(message['content'], List): + messages[i]['preprocess'] = None + return result + + def p2p_initialize(self, init_request: DistServeInitRequest): + """Initialize p2p connection. + + FIXME: This method is synchronous (`def`). + The standard PytorchEngine (in multi-process mode) has a synchronous + `p2p_initialize` that acts as an RPC bridge to an async worker. + To maintain a compatible interface for the `AsyncEngine` adapter, + this single-process engine also provides a synchronous implementation. + """ + kv_eps = self.model_agent.cache_engine.p2p_initialize(init_request) + # encoder has no zmq communication for now; return a dummy address + zmq_addr = 'tcp://0.0.0.0:65001' + resp = DistServeInitResponse( + status=DistServeConnectionStatus.SUCCESS, + engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_addr), + kvtransfer_endpoint_info=kv_eps, + ) + return JSONResponse(jsonable_encoder(resp.model_dump())) + + def p2p_connect(self, conn_request: DistServeConnectionRequest): + self.model_agent.cache_engine.p2p_connect( + conn_request.remote_engine_id, + conn_request.remote_kvtransfer_endpoint_info, + ) + resp = DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS) + return JSONResponse(jsonable_encoder(resp.model_dump())) + + async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): + return self.engine_conn.p2p_drop_connect(drop_conn_request) diff --git a/lmdeploy/multimodal/engine/model_agent.py b/lmdeploy/multimodal/engine/model_agent.py new file mode 100644 index 0000000000..cd7e8ab5e4 --- /dev/null +++ b/lmdeploy/multimodal/engine/model_agent.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from typing import List + +import torch + +from lmdeploy.utils import get_logger + +from .cache_engine import EncoderCacheEngine + +logger = get_logger('lmdeploy') + + +def _try_to_cuda(data, non_blocking: bool = True): + """Recursively traverses a data structure and moves all torch.Tensors to + the configured device.""" + if data is None: + return None + if isinstance(data, torch.Tensor): + return data.to('cuda', non_blocking=non_blocking) + if isinstance(data, list): + return [_try_to_cuda(item, non_blocking) for item in data] + if isinstance(data, tuple): + return tuple(_try_to_cuda(item, non_blocking) for item in data) + if isinstance(data, dict): + return {key: _try_to_cuda(value, non_blocking) for key, value in data.items()} + return data + + +def _try_to_cpu(data): + """Recursively traverses a data structure and moves all torch.Tensors to + the CPU.""" + if data is None: + return None + if isinstance(data, torch.Tensor): + return data.cpu() + if isinstance(data, list): + return [_try_to_cpu(item) for item in data] + if isinstance(data, tuple): + return tuple(_try_to_cpu(item) for item in data) + if isinstance(data, dict): + return {key: _try_to_cpu(value) for key, value in data.items()} + return data + + +class BaseModelAgent: + + def __init__(self, model): + + # PreProcessor -> h2d loop + self._pre_in_que = asyncio.Queue() + # h2d loop -> forward loop + self._in_que = asyncio.Queue() + # forward loop -> d2h loop + self._out_que = asyncio.Queue() + # d2h loop -> PostProcessor + self._post_proc_que = asyncio.Queue() + + # backpressure signal between h2d loop <-> forward loop + self.has_inputs = asyncio.Event() + + # CUDA streams + self.in_stream = torch.cuda.Stream() + self.out_stream = torch.cuda.Stream() + self.forward_stream = torch.cuda.Stream() + + self.model = model + self.device = 'cuda' + + async def make_batch(self): + # TODO: fix for multi-batch + requests = [] + + req = await self._pre_in_que.get() + requests.append(req) + + return requests[0] + + async def async_model_forward(self): + """Model forward.""" + while True: + # wait for inputs + session_id, forward_inputs = await self._in_que.get() + print(f'get session_id: {session_id}') + print(f'get forward inputs from _in_que: {forward_inputs}') + self.next_inputs = None + + with torch.cuda.stream(self.forward_stream): + feats, allocated_blocks = self._forward_impl(forward_inputs) + + # event for async fetch outputs + event = torch.cuda.Event() + event.record() + + # put inside out_que + out = dict( + session_id=session_id, + feats=feats, + block_ids=allocated_blocks, + event=event, + ) + self._out_que.put_nowait(out) + + # reset events, for h2d prepare the next round inputs + self.has_inputs.set() + + async def h2d_loop(self): + """Host to device loop. + + preprocess inputs and put them into in_que. copy inputs to device in a different stream. + """ + while True: + await self.has_inputs.wait() + + session_id, forward_inputs = await self.make_batch() + print(f'check forward_inputs: {forward_inputs}') + + # use a different stream to copy h2d + with torch.cuda.stream(self.in_stream): + forward_inputs = _try_to_cuda(forward_inputs) + + # put inputs inside in_que, reset has_inputs + self._in_que.put_nowait((session_id, forward_inputs)) + self.has_inputs.clear() + + async def d2h_loop(self): + """Device to host loop. + + copy outputs from device to host. put outputs into post processing queue. + """ + while True: + out = await self._out_que.get() + + # check event periodically + event = out.pop('event') + while not event.query(): + await asyncio.sleep(0.001) + + # use a different stream to copy d2h + with torch.cuda.stream(self.out_stream): + out = _try_to_cpu(out) + + self._post_proc_que.put_nowait(out) + + def start_loop(self): + """Start event loop.""" + event_loop = asyncio.get_event_loop() + + # set for the first batch + self.has_inputs.set() + + # forward task + logger.info('Create task MultiModal ModelAgent ForwardLoop.') + self._forward_task = event_loop.create_task(self.async_model_forward(), name='ModelAgentForwardLoop') + + # preprocess inputs task + logger.info('Create task MultiModal ModelAgent Preprocess.') + self._preprocess_task = event_loop.create_task(self.h2d_loop(), name='ModelAgentPreprocess') + + # postprocess outputs task + logger.info('Create task MultiModal ModelAgent Postprocess.') + self._postprocess_task = event_loop.create_task(self.d2h_loop(), name='ModelAgentPostprocess') + + loop_tasks: list[asyncio.Task] = [self._forward_task, self._preprocess_task, self._postprocess_task] + + # binding done callback + self._add_loop_tasks_done_callback(loop_tasks) + + @staticmethod + def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]): + """Add loop tasks done callback.""" + + def __task_callback(task: asyncio.Task) -> None: + """Raise exception on finish.""" + task_name = task.get_name() + try: + task.result() + except asyncio.CancelledError: + logger.debug(f'Task <{task_name}> cancelled.') + return + except Exception: + logger.exception(f'Task <{task_name}> failed') + finally: + for task in tasks: + if not task.done(): + task.cancel() + + for task in tasks: + task.add_done_callback(__task_callback) + + def build_cache_engine(self): + cache_engine = EncoderCacheEngine() + self.cache_engine = cache_engine + + def _forward_impl(self, inputs): + """Model forward implementation.""" + feats = self.model.forward(inputs) + + # put feat into encoder cache + feats = feats[0] # FIXME + num_required_blocks = feats.shape[0] // 256 + if len(self.cache_engine.free_blocks) < num_required_blocks: + raise RuntimeError('Not enough free blocks in cache engine') + allocated_blocks = self.cache_engine.free_blocks[:num_required_blocks] + + # move into dedicated mm cache pool + # TODO: we dont want copy, better to just write into that memory region + # but current transformers get_image_features() returns a new tensor, seems no way to achieve this + for i in range(num_required_blocks): + src_chunk = feats[i * 256:(i + 1) * 256, :] + dst_block_id = allocated_blocks[i] + self.cache_engine.gpu_cache[dst_block_id].copy_(src_chunk) + print(f'=> allocated blocks: {allocated_blocks}') + + return feats, allocated_blocks + + def init(self): + self.build_cache_engine() + + def close(self): + self.cache_engine = None + self.model = None + torch.cuda.empty_cache() + + +def build_model_agent(model): + """Build model agent.""" + model_agent = BaseModelAgent(model=model, ) + return model_agent diff --git a/lmdeploy/multimodal/engine/post_process.py b/lmdeploy/multimodal/engine/post_process.py new file mode 100644 index 0000000000..5c9ea70e53 --- /dev/null +++ b/lmdeploy/multimodal/engine/post_process.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from typing import Dict, List + +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + + +class PostProcessor(): + + def __init__(self, model_agent): + print('=> PostProcessor init') + self.model_agent = model_agent + self._loop_task = None + + # session_id -> future + self._future_store: Dict[int, asyncio.Future] = {} + + def add_future(self, session_id, messages, future): + self._future_store[session_id] = (messages, future) + + def start_loop(self): + if not hasattr(self, '_loop_task') or self._loop_task is None: + logger.info('Starting PostProcessor loop') + self._loop_task = asyncio.create_task(self.async_loop()) + + def post_process(self, messages: List[Dict]): + # TODO: implement model-specific post process logic + return messages + + async def async_loop(self): + while True: + out = await self.model_agent._post_proc_que.get() + print(f'=> PostProcessor got data: {out}') + + out = self.post_process(out) + print(f'=> PostProcessor post-processed data: {out}') + + session_id = out.pop('session_id', None) + messages, future = self._future_store.pop(session_id, None) + messages[0]['block_ids'] = out['block_ids'] + if future and not future.done(): + print(f'=> PostProcessor setting future result: {messages}') + future.set_result(messages) + + def close(self): + """Cancel the background loop task.""" + if self._loop_task and not self._loop_task.done(): + self._loop_task.cancel() + logger.info('PostProcessor loop cancelled.') diff --git a/lmdeploy/multimodal/engine/pre_process.py b/lmdeploy/multimodal/engine/pre_process.py new file mode 100644 index 0000000000..ba04c5bbc7 --- /dev/null +++ b/lmdeploy/multimodal/engine/pre_process.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from typing import Dict, List + +from lmdeploy.utils import get_logger +from lmdeploy.vl.utils import load_image + +logger = get_logger('lmdeploy') + + +class PreProcessor(): + + def __init__(self, model_agent, post_processor): + print('=> PreProcessor init') + self._in_que = asyncio.Queue() + self.model_agent = model_agent + self.post_processor = post_processor + + self._loop_task = None + + @staticmethod + def collect_images(messages): + """Gather all images along with their respective parameters from the + messages and compile them into a single list. Each image is converted + to RGB color space. + + Args: + messages (List[Tuple[Image, Dict]]): a list of images with their + corresponding parameters + """ # noqa + images = [] + for message in messages: + content = message['content'] + if not isinstance(content, List): + continue + images.extend([(x['image'], { + k: v + for k, v in x.items() if k not in {'type', 'image'} + }) for x in content if x['type'] == 'image']) + return images + + @classmethod + async def async_convert_to_pil_images(cls, messages: List[Dict]) -> List[Dict]: + """Scan the provided messages to find image URLs or base64-encoded + image data. Loads the images into Pillow image objects. + + Args: + messages (List[Dict]): a user request of GPT4V message format + """ + if isinstance(messages, Dict): + messages = [messages] + assert isinstance(messages, List) + + out_messages = [None] * len(messages) + + def _inner_call(i, in_messages, out_messages): + role = in_messages[i]['role'] + content = in_messages[i]['content'] + assert role in ['system', 'user', 'assistant'], \ + f'unsupported role "{role}"' + if role != 'user' or isinstance(content, str): + # the content is a user's prompt or an assistant's prompt, + # returning it directly + out_messages[i] = in_messages[i] + return + # the role is a user and the content is a list, in which there + # might be image_url or image_data + assert isinstance(content, List) + message = dict(role=role, content=[]) + for item in content: + # image url or base64-encoded image data + if item['type'] == 'image_url': + """ + convert the following item: + { + 'type': 'image_url', + 'image_url': { + 'url': 'image url or base64-encoded image data', + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_url'].copy() + try: + url = data.pop('url') + image = load_image(url) + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'image_data': + """ + convert the following item: + { + 'type': 'image_data', + 'image_data': { + 'data': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_data'].copy() + try: + image = data.pop('data') + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'text': + message['content'].append(item) + else: + logger.error(f'unexpected content type {message}') + out_messages[i] = message + + await asyncio.gather(*[ + asyncio.get_event_loop().run_in_executor(None, _inner_call, i, messages, out_messages) + for i in range(len(messages)) + ]) + return out_messages + + def start_loop(self): + """Creates a task for the given coroutine.""" + if not hasattr(self, '_loop_task') or self._loop_task is None: + logger.info('Starting PreProcessor loop') + self._loop_task = asyncio.create_task(self.async_loop()) + + async def async_loop(self): + while True: + session_id, messages = await self._in_que.get() + + messages = await self.async_convert_to_pil_images(messages) + print(f'after convert msg: {messages}') + + proc_inputs = self.model_agent.model.preprocess(messages) + print(f'after preproc msg: {proc_inputs}') + + # TODO: process to get token ids, image mask + + self.model_agent._pre_in_que.put_nowait((session_id, proc_inputs)) + + def process(self, session_id, messages, future): + if messages is not None: + self._in_que.put_nowait((session_id, messages)) + + self.post_processor.add_future(session_id, messages, future) + + def close(self): + """Cancel the background loop task.""" + if self._loop_task and not self._loop_task.done(): + self._loop_task.cancel() + logger.info('PreProcessor loop cancelled.') diff --git a/lmdeploy/multimodal/models/__init__.py b/lmdeploy/multimodal/models/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/multimodal/models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/multimodal/models/base.py b/lmdeploy/multimodal/models/base.py new file mode 100644 index 0000000000..c28dcaaab0 --- /dev/null +++ b/lmdeploy/multimodal/models/base.py @@ -0,0 +1,253 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import Dict, List, Union + +import numpy as np +from mmengine import Registry +from transformers import AutoConfig, AutoTokenizer + +from lmdeploy.archs import get_model_arch +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +BASE_MODELS = Registry('base_model') + + +class BaseModel(ABC): + """Abstract base model class in the multimodal engine.""" + _arch: Union[str, List[str]] = None + + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + """init.""" + self.model_path = model_path + self.with_llm = with_llm + self.max_memory = max_memory + self.backend = backend + if hf_config is None: + _, hf_config = get_model_arch(model_path) + self.hf_config = hf_config + self.image_token_id = self.get_pad_token_id(model_path, hf_config) or 0 + + def get_pad_token_id(self, model_path, hf_config): + """Get pad_token_id from hf_config or tokenizer.""" + pad_token_id = getattr(hf_config, 'pad_token_id', None) + if pad_token_id is None: + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + pad_token_id = getattr(tokenizer, 'pad_token_id', None) + except Exception as e: + print(e) + pass + return pad_token_id + + @abstractmethod + def build_preprocessor(self, ): + """Build the preprocessor.""" + raise NotImplementedError() + + def build_model(self, ): + """Build the vision part of a VLM model when backend is turbomind. + + But when `with_llm=True`, load the whole VLM model + """ + raise NotImplementedError() + + @abstractmethod + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """Preprocess multimodal data in the messages. + + The derived class, + i.e., a specific vision model, takes the charge of image preprocessing + and the result management. + It can integrate the result into the messages list, or insert it to + the individual image item. + Args: + message(Dict): multimodal data in a dict, which is as follows: + [ + {'role': 'user', 'content': 'user prompt'}, + {'role': 'assisant', 'content': 'AI reponse'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'string', + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + ... + ] + } + {....} + ] + Returns: + the message list with preprocessing results included, which is + determined by the derived classes + """ # noqa + raise NotImplementedError() + + # @abstractmethod + # def postprocess(self, model_outputs: torch.Tensor, processed_inputs: List[Dict]) -> List[Dict]: + # """ + # Takes the model outputs and performs post-process. + # """ + # raise NotImplementedError() + + def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]: + """Extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included, which is + determined by the derived classes + """ + if self.backend == 'turbomind': + raise NotImplementedError() + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs): + """Pack the preprocessing results in a format compatible with what is + required by pytorch engine. ONLY implement it when the backend is + pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'pytorch': + raise NotImplementedError() + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs): + """Pack the forwarding results in a format compatible with what is + required by turbomind engine. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'turbomind': + raise NotImplementedError() + + @staticmethod + def collect_images(messages): + """Gather all images along with their respective parameters from the + messages and compile them into a single list. Each image is converted + to RGB color space. + + Args: + messages (List[Tuple[Image, Dict]]): a list of images with their + corresponding parameters + """ # noqa + images = [] + for message in messages: + content = message['content'] + if not isinstance(content, List): + continue + images.extend([(x['image'], { + k: v + for k, v in x.items() if k not in {'type', 'image'} + }) for x in content if x['type'] == 'image']) + return images + + def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): + """Auxiliary function to pack the preprocessing results in a format + compatible with what is required by pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect all preprocessing result from messages + preps = [x['content'] for x in messages if x['role'] == 'preprocess'] + assert len(preps) == 1 + preps = preps[0] + + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(preps)}') + + # calculate the image token offset for each image + input_ids = [] + for i, seg in enumerate(segs): + if i > 0 and i <= len(preps): + preps[i - 1].update(offset=len(input_ids)) + image_tokens = preps[i - 1]['image_tokens'] + assert self.image_token_id == preps[i - 1]['image_token_id'] + input_ids.extend([self.image_token_id] * image_tokens) + token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) + input_ids.extend(token_ids) + + return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) + + def to_turbomind_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): + """Auxiliary function to pack the forwarding results in a format + compatible with what is required by turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect image features from messages + features = [x['content'] for x in messages if x['role'] == 'forward'] + features = features[0] + features = [x.cpu().numpy() for x in features] + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(features) + 1, (f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(features)}') + + # tokenizer prompt, and get input_embeddings and input_embedding_ranges + input_ids = [] + begins = [] + ends = [] + for i, seg in enumerate(segs): + if i > 0 and i <= len(features): + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([self.image_token_id] * image_dim) + seg_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + return dict(prompt=prompt, input_ids=input_ids, input_embeddings=features, input_embedding_ranges=ranges) + + @classmethod + def match(cls, config: AutoConfig): + """Check whether the config match the model.""" + arch = config.architectures[0] if config.architectures else None + if arch and (arch == cls._arch or arch in cls._arch): + return True + return False diff --git a/lmdeploy/multimodal/models/builder.py b/lmdeploy/multimodal/models/builder.py new file mode 100644 index 0000000000..7c3f4a052c --- /dev/null +++ b/lmdeploy/multimodal/models/builder.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Optional, Union + +import torch + +from lmdeploy.archs import get_model_arch +from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig +from lmdeploy.multimodal.models.base import BASE_MODELS +from lmdeploy.utils import get_logger, get_model + +from .internvl3_hf import InternVL3VisionModel # noqa F401 + +logger = get_logger('lmdeploy') + + +def load_mm_model(model_path: str, + backend: str = '', + with_llm: bool = False, + backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None): + """Load multimodal model. + + Args: + model_path(str): the path or repo_id from model hub of the model + backend(str): the name of inference backend + with_llm(bool): load LLM model or not. Set it to False for VLM + inference scenarios and True for VLM quantization + backend_config: the config of the inference engine + """ + if not os.path.exists(model_path): + revision = getattr(backend_config, 'revision', None) + download_dir = getattr(backend_config, 'download_dir', None) + model_path = get_model(model_path, revision=revision, download_dir=download_dir) + + max_memory = None + if not with_llm: + tp = getattr(backend_config, 'tp', 1) + max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)} + + _, hf_config = get_model_arch(model_path) + kwargs = dict(model_path=model_path, with_llm=with_llm, max_memory=max_memory, hf_config=hf_config, backend=backend) + + for name, module in BASE_MODELS.module_dict.items(): + try: + if module.match(hf_config): + logger.info(f'matching multimodal model: {name}') + model = module(**kwargs) + model.build_preprocessor() + model.build_model() + return model + except Exception as e: + logger.error(f'build multimodal model {name} failed, {e}') + raise + + raise ValueError(f'unsupported multimodal model with config {hf_config}') diff --git a/lmdeploy/multimodal/models/internvl3_hf.py b/lmdeploy/multimodal/models/internvl3_hf.py new file mode 100644 index 0000000000..088eac509e --- /dev/null +++ b/lmdeploy/multimodal/models/internvl3_hf.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# TODO: may consider separate similar to transformers +# - Internvl +# - configuration_internvl.py +# - modeling_internvl.py +# - processing_internvl.py + +# but this may bring too many files, so currently we just put all things together + +from typing import Dict, List, Optional + +import torch +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoProcessor +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs + +from lmdeploy.utils import get_logger +from lmdeploy.vl.model.utils import disable_logging + +from .base import BASE_MODELS, BaseModel + +logger = get_logger('lmdeploy') + + +class InternVLImagesKwargs(ImagesKwargs, total=False): + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +class InternVLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: InternVLImagesKwargs + _defaults = { + 'text_kwargs': { + 'padding': False, + }, + 'images_kwargs': { + 'crop_to_patches': True, + }, + 'videos_kwargs': {}, + } + + +@BASE_MODELS.register_module() +class InternVL3VisionModel(BaseModel): + """Internvl3 vision model.""" + + _arch = ['InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration'] + + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) + self.arch = hf_config.architectures[0] + + def build_preprocessor(self): + self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) + tokenizer = self.processor.tokenizer + self.image_token_id = tokenizer.context_image_token_id + self.image_tokens_per_patch = self.processor.image_seq_length + self.tokenizer_init_kwargs = tokenizer.init_kwargs + + def build_model(self): + """Build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import init_empty_weights + with init_empty_weights(): + if self.arch == 'InternVLForConditionalGeneration': + model = AutoModel.from_config(self.hf_config, trust_remote_code=True) + # if not self.with_llm: + # print('delete language model') + del model.language_model + elif self.arch == 'InternS1ForConditionalGeneration': + model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True) + # if not self.with_llm: + # print('delete language model') + del model.model.language_model + else: + raise ValueError(f'unsupported model arch {self.arch}') + + model.half() + from accelerate import load_checkpoint_and_dispatch + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + checkpoint=self.model_path, + # device_map='auto' if not self.with_llm else {'': 'cpu'}, + device_map='auto', + max_memory=self.max_memory, + no_split_module_classes=['InternVLVisionLayer', 'InternS1VisionLayer'], + dtype=torch.half) + # We need eval mode to freeze the weights in model, thus, + # avoid randomness in inference. + self.model = model.eval() + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """Refers to `super.preprocess() for spec.""" + from transformers.image_utils import make_flat_list_of_images + output_kwargs = self.processor._merge_kwargs( + InternVLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer_init_kwargs, + **{ + 'return_tensors': 'pt', + 'add_special_tokens': False + }, + ) + images = self.collect_images(messages) + images = [image.convert('RGB') for image, _ in images] + num_image = len(images) + images = make_flat_list_of_images(images) + image_inputs = self.processor.image_processor(images, **output_kwargs['images_kwargs']) + image_num_patches = image_inputs.pop('num_patches').cpu().numpy().tolist() + image_pixel_values = image_inputs.pop('pixel_values') + outputs = [] + cum_num_patches = 0 + for idx in range(num_image): + cur_num_patches = image_num_patches[idx] + pixel_values = image_pixel_values[cum_num_patches:cum_num_patches + cur_num_patches, ...] + cum_num_patches += cur_num_patches + data = dict(pixel_values=pixel_values, + image_tokens=self.image_tokens_per_patch * cur_num_patches, + image_token_id=self.image_token_id) + outputs.append(data) + + return outputs + + @torch.no_grad() + def forward(self, processed_inputs: List[Dict]) -> torch.Tensor: + # FIXME: consider batch? + outputs = [] + pixel_values = [x['pixel_values'] for x in processed_inputs] + split = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.get_image_features( + pixel_values, + vision_feature_layer=self.hf_config.vision_feature_layer, + vision_feature_select_strategy=self.hf_config.vision_feature_select_strategy, + ) + feats = torch.split(feats, split, dim=0) + outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats]) + return outputs + + @staticmethod + def proc_messages( + messages, + chat_template, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + ): + """Apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['preprocess', 'forward']: + continue + n_images = len([1 for x in message['content'] if x['type'] == 'image']) + content = [x.get('text', '') for x in message['content'] if x['type'] == 'text'] + prompt = content[0] + if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt: + prompt = prompt.replace(f'{IMAGE_TOKEN}', f'{IMAGE_TOKEN}') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + elif IMAGE_TOKEN not in prompt: + prompt = f'{IMAGE_TOKEN * n_images}\n' + prompt + else: + pass + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, + messages, + chat_template, + tokenizer, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + **kwargs): + prompt, IMAGE_TOKEN = self.proc_messages(messages, + chat_template, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) + + def to_turbomind(self, + messages, + chat_template, + tokenizer, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + **kwargs): + prompt, IMAGE_TOKEN = self.proc_messages(messages, + chat_template, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index deb6c66bfd..7259154fc7 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -234,12 +234,14 @@ def __call__(self, **kwargs): def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], + encoder_cache: torch.Tensor, inputs_embeds: torch.Tensor = None, context: StepContext = None, ): """Prepare inputs.""" return self.model.prepare_inputs_for_generation( past_key_values=past_key_values, + encoder_cache=encoder_cache, inputs_embeds=inputs_embeds, context=context, ) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index a88872f2bd..aea8a3b976 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -56,12 +56,14 @@ def get_logits(self, hidden_states: torch.Tensor): def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], + encoder_cache: torch.Tensor, inputs_embeds: torch.Tensor = None, context: StepContext = None, ): """Prepare inputs.""" return self.model.prepare_inputs_for_generation( past_key_values, + encoder_cache, inputs_embeds, context, ) diff --git a/lmdeploy/pytorch/disagg/config.py b/lmdeploy/pytorch/disagg/config.py index f4dd002231..f79a3718fa 100644 --- a/lmdeploy/pytorch/disagg/config.py +++ b/lmdeploy/pytorch/disagg/config.py @@ -35,6 +35,7 @@ class EngineRole(enum.Enum): Hybrid = enum.auto() Prefill = enum.auto() Decode = enum.auto() + Encoder = enum.auto() class MigrationBackend(enum.Enum): diff --git a/lmdeploy/pytorch/disagg/conn/epd_proxy_conn.py b/lmdeploy/pytorch/disagg/conn/epd_proxy_conn.py new file mode 100644 index 0000000000..a388fa5301 --- /dev/null +++ b/lmdeploy/pytorch/disagg/conn/epd_proxy_conn.py @@ -0,0 +1,324 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import enum +import os +from collections import defaultdict +from typing import Dict, Optional, Set, Tuple + +import aiohttp +import requests + +from lmdeploy.logger import get_logger +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, + DistServeConnectionResponse, DistServeDropConnectionRequest, + DistServeInitRequest, DistServeInitResponse) +from lmdeploy.pytorch.disagg.messages import EPDConnectionMessage + +logger = get_logger('lmdeploy') + +# Parse timeout env (string -> float) safely +_raw_timeout = os.getenv('AIOHTTP_TIMEOUT', None) +try: + AIOHTTP_TIMEOUT: Optional[float] = float(_raw_timeout) if _raw_timeout else None +except ValueError: # fallback silently and log + logger.warning(f'Invalid AIOHTTP_TIMEOUT value: {_raw_timeout}, fallback to None') + AIOHTTP_TIMEOUT = None + + +class EPDConnectionStatus(enum.Enum): + Disconnected = enum.auto() + Connected = enum.auto() + Connecting = enum.auto() + + +class EPDConnectionState: + """EPDConnectionState (simple state holder with one event).""" + + def __init__(self, status: EPDConnectionStatus, event: asyncio.Event): + self.status = status + self.event = event + + async def wait(self): + await self.event.wait() + + def set_status(self, status: EPDConnectionStatus): + self.status = status + + +def get_server_api(url: str, api: str): + return f'{url}/{api}' + + +class EPDConnectionPool: + """Constructing the link of E & PD engine for the migration of Encoder + cache. + + Note: we use Peer to Peer transportation in KVCache migration. + Note: Lazy link construction is supported, which perform connection + at the first LLM request. As a result, we don't need to construct + PD Communication group when start a engine server. + Note: we perform simple fault tolerance by checkpointing the session_id of a + request which is under migrating and will trigger `gc` when the decode + instanceis crushed. + TODO (JimyMa): By now, only engines with same parallel configuration can be + correctly connected. + """ + + # Maximum concurrent connections​​ + CONN_SEMAPHORE_SIZE = 2048 + + def __init__(self): + # all encode, prefill and decode instances + # TODO (JimyMa): Maybe encoding instances + self.prefill_decode_endpoints: Set[str] = set() + self.encode_endpoints: Set[str] = set() + + # Links of EPD Connection. + self.pool: Dict[Tuple[str, str], EPDConnectionState] = {} + + # put migrating session to `self.migration_session_shelf` for increasing fault tolerance + # if a session is finished, then pop it from `self.migration_session_shelf` + # if a decode instance is disconnected, then gc all blocks of these sessions in prefill instance. + # use tuple (left, right) as key to align with drop() usage + self.migration_session_shelf: Dict[Tuple[str, str], Set[int]] = defaultdict(set) + + # conn_perform handler queue + self.waiting_conn: asyncio.Queue[Tuple[EPDConnectionMessage, asyncio.Event]] = asyncio.Queue() + + # conn Registry Lock + self.conn_lock = asyncio.Lock() + + # Connection Retry when failure + self.max_retry_cnt = 8 + + # trigger signal when conn request arrive. + self.conn_req_event = asyncio.Event() + + # conn initialized signal + self.initialized = False + + def reg_instance(self, role: EngineRole, endpoint: str): + if role == EngineRole.Prefill: + self.prefill_decode_endpoints.add(endpoint) + elif role == EngineRole.Encoder: + self.encode_endpoints.add(endpoint) + else: + raise ValueError(f'Unsupported role: {role}') + + def dereg_instance(self, endpoint: str): + # Symmetric cleanup for both roles + if endpoint in self.encode_endpoints: + dropped_key = [k for k in self.pool.keys() if k[0] == endpoint] + for k in dropped_key: + self.drop(k) + self.encode_endpoints.remove(endpoint) + elif endpoint in self.prefill_decode_endpoints: + dropped_key = [k for k in self.pool.keys() if k[1] == endpoint] + for k in dropped_key: + self.drop(k) + # TODO(JimyMa): handle side-effect by kvcache migration + self.prefill_decode_endpoints.remove(endpoint) + + async def connect(self, conn_req: EPDConnectionMessage): + + async def get_engine_config(server_endpoint): + async with self.conn_sem: + async with self.conn_sess.get( + get_server_api(server_endpoint, 'distserve/engine_info'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + # model_validate_json expects a JSON string; result is already dict + logger.info(f'engine info from {server_endpoint}: {result}') + return DistServeEngineConfig.model_validate_json(result) + + async def p2p_initialize(server_endpoint, init_request: DistServeInitRequest) -> DistServeInitResponse: + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_initialize'), + json=init_request.model_dump(mode='json'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + logger.info(f'P2P Initialize response from {server_endpoint}: {result}') + return DistServeInitResponse.model_validate(result) + + async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) -> DistServeConnectionResponse: + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_connect'), + json=conn_request.model_dump(mode='json'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + return DistServeConnectionResponse.model_validate(result) + + async def conn_worker(conn_req: EPDConnectionMessage, conn_event: asyncio.Event): + # try: + link = (conn_req.e_url, conn_req.pd_url) + logger.debug(f'{link} connecting...') + # Step 1. Get Remote Engine Configuration + prefill_decode_engine_configs = await get_engine_config(conn_req.pd_url) + encode_engine_config = await get_engine_config(conn_req.e_url) + print(f'prefill_decode_engine_configs: {prefill_decode_engine_configs}') + print(f'encode_engine_config: {encode_engine_config}') + + # encode 的 config 大部分字段为 空 + + # Step 2. Construct Initialize Configuration + prefill_decode_init_req = DistServeInitRequest( + protocol=conn_req.protocol, + local_engine_id=conn_req.pd_url, + local_engine_config=prefill_decode_engine_configs, + remote_engine_id=conn_req.e_url, + remote_engine_config=encode_engine_config, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config, + ) + encode_init_req = DistServeInitRequest( + protocol=conn_req.protocol, + local_engine_id=conn_req.e_url, + local_engine_config=encode_engine_config, + remote_engine_id=conn_req.pd_url, + remote_engine_config=prefill_decode_engine_configs, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config, + ) + + print(f'prefill_decode_init_req: {prefill_decode_init_req}') + print(f'encode_init_req: {encode_init_req}') + prefill_decode_init_resp = await p2p_initialize(conn_req.pd_url, prefill_decode_init_req) + encode_init_resp = await p2p_initialize(conn_req.e_url, encode_init_req) + + # Step 3. Connection + encode_endpoint_conn_reqs = DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.pd_url, + remote_engine_endpoint_info=prefill_decode_init_resp.engine_endpoint_info, + remote_kvtransfer_endpoint_info=prefill_decode_init_resp.kvtransfer_endpoint_info) + prefill_decode_endpoint_conn_reqs = DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.e_url, + remote_engine_endpoint_info=encode_init_resp.engine_endpoint_info, + remote_kvtransfer_endpoint_info=encode_init_resp.kvtransfer_endpoint_info) + print(f'encode_endpoint_conn_reqs: {encode_endpoint_conn_reqs}') + print(f'prefill_decode_endpoint_conn_reqs: {prefill_decode_endpoint_conn_reqs}') + await p2p_connect(conn_req.pd_url, prefill_decode_endpoint_conn_reqs) + await p2p_connect(conn_req.e_url, encode_endpoint_conn_reqs) + self.pool[link].set_status(EPDConnectionStatus.Connected) + logger.debug(f'{(conn_req.e_url, conn_req.pd_url)} connected') + # except Exception as e: + # self.pool[link].set_status(EPDConnectionStatus.Disconnected) + # logger.error(f'ep connection error: {e}') + conn_event.set() + + async def wait_for_conn(conn_req: EPDConnectionMessage, conn_event: asyncio.Event): + await self.pool[(conn_req.e_url, conn_req.pd_url)].event.wait() + conn_event.set() + + async def _perform_conn(): + logger.debug('perform_conn start') + while True: + if self.waiting_conn.empty(): + await self.conn_req_event.wait() + + self.conn_req_event.clear() + + while not self.waiting_conn.empty(): + conn_req, conn_event = self.waiting_conn.get_nowait() + link = (conn_req.e_url, conn_req.pd_url) + if link not in self.pool: + self.pool[link] = EPDConnectionState( + EPDConnectionStatus.Disconnected, + conn_event, + ) + if self.pool[link].status == EPDConnectionStatus.Connecting: + asyncio.create_task(wait_for_conn(conn_req, conn_event)) + elif self.pool[link].status == EPDConnectionStatus.Disconnected: + self.pool[link].set_status(EPDConnectionStatus.Connecting) + asyncio.create_task(conn_worker(conn_req, conn_event)) + + if not self.initialized: + loop = asyncio.get_event_loop() + loop.create_task(_perform_conn()) + self.conn_sem = asyncio.Semaphore(self.CONN_SEMAPHORE_SIZE) + self.conn_sess = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit_per_host=256), + timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT), + ) + self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) + self.initialized = True + + print(f'EPDConnectionPool connect called: {conn_req.e_url} <-> {conn_req.pd_url}') + self.reg_instance(EngineRole.Encoder, conn_req.e_url) + self.reg_instance(EngineRole.Prefill, conn_req.pd_url) + + cnt = 0 + while cnt < self.max_retry_cnt: + if self.is_connected(conn_req.e_url, conn_req.pd_url): + return + if cnt > 0: + logger.warning(f'EPD connection failure, retry cnt: {cnt}') + # simple incremental backoff + await asyncio.sleep(min(1.0, 0.2 * cnt)) + conn_event = asyncio.Event() + self.waiting_conn.put_nowait((conn_req, conn_event)) + self.conn_req_event.set() + await conn_event.wait() + cnt += 1 + async with self.conn_lock: + if (conn_req.e_url, conn_req.pd_url) in self.pool: + self.pool[conn_req.e_url, conn_req.pd_url].set_status(EPDConnectionStatus.Disconnected) + raise TimeoutError('EPDConnection Failure') + + def is_connected(self, e_url: str, pd_url: str): + link = self.pool.get((e_url, pd_url), None) + if not link: + return False + return link.status == EPDConnectionStatus.Connected + + def drop(self, ep_key: Tuple[str, str]): + left = ep_key[0] + right = ep_key[1] + + def cache_free(server_endpoint, cache_free_request: DistServeCacheFreeRequest) -> None: + try: + requests.post(get_server_api(server_endpoint, 'distserve/free_cache'), + json=cache_free_request.model_dump(mode='json')) + except Exception as e: + logger.warning(f'error cache block free {server_endpoint, cache_free_request}. ErrorMsg: {str(e)}') + + def drop_connect(server_endpoint: str, p2p_disconnect_request: DistServeDropConnectionRequest): + try: + requests.post(get_server_api(server_endpoint, 'distserve/p2p_drop_connect'), + json=p2p_disconnect_request.model_dump(mode='json')) + except Exception as e: + logger.warning(f'error drop connect {server_endpoint, p2p_disconnect_request}. ErrorMsg: {str(e)}') + + # trigger gc + logger.warning('cache block gc triggered.') + try: + for session_id in self.migration_session_shelf[(left, right)]: + cache_free(left, DistServeCacheFreeRequest(remote_engine_id=left, remote_session_id=session_id)) + except Exception as e: + logger.warning(f'gc error, ErrorMsg: {str(e)}') + finally: + self.migration_session_shelf.pop((left, right), None) + + # trigger p2p disconnect + logger.warning('drop connection triggered.') + try: + drop_connect(left, DistServeDropConnectionRequest(engine_id=left, remote_engine_id=right)) + drop_connect(right, DistServeDropConnectionRequest(engine_id=right, remote_engine_id=left)) + except Exception as e: + logger.warning(f'p2p disconnect error, ErrorMsg: {str(e)}') + + self.pool.pop((left, right), None) + + async def close(self): + if getattr(self, 'initialized', False): + try: + await self.conn_sess.close() + except Exception as e: + logger.warning(f'EPDConnectionPool close error: {e}') diff --git a/lmdeploy/pytorch/disagg/conn/protocol.py b/lmdeploy/pytorch/disagg/conn/protocol.py index aa47789497..54af27a0fa 100644 --- a/lmdeploy/pytorch/disagg/conn/protocol.py +++ b/lmdeploy/pytorch/disagg/conn/protocol.py @@ -77,6 +77,17 @@ class DistServeConnectionResponse(BaseModel): status: DistServeConnectionStatus +class EncoderResult(BaseModel): + + token_ids: List[int] + image_mask: List[int] + + protocol: MigrationProtocol + remote_engine_id: str + remote_session_id: int + remote_block_ids: List[int] + + class MigrationRequest(BaseModel): protocol: MigrationProtocol diff --git a/lmdeploy/pytorch/disagg/conn/proxy_conn.py b/lmdeploy/pytorch/disagg/conn/proxy_conn.py index a07d281248..f0ae916251 100644 --- a/lmdeploy/pytorch/disagg/conn/proxy_conn.py +++ b/lmdeploy/pytorch/disagg/conn/proxy_conn.py @@ -147,6 +147,7 @@ async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) timeout=self.aiotimeout, ) as resp: result = await resp.json() + logger.info(f'=> p2p_connect response: {result}') return DistServeConnectionResponse.model_validate(result) async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): @@ -161,6 +162,7 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): assert prefill_engine_config.tp_size == decode_engine_config.tp_size # Step 2. Construct Initialize Configuration + logger.info(f'=> check conn_req: {conn_req}') prefill_init_req = DistServeInitRequest( protocol=conn_req.protocol, local_engine_id=conn_req.p_url, @@ -183,6 +185,8 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req) decode_init_resp = await p2p_initialize(conn_req.d_url, decode_init_req) + logger.info(f'=> p2p init, prefill_init_resp: \n{prefill_init_resp}\n') + logger.info(f'=> p2p init, decode_init_resp: \n{decode_init_resp}\n') # Step 3. Connection prefill_endpoint_conn_reqs = DistServeConnectionRequest( protocol=conn_req.protocol, diff --git a/lmdeploy/pytorch/disagg/messages.py b/lmdeploy/pytorch/disagg/messages.py index 9dac0b0391..437ddbeebd 100644 --- a/lmdeploy/pytorch/disagg/messages.py +++ b/lmdeploy/pytorch/disagg/messages.py @@ -38,6 +38,15 @@ class PDConnectionMessage(BaseModel): nvlink_config: Optional[DistServeNVLinkConfig] = None +class EPDConnectionMessage(BaseModel): + e_url: str + pd_url: str + protocol: MigrationProtocol = MigrationProtocol.RDMA + tcp_config: Optional[DistServeTCPConfig] = None + rdma_config: Optional[DistServeRDMAConfig] = None + nvlink_config: Optional[DistServeNVLinkConfig] = None + + class DistServeRegisterMRMessage(BaseModel): protocol: MigrationProtocol diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index d8ec198349..d4b231277e 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -62,6 +62,8 @@ def __init__( # Initialize the cache. self.local_gpu_cache = self.allocate_gpu_cache() self.local_cpu_cache = self.allocate_cpu_cache() + # FIXME: hardcode cache size for interns1 series + self.encoder_gpu_cache = torch.empty(size=(128, 256, 4096), dtype=torch.bfloat16, device='cuda') self.migration_backend_impl: Optional[MigrationBackendImpl] = None @@ -317,7 +319,7 @@ def get_cache_block_size(cls, total = num_layers * (mem_key_block + mem_value_block) return total - """ Metheds for PD Disaggregation Begin. """ + """ Methods for PD Disaggregation Begin. """ def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistServeKVTransferEndpointInfo: if not self.migration_backend_impl: @@ -334,6 +336,19 @@ def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistSe offset=t.storage_offset(), length=t.numel() * t.itemsize) self.migration_backend_impl.register_memory_region(register_mr_request) + + # register memory region for encoder cache, otherwise cannot perform RDMA transfer + if self.encoder_gpu_cache.numel() > 0: + logger.info('p2p_init encoder_cache') + register_mr_request = DistServeRegisterMRMessage( + protocol=migration_init_request.protocol, + remote_engine_id=migration_init_request.remote_engine_id, + mr_key='encoder_cache', # Use the fixed mr key, same as the one in encoder_cache_engine + addr=self.encoder_gpu_cache.data_ptr(), + offset=self.encoder_gpu_cache.storage_offset(), + length=self.encoder_gpu_cache.numel() * self.encoder_gpu_cache.itemsize) + self.migration_backend_impl.register_memory_region(register_mr_request) + return DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, endpoint_info=json.dumps( self.migration_backend_impl.endpoint_info( @@ -383,4 +398,40 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote batch=assignment_batch, )) - """ Metheds for PD Disaggregation End. """ + async def epd_migrate(self, migration_execution_inputs: MigrationExecutionBatch): + """Handles the migration of the Multi-Modal (MM) cache.""" + if not self.migration_backend_impl: + logger.error('Migration backend is not initialized. Cannot perform EPD migration.') + return + + if self.encoder_gpu_cache.numel() == 0: + logger.warning('MM GPU cache is not allocated or is empty. Skipping EPD migration.') + return + + _, tokens_per_image, hidden_size = self.encoder_gpu_cache.shape + assignment_len = tokens_per_image * hidden_size * self.encoder_gpu_cache.element_size() + + assignment_batch: List[AssignmentInstruct] = [] + mr_key = 'encoder_cache' # Use the fixed mr key, same as the one in encoder_cache_engine + + for _, blocks_to_migration in migration_execution_inputs.requests: + for source_idx, target_idx in blocks_to_migration: + source_offset = source_idx * assignment_len + target_offset = target_idx * assignment_len + instruction = AssignmentInstruct(mr_key=mr_key, + target_offset=target_offset, + source_offset=source_offset, + length=assignment_len) + assignment_batch.append(instruction) + + if assignment_batch: + remote_engine_id = migration_execution_inputs.requests[0][0] + logger.debug(f'Migrating {len(assignment_batch)} MM feature blocks to {remote_engine_id}.') + await self.migration_backend_impl.p2p_migrate( + MigrationAssignment( + protocol=migration_execution_inputs.protocol, + remote_engine_id=remote_engine_id, + batch=assignment_batch, + )) + + """ Methods for PD Disaggregation End. """ diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 4bea7e5a64..e9c5567b7a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -417,8 +417,10 @@ def __init__(self, # for PD Disaggregation # For migrating prefill request to decode engine self.migration_event: asyncio.Event = None + # For encoder result migration + self.epd_migration_event: asyncio.Event = None # For backpressure prefill request when cache is full - self.perfill_watermark_event: asyncio.Event = None + self.prefill_watermark_event: asyncio.Event = None self.engine_conn = EngineP2PConnection(self) @@ -583,13 +585,17 @@ def _on_add_message(self, reqs: List[Request], **kwargs): logger.warning('Vision encoder has not been loaded, multimodal inputs will be ignored.') continue - result = self.input_processor.preprocess_input(input_ids, input_multimodals) + # skip preprocess if encoder results exist + if req_data.get('encoder_result') is None: + result = self.input_processor.preprocess_input(input_ids, input_multimodals) + input_ids = result.input_ids + input_multimodals = result.input_multimodals - input_ids = result.input_ids - input_multimodals = result.input_multimodals - - req_data['token_ids'] = input_ids - req_data['input_multimodals'] = input_multimodals + req_data['token_ids'] = input_ids + req_data['input_multimodals'] = input_multimodals + else: + req_data['input_multimodals'] = None + logger.info('Ignore multimodal inputs since encoder results exist.') if len(valid_reqs) > 0: self._add_message(valid_reqs) @@ -620,6 +626,9 @@ def __update_max_new_tokens(msg): sampling_param = req.data['sampling_param'] if len(sess.sequences) == 0: migration_request = req.data.get('migration_request') + encoder_result = req.data.get('encoder_result') + logger.info(f'=> add msg, migration_request {migration_request}') + logger.info(f'=> add msg, encoder_result {encoder_result}') assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') sess.add_sequence(req.data['token_ids'], sampling_param=sampling_param, @@ -627,6 +636,7 @@ def __update_max_new_tokens(msg): multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings', ), migration_request=migration_request, + encoder_result=encoder_result, resp_cache=req.data.get('with_cache'), preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) @@ -635,6 +645,11 @@ def __update_max_new_tokens(msg): if migration_request: self.scheduler._set_message_status(msg, MessageStatus.WAITING_MIGRATION) self.migration_event.set() + # if have encoder results here, skip encoding, directly proceed to prefill + if encoder_result: + logger.info('=> set waiting EPD migration') + self.scheduler._set_message_status(msg, MessageStatus.WAITING_EPD_MIGRATION) + self.epd_migration_event.set() else: msg = next(iter(sess.sequences.values())) msg.update_token_ids( @@ -696,6 +711,11 @@ def __has_values(input_multimodals): return True return False + # has_encoder_result = any([msg.encoder_result is not None for msg in messages]) + # # FIXME: any special treatment for encoder_result? + # if has_encoder_result: + # pass + has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages]) if has_embedding: has_embedding = any([len(msg.input_embeddings) > 0 for msg in messages]) @@ -767,6 +787,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): # model_metas model_metas = [msg.model_meta for msg in messages] + encoder_results = [msg.encoder_result for msg in messages] # create model inputs for all required fields model_inputs = ModelInputs( @@ -780,6 +801,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): max_kv_seqlen=max_kv_seqlen, sum_kv_seqlen=sum_kv_seqlen, model_metas=model_metas, + encoder_results=encoder_results, ) # adapters @@ -1054,6 +1076,54 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event # release coroutine for decoding await asyncio.sleep(.5) + @torch.inference_mode() + async def _async_loop_epd_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): + """Async loop for encoder-prefill migration.""" + while True: + epd_migration_running = self.scheduler._schedule_epd_migration() + if not epd_migration_running and not self.scheduler.has_epd_migration_waiting(): + await self.epd_migration_event.wait() + elif epd_migration_running: + self.epd_migration_event.clear() + for msg in epd_migration_running: + logger.info('performing epd migrations.') + migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] + epd_migration_request = msg.encoder_result + encoder_block_ids = epd_migration_request.remote_block_ids + + # FIXME: only test one request now, we simply use the same block ids + # ideally we should get block ids from scheduler, corresponding to the msg + prefill_block_ids = epd_migration_request.remote_block_ids + + assert len(encoder_block_ids) == len(prefill_block_ids), ( + f'#encoder block ids ({len(encoder_block_ids)}) must equal to ' + f'#prefill block ids ({len(prefill_block_ids)}) ' + f'all id length: {len(msg.num_token_ids)}') + migration_execution_requests.append(( + epd_migration_request.remote_engine_id, + list(zip(encoder_block_ids, prefill_block_ids)), + )) + migration_inputs = MigrationExecutionBatch(protocol=epd_migration_request.protocol, + requests=migration_execution_requests) + logger.info(f'migrating encoder cache for session: {msg.session_id} begin') + await self.executor.epd_migrate(migration_inputs) + logger.info(f'migrating encoder cache for session: {msg.session_id} done') + # TODO: we don't send free cache via zmq now, leave as future work + # await self.engine_conn.zmq_send(remote_engine_id=epd_migration_request.remote_engine_id, + # remote_session_id=epd_migration_request.remote_session_id) + + # After migration, the sequences are ready for prefill. We change their status to WAITING + # later it will be scheduled by self.scheduler.schedule_prefill() and proceed to prefill stage + self.scheduler.lock_running_epd_migration(epd_migration_running) + for msg in epd_migration_running: + self.scheduler._set_message_status(msg, MessageStatus.WAITING) + self.scheduler.unlock_running_epd_migration(epd_migration_running) + + has_runable_event.set() + else: + # release coroutine for other tasks + await asyncio.sleep(.5) + @torch.inference_mode() async def _async_loop_main( self, @@ -1078,10 +1148,11 @@ async def _async_loop_main( forward_event.clear() scheduler.collect_migration_done() + scheduler.collect_epd_migration_done() forward_inputs, next_running = await inputs_maker.send_next_inputs() if next_running is None: # TODO (JimyMa): add watermark check event instead of async sleep. - # self.perfill_watermark_event.wait() + # self.prefill_watermark_event.wait() logger.warning(f'no next prefill running request, Maybe cache is full, ' f'free gpu cache blocks: {scheduler.block_manager.get_num_free_gpu_blocks()}, ' f'total gpu cache blocks: {scheduler.block_manager.num_gpu_blocks}') @@ -1101,6 +1172,7 @@ async def _async_loop_main( # pre-forward before get last token if idx == num_loops - 1: scheduler.collect_migration_done() + scheduler.collect_epd_migration_done() forward_inputs, next_running = await inputs_maker.prefetch_next_inputs() # send output @@ -1167,6 +1239,7 @@ async def async_loop(self): # migration task self.migration_event = asyncio.Event() + self.epd_migration_event = asyncio.Event() logger.info('Starting executor.') self.executor.start(forward_event) @@ -1195,6 +1268,14 @@ async def async_loop(self): ) loop_tasks.append(loop_migration) + # TODO: only create this coroutine when in EPD mode + logger.info('Starting async task EPDMigrationLoop.') + loop_epd_migration = event_loop.create_task( + self._async_loop_epd_migration(resp_que, has_runable_event=has_runable_event), + name='MainLoopEPDMigration', + ) + loop_tasks.append(loop_epd_migration) + # binding done callback self._add_loop_tasks_done_callback(loop_tasks) self._loop_main = loop_main diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 041d8a042e..102b92d755 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -139,6 +139,7 @@ async def async_stream_infer(self, adapter_name=adapter_name, input_multimodals=multimodal, migration_request=gen_config.migration_request, + encoder_result=gen_config.encoder_result, with_cache=gen_config.with_cache, preserve_cache=gen_config.preserve_cache, ) diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 327d56a5ca..0915ebf3cc 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -593,4 +593,8 @@ async def migrate(self, batch: MigrationExecutionBatch): jobs = (worker.migrate.remote(batch) for worker in self.workers) return await asyncio.gather(*jobs) + async def epd_migrate(self, batch: MigrationExecutionBatch): + jobs = (worker.ep_migrate.remote(batch) for worker in self.workers) + return await asyncio.gather(*jobs) + """ PD Disaggregation API Begin """ diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index 283a8cddc7..c43cc6ad50 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -117,4 +117,8 @@ async def migrate(self, batch: MigrationExecutionBatch): """KV Cache Migration.""" return await self.model_agent.cache_engine.migrate(batch) + async def epd_migrate(self, batch: MigrationExecutionBatch): + """Encoder Cache Migration.""" + return await self.model_agent.cache_engine.epd_migrate(batch) + """ PD Disaggregation API End """ diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d19485e83e..0507167b53 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -243,6 +243,7 @@ def model_forward( ) input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, + encoder_cache=cache_engine.encoder_gpu_cache, context=context, ) output = model(**input_dict) diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 466e102e22..c1c0a545cc 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -48,7 +48,7 @@ class Request: def _run_until_complete(future: Awaitable): - """Run untile complete.""" + """Run until complete.""" try: event_loop = asyncio.get_event_loop() except Exception: diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 101ed62546..8aa974845a 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -7,7 +7,7 @@ from torch import Tensor from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger @@ -159,6 +159,11 @@ class MessageStatus(enum.Enum): MIGRATION_LOCKED = enum.auto() MIGRATION_DONE = enum.auto() + WAITING_EPD_MIGRATION = enum.auto() + RUNNING_EPD_MIGRATION = enum.auto() + EPD_MIGRATION_LOCKED = enum.auto() + EPD_MIGRATION_DONE = enum.auto() + SeqMap = Dict[int, 'SchedulerSequence'] @@ -257,6 +262,7 @@ def add_sequence(self, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" @@ -269,6 +275,7 @@ def add_sequence(self, sampling_param=sampling_param, adapter_name=adapter_name, migration_request=migration_request, + encoder_result=encoder_result, resp_cache=resp_cache, preserve_cache=preserve_cache) seq.update_token_ids( @@ -495,6 +502,7 @@ class SchedulerSequence: migration_request: Optional[MigrationRequest] = None resp_cache: bool = False preserve_cache: bool = False + encoder_result: Optional[EncoderResult] = None # For logging engine_events: List[EngineEvent] = field(default_factory=list) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index a377c9d4d6..7c7d801e3b 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -142,6 +142,7 @@ class ModelInputs: cross_length: torch.LongTensor = None history_cross_length: torch.LongTensor = None model_metas: List[Dict[str, Any]] = None + encoder_results: List[Dict[str, Any]] = None dp_meta: 'DPMeta' = None enable_microbatch: bool = False @@ -254,6 +255,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int): local_adapter_ids=self.local_adapter_ids, vision_inputs=vision_inputs, model_metas=self.model_metas, + encoder_results=self.encoder_results, cross_length=cross_length, history_cross_length=history_cross_length, ) @@ -319,6 +321,7 @@ class StepContext: cross_attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 model_metas: List[Dict[str, Any]] = None + encoder_results: List[Dict[str, Any]] = None dp_meta: DPMeta = None enable_microbatch: bool = False @@ -385,6 +388,7 @@ def new( vision_inputs=inputs.vision_inputs, kv_quant_policy=kv_quant_policy, model_metas=inputs.model_metas, + encoder_results=inputs.encoder_results, cross_seqlens=cross_seqlens, cross_kv_seqlens=cross_kv_seqlens, dp_meta=inputs.dp_meta, diff --git a/lmdeploy/pytorch/models/internvl3_hf.py b/lmdeploy/pytorch/models/internvl3_hf.py index 6e760dbeac..6cde7c0468 100644 --- a/lmdeploy/pytorch/models/internvl3_hf.py +++ b/lmdeploy/pytorch/models/internvl3_hf.py @@ -578,25 +578,42 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], + vision_embeddings: torch.Tensor = None, attn_metadata: Any = None, pixel_values: torch.Tensor = None, image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): - if inputs_embeds is None and pixel_values is not None: - # extract feature - self._mark_dynamic_once(pixel_values, [0]) - vit_embeds = self.get_image_features( - pixel_values, - self.vision_feature_layer, - self.vision_feature_select_strategy, - ) - lang_embeds = self.get_input_embeddings()(input_ids) - lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) - - inputs_embeds = lang_embeds - input_ids = None + if False: + if inputs_embeds is None and pixel_values is not None: + # extract feature + self._mark_dynamic_once(pixel_values, [0]) + vit_embeds = self.get_image_features( + pixel_values, + self.vision_feature_layer, + self.vision_feature_select_strategy, + ) + lang_embeds = self.get_input_embeddings()(input_ids) + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) + + inputs_embeds = lang_embeds + input_ids = None + else: + if inputs_embeds is None and vision_embeddings is not None and image_mask is not None: + print('Using encoder_cache as vit_embeds !!!!!') + print(f'input_ids: {input_ids.shape}') + # use cached feature + vit_embeds = vision_embeddings + lang_embeds = self.get_input_embeddings()(input_ids) + print(f'lang_embeds.shape: {lang_embeds.shape}') + print(f'vit_embeds.shape: {vit_embeds.shape}') + print(f'image_mask.shape: {image_mask.shape}') + print(f'image_mask[..., None].shape: {image_mask[..., None].shape}') + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) + + inputs_embeds = lang_embeds + input_ids = None if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError('You must specify exactly one of input_ids or inputs_embeds') @@ -614,6 +631,7 @@ def forward( def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], + encoder_cache: torch.Tensor = None, inputs_embeds: torch.Tensor = None, context: StepContext = None, ): @@ -646,10 +664,27 @@ def prepare_inputs_for_generation( inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + if not context.is_decoding: + if context.encoder_results is not None and context.encoder_results[0] is not None: + # FIXME: pick 0 index for now, should fix for batch > 1 + image_mask = context.encoder_results[0].image_mask + image_mask = torch.tensor(image_mask, device=input_ids.device, dtype=torch.bool) + remote_block_ids = context.encoder_results[0].remote_block_ids + vision_embeddings = encoder_cache[remote_block_ids] + print(f'len(remote_block_ids): {len(remote_block_ids)}') + print(f'remote_block_ids: {remote_block_ids}') + print(f'vision_embeddings.shape: {vision_embeddings.shape}') + print(f'vision_embeddings: {vision_embeddings}') + # FIXME: we need to change the input_ids here, or maybe even earlier + # since multi-modal requests input_ids has image token ids, different from the others + encoder_input_ids = context.encoder_results[0].token_ids + print(f'encoder_input_ids: {len(encoder_input_ids)}') + return dict( input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values, + vision_embeddings=vision_embeddings, attn_metadata=attn_metadata, pixel_values=pixel_values, image_mask=image_mask, diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 8144ac52a1..1e05246edc 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -98,6 +98,24 @@ def migration_done(self): seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) return list(seq_map.values()) + @property + def waiting_epd_migration(self): + """Get waiting sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_EPD_MIGRATION) + return list(seq_map.values()) + + @property + def running_epd_migration(self): + """Get running sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_EPD_MIGRATION) + return list(seq_map.values()) + + @property + def epd_migration_done(self): + """Get migration done sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.EPD_MIGRATION_DONE) + return list(seq_map.values()) + def build_eviction_helper(self, eviction_type: str): if eviction_type == 'copy': logger.warning('`copy` eviction has been deprecated, ' @@ -182,6 +200,48 @@ def _reorder_migrating(): return running_migration + @logging_timer('ScheduleEPMigration', logger) + def _schedule_epd_migration(self): + running_epd_migration: SeqList = [] + migrating_token_count = 0 + + def _to_running(seq: SchedulerSequence): + """To running.""" + seq.status = MessageStatus.RUNNING_EPD_MIGRATION + running_epd_migration.append(seq) + nonlocal migrating_token_count + migrating_token_count += seq.num_token_ids + + def __evict_for_seq(seq: SchedulerSequence, waiting): + """Evict until can append.""" + from itertools import chain + + hanging = reversed(self.hanging) + waiting = reversed(waiting) + evictable = list(chain(hanging, waiting)) + return self.eviction_helper.evict_for_seq(seq, evictable, 0) + + def _reorder_migrating(): + """Reorder waiting.""" + return sorted(self.waiting_epd_migration, key=lambda seq: seq.arrive_time) + + waiting_epd_migration = _reorder_migrating() + print(f'=> check waiting EPD migration: {waiting_epd_migration}') + + max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() + while len(waiting_epd_migration) > 0 and len(running_epd_migration) < max_batches: + seq = waiting_epd_migration.pop(0) + self.block_trie.match(waiting_epd_migration) + if not __evict_for_seq(seq, waiting_epd_migration): + break + + # allocate session memory + self.block_manager.allocate(seq) + _to_running(seq) + + print(f'=> check running EPD migration: {running_epd_migration}') + return running_epd_migration + @logging_timer('SchedulePrefilling', logger) def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" @@ -343,7 +403,7 @@ def end_session(self, session_id: int): def has_unfinished(self): """Check if there are any unfinished message.""" - return self.has_running() or self.has_waiting() or self.has_migration_done() + return self.has_running() or self.has_waiting() or self.has_migration_done() or self.has_epd_migration_done() def has_running(self): return self.num_running() > 0 @@ -363,6 +423,15 @@ def has_migration_waiting(self): def has_migration_done(self): return self.num_migration_done() > 0 + def has_epd_migration_running(self): + return self.num_epd_migration_running() > 0 + + def has_epd_migration_waiting(self): + return self.num_epd_migration_waiting() > 0 + + def has_epd_migration_done(self): + return self.num_epd_migration_done() > 0 + def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] @@ -395,6 +464,18 @@ def num_migration_waiting(self): """Num waiting.""" return self.seq_manager.num_sequences(MessageStatus.WAITING_MIGRATION) + def num_epd_migration_running(self): + """Num EPD migration running.""" + return self.seq_manager.num_sequences(MessageStatus.RUNNING_EPD_MIGRATION) + + def num_epd_migration_done(self): + """Num EPD migration done.""" + return self.seq_manager.num_sequences(MessageStatus.EPD_MIGRATION_DONE) + + def num_epd_migration_waiting(self): + """Num EPD migration waiting.""" + return self.seq_manager.num_sequences(MessageStatus.WAITING_EPD_MIGRATION) + def num_locked(self): """Num locked.""" return self.seq_manager.num_sequences(MessageStatus.LOCKED) @@ -422,11 +503,28 @@ def unlock_running_migration(self, locked: SeqList): if seq.status == MessageStatus.MIGRATION_LOCKED: self._set_message_status(seq, MessageStatus.MIGRATION_DONE) + def lock_running_epd_migration(self, running: SeqList): + """Lock running EPD migration sequence.""" + for seq in running: + if seq.status == MessageStatus.RUNNING_EPD_MIGRATION: + self._set_message_status(seq, MessageStatus.EPD_MIGRATION_LOCKED) + + def unlock_running_epd_migration(self, locked: SeqList): + """Unlock running EPD migration.""" + for seq in locked: + if seq.status == MessageStatus.EPD_MIGRATION_LOCKED: + self._set_message_status(seq, MessageStatus.EPD_MIGRATION_DONE) + def collect_migration_done(self): migration_done = self.migration_done for seq in migration_done: self._set_message_status(seq, MessageStatus.RUNNING) + def collect_epd_migration_done(self): + epd_migration_done = self.epd_migration_done + for seq in epd_migration_done: + self._set_message_status(seq, MessageStatus.RUNNING) + @property def schedule_metrics(self): return ScheduleMetrics( diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index 91a3335f18..ca30517faa 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -5,7 +5,7 @@ from torch import Tensor -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest from lmdeploy.pytorch.engine.model_agent import BatchedOutputs from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, SchedulerSequence, SchedulerSession, UpdateTokenMode, _to_ndarray) @@ -81,6 +81,7 @@ def make_sequence(self, sampling_param: 'SamplingParam' = None, adapter_name: str = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Make sequence.""" @@ -89,6 +90,7 @@ def make_sequence(self, sampling_param=sampling_param, adapter_name=adapter_name, migration_request=migration_request, + encoder_result=encoder_result, resp_cache=resp_cache, preserve_cache=preserve_cache) diff --git a/lmdeploy/pytorch/strategies/base/sequence.py b/lmdeploy/pytorch/strategies/base/sequence.py index 408a3cc15e..ea8f533700 100644 --- a/lmdeploy/pytorch/strategies/base/sequence.py +++ b/lmdeploy/pytorch/strategies/base/sequence.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest if TYPE_CHECKING: from lmdeploy.pytorch.engine.model_agent import BatchedOutputs @@ -19,6 +19,7 @@ def make_sequence(self, sampling_param: 'SamplingParam' = None, adapter_name: str = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Make sequence.""" diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py index ab004a2b63..06be439b87 100644 --- a/lmdeploy/pytorch/strategies/dllm/sequence.py +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -7,7 +7,7 @@ from torch import Tensor from lmdeploy.pytorch import consts -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest from lmdeploy.pytorch.engine.model_agent import BatchedOutputs from lmdeploy.pytorch.messages import (HistoryTokenIds, InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, SchedulerSession, UpdateTokenMode, _to_ndarray) @@ -206,6 +206,7 @@ def make_sequence(self, sampling_param: 'SamplingParam' = None, adapter_name: str = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequenceDLLM': """Make sequence.""" diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 43bde069a7..2cfac1ea21 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -909,6 +909,41 @@ def is_error(status): # manually end pytorch session await inst.async_end(session_id) + async def encode_generate(self, messages, session_id: int, **kwargs): + """Perform encoding.""" + if not hasattr(self.engine, 'encode'): + raise NotImplementedError('encode() is not implemented for the current backend') + + encoder_result = await self.engine.encode(messages, session_id) + print(f'check encoder_result: {encoder_result}') + remote_block_ids = encoder_result[0]['block_ids'] + + # FIXME: for simplicity, we reuse previous get_prompt_input function + # in order to get input_ids, and for calculatin image_mask + # but get_prompt_input() will invoke vl_encoder preprocess(), which duplicate with above encode() + # we should adopt a new function to get input_ids and image_mask only + prompt = messages + self.request_logger.log_prompt(session_id=session_id, prompt=prompt) + prompt_input = await self._get_prompt_input(prompt, + do_preprocess=True, + sequence_start=True, + adapter_name=None, + **kwargs) + prompt = prompt_input['prompt'] + input_ids = prompt_input['input_ids'] + + # get image_mask + image_token_id = prompt_input['multimodal'][0]['image_token_id'] + image_mask = [1 if x == image_token_id else 0 for x in prompt_input['input_ids']] + + # pack results together and return to api server + return { + 'token_ids': input_ids, + 'image_mask': image_mask, + 'remote_session_id': session_id, + 'remote_block_ids': remote_block_ids + } + def _run(self, fn=None, coro=None, loop=None): assert (fn or coro) and not (fn and coro) loop = loop or self.internal_thread.loop diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 75472399fd..66123c26cd 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -25,9 +25,9 @@ from lmdeploy.messages import GenerationConfig, LogitsProcessor, PytorchEngineConfig, TurbomindEngineConfig from lmdeploy.metrics.metrics_processor import metrics_processor from lmdeploy.model import ChatTemplateConfig -from lmdeploy.pytorch.disagg.config import DistServeEngineConfig +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, - DistServeDropConnectionRequest, DistServeInitRequest, + DistServeDropConnectionRequest, DistServeInitRequest, EncoderResult, MigrationRequest) from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.harmony_utils import GptOssChatParser @@ -356,9 +356,13 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque migration_request = json_request.pop('migration_request', None) with_cache = json_request.pop('with_cache', False) preserve_cache = json_request.pop('preserve_cache', False) + encoder_result = json_request.pop('encoder_result', None) if migration_request: migration_request = MigrationRequest.model_validate(migration_request) - + if encoder_result: + encoder_result = EncoderResult.model_validate(encoder_result) + print(f'=> api server, migration_request: \n{migration_request}\n') + print(f'=> api server, encoder_result: \n{encoder_result}\n') if request.session_id == -1: VariableInterface.session_id += 1 request.session_id = VariableInterface.session_id @@ -368,6 +372,17 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id!r} is occupied.') + # if encoder, we only do encoding and return + engine_role = VariableInterface.async_engine.backend_config.role + if engine_role == EngineRole.Encoder: + encoder_result = await VariableInterface.async_engine.encode_generate( + request.messages, + request.session_id, + ) + # TODO: use CompleteResponse prototype + print(f'api_server, v1/completion, encoder_result: {encoder_result}') + return encoder_result + model_name = request.model adapter_name = None if model_name != VariableInterface.async_engine.model_name: @@ -418,6 +433,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, + encoder_result=encoder_result, with_cache=with_cache, preserve_cache=preserve_cache, ) diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index e78579eded..a6cc807e09 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -7,6 +7,8 @@ import shortuuid from pydantic import BaseModel, ConfigDict, Field +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult + class ErrorResponse(BaseModel): """Error responses.""" @@ -152,6 +154,7 @@ class ChatCompletionRequest(BaseModel): enable_thinking: Optional[bool] = None return_token_ids: Optional[bool] = False include_stop_str_in_output: Optional[bool] = False + encoder_result: Optional[EncoderResult] = None class FunctionCall(BaseModel): diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index d796bc64a1..d2e5719b26 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -22,9 +22,10 @@ from pydantic import BaseModel, Field from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy -from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest +from lmdeploy.pytorch.disagg.conn.epd_proxy_conn import EPDConnectionPool +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationProtocol, MigrationRequest from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool -from lmdeploy.pytorch.disagg.messages import PDConnectionMessage +from lmdeploy.pytorch.disagg.messages import EPDConnectionMessage, PDConnectionMessage from lmdeploy.serve.openai.api_server import check_api_key, create_error_response from lmdeploy.serve.openai.protocol import ModelCard # noqa: E501 from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission @@ -108,6 +109,7 @@ def __init__(self, self.migration_protocol = MigrationProtocol[migration_protocol] self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type]) self.pd_connection_pool = PDConnectionPool() + self.ep_connection_pool = EPDConnectionPool() self.dummy_prefill = False def get_nodes(self, role: EngineRole) -> Dict: @@ -126,6 +128,10 @@ def prefill_nodes(self): def decode_nodes(self): return self.get_nodes(EngineRole.Decode) + @property + def encoder_nodes(self): + return self.get_nodes(EngineRole.Encoder) + def update_config_file(self): """Update the config file.""" nodes = copy.deepcopy(self.nodes) @@ -504,6 +510,17 @@ async def connection_warmup(): rdma_config=node_manager.rdma_config, )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes ]) + logger.info(f'encoder nodes: {node_manager.encoder_nodes}\nprefill nodes: {node_manager.hybrid_nodes}') + # FIXME: we set pd_urls to use hybrid nodes, since we start LLM server in hybrid role + await asyncio.gather(*[ + node_manager.ep_connection_pool.connect( + EPDConnectionMessage( + e_url=e_url, + pd_url=pd_url, + protocol=node_manager.migration_protocol, + rdma_config=node_manager.rdma_config, + )) for e_url in node_manager.encoder_nodes for pd_url in node_manager.hybrid_nodes + ]) return JSONResponse({'SUCCESS': True}) @@ -576,21 +593,88 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque return check_response if node_manager.serving_strategy == ServingStrategy.Hybrid: - node_url = node_manager.get_node_url(request.model) - if not node_url: - return node_manager.handle_unavailable_model(request.model) + # Helper: decide whether we need encoder stage + def _need_encoder(msgs: List[Dict]) -> bool: + try: + for m in msgs: + content = m.get('content') + # user role + list content -> possible multimodal + if m.get('role') == 'user' and isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get('type') in ['image_url', 'image_data', 'image']: + return True + return False + except Exception as e: # noqa + logger.warning(f'encoder detect failed, fallback no-encoder: {e}') + return False - logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() + # 1. Encoder stage (only if encoder node exists & messages contain images) + encoder_url = None + if len(node_manager.encoder_nodes): + if _need_encoder(request_dict.get('messages', [])): + encoder_url = node_manager.get_node_url(request.model, EngineRole.Encoder) + if not encoder_url: + logger.warning( + 'Encoder nodes registered but no suitable encoder node found for model; skip encoder stage.') + else: + logger.info(f'Encoder stage dispatched to {encoder_url}') + enc_start = node_manager.pre_call(encoder_url) + # fall back to /v1/chat/completions if first fails + encoder_info = await node_manager.generate(request_dict, encoder_url, '/v1/chat/completions') + print(encoder_info) + encoder_info = json.loads(encoder_info) + remote_session_id = encoder_info['remote_session_id'] + remote_block_ids = encoder_info['remote_block_ids'] + remote_token_ids = encoder_info['token_ids'] + image_mask = encoder_info['image_mask'] + request_dict['encoder_result'] = EncoderResult( + token_ids=remote_token_ids, + image_mask=image_mask, + protocol=node_manager.migration_protocol, + remote_engine_id=encoder_url, + remote_session_id=remote_session_id, + remote_block_ids=remote_block_ids).model_dump(mode='json') + + # simple heuristic: if returns timeout structure (bytes) keep original + try: + # enc_json = json.loads(encoder_info) + enc_json = request_dict + except Exception: + # try alternative endpoint if first not json (maybe 404 HTML) + alt_text = await node_manager.generate(request_dict, encoder_url, '/v1/chat/completions') + try: + enc_json = json.loads(alt_text) + # encoder_response_text = alt_text + except Exception: + logger.error('Encoder stage failed: cannot parse JSON; skip encoder stage') + enc_json = None + node_manager.post_call(encoder_url, enc_start) + if enc_json and isinstance(enc_json, dict) and 'encoder_result' in enc_json: + # Replace messages with encoder returned (likely empty) to avoid double encoding + request_dict['messages'] = enc_json.get('messages', []) + request_dict['encoder_result'] = enc_json['encoder_result'] + else: + logger.warning('Encoder response lacks encoder_result, skip passing encoder_result.') + logger.info(f'Post-encoder request dict: {request_dict}') + # 2. Hybrid (LLM) generation stage + node_url = node_manager.get_node_url(request.model, EngineRole.Hybrid) + if not node_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'LLM stage dispatched to {node_url}' + (f' (after encoder {encoder_url})' if encoder_url else '')) start = node_manager.pre_call(node_url) if request.stream is True: response = node_manager.stream_generate(request_dict, node_url, '/v1/chat/completions') background_task = node_manager.create_background_tasks(node_url, start) return StreamingResponse(response, background=background_task) else: - response = await node_manager.generate(request_dict, node_url, '/v1/chat/completions') + response_text = await node_manager.generate(request_dict, node_url, '/v1/chat/completions') node_manager.post_call(node_url, start) - return JSONResponse(json.loads(response)) + try: + return JSONResponse(json.loads(response_text)) + except Exception: + logger.error('Failed to parse LLM response JSON, returning raw text') + return JSONResponse({'raw': response_text}) elif node_manager.serving_strategy == ServingStrategy.DistServe: request_dict = request.model_dump() @@ -621,6 +705,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if not node_manager.dummy_prefill: if not node_manager.pd_connection_pool.is_connected(p_url, d_url): + # FIXME: here perform connections! we need to add similar logic for encode connect + # currently we connect and warmup manually through /distserve/connection_warmup await node_manager.pd_connection_pool.connect( PDConnectionMessage( p_url=p_url, @@ -662,6 +748,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') +# TODO: also change to /v1/completions, similar to /v1/chat/completions @app.post('/v1/completions', dependencies=[Depends(check_api_key)]) async def completions_v1(request: CompletionRequest, raw_request: Request = None): """Completion API similar to OpenAI's API. diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index a784e67e74..a51abd3628 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -7,6 +7,7 @@ from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig from lmdeploy.model import BaseChatTemplate +from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.utils import get_logger, try_import_deeplink from lmdeploy.vl.engine import ImageEncoder @@ -38,6 +39,15 @@ def __init__(self, 'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501 ) + # if the server started as encoder, we replace with mm engine + # TODO: find a way to disable LLM engine initialization and weight loading + if self.backend_config.role == EngineRole.Encoder: + from lmdeploy.multimodal.engine.engine import MultiModalEngine + self.engine = MultiModalEngine.from_pretrained(pretrained_model_name_or_path=model_path, + chat_template=self.chat_template, + tokenizer=self.tokenizer, + engine_config=backend_config) + @classmethod def _convert_prompts(cls, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]]): """Convert prompts to openai GPT4V format."""