From c9281d52e3d6063b5b6b8edf472945311b05ad22 Mon Sep 17 00:00:00 2001 From: zxy Date: Wed, 3 Sep 2025 15:36:39 +0800 Subject: [PATCH 01/11] initial attempts to add encoder results --- lmdeploy/messages.py | 3 +- lmdeploy/pytorch/disagg/conn/protocol.py | 11 +++ lmdeploy/pytorch/disagg/conn/proxy_conn.py | 5 + lmdeploy/pytorch/engine/cache_engine.py | 9 +- lmdeploy/pytorch/engine/engine.py | 96 ++++++++++++++++++-- lmdeploy/pytorch/engine/engine_instance.py | 1 + lmdeploy/pytorch/messages.py | 10 +- lmdeploy/pytorch/paging/scheduler.py | 101 ++++++++++++++++++++- lmdeploy/serve/openai/api_server.py | 10 +- lmdeploy/serve/openai/protocol.py | 2 + 10 files changed, 235 insertions(+), 13 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 808c47737c..a9a627204a 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -8,7 +8,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest, EncoderResult from .tokenizer import Tokenizer from .utils import get_logger @@ -115,6 +115,7 @@ class GenerationConfig: with_cache: bool = False preserve_cache: bool = False migration_request: Optional[MigrationRequest] = None + encoder_result: Optional[EncoderResult] = None def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to diff --git a/lmdeploy/pytorch/disagg/conn/protocol.py b/lmdeploy/pytorch/disagg/conn/protocol.py index aa47789497..3853974bf1 100644 --- a/lmdeploy/pytorch/disagg/conn/protocol.py +++ b/lmdeploy/pytorch/disagg/conn/protocol.py @@ -77,6 +77,17 @@ class DistServeConnectionResponse(BaseModel): status: DistServeConnectionStatus +class EncoderResult(BaseModel): + + token_ids : List[int] # FIXME, why we need this? + image_mask : List[int] + + protocol: MigrationProtocol # RDMA + remote_engine_id: str # encoder engine id + remote_session_id: int # specify encoder cache to free + remote_block_ids: List[int] # encoder multi-modal cache region + + class MigrationRequest(BaseModel): protocol: MigrationProtocol diff --git a/lmdeploy/pytorch/disagg/conn/proxy_conn.py b/lmdeploy/pytorch/disagg/conn/proxy_conn.py index a07d281248..a740c28284 100644 --- a/lmdeploy/pytorch/disagg/conn/proxy_conn.py +++ b/lmdeploy/pytorch/disagg/conn/proxy_conn.py @@ -119,6 +119,7 @@ def unshelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int): self.migration_session_shelf[conn_key].remove(session_id) async def connect(self, conn_req: PDConnectionMessage): + # perform connection here async def get_engine_config(server_endpoint): async with self.conn_sem: @@ -147,6 +148,7 @@ async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) timeout=self.aiotimeout, ) as resp: result = await resp.json() + print(f'p2p_connect response: {result}') return DistServeConnectionResponse.model_validate(result) async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): @@ -161,6 +163,7 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): assert prefill_engine_config.tp_size == decode_engine_config.tp_size # Step 2. Construct Initialize Configuration + print(f'check conn_req: {conn_req}') prefill_init_req = DistServeInitRequest( protocol=conn_req.protocol, local_engine_id=conn_req.p_url, @@ -183,6 +186,8 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req) decode_init_resp = await p2p_initialize(conn_req.d_url, decode_init_req) + print(f'=> p2p init, prefill_init_resp: \n{prefill_init_resp}\n') + print(f'=> p2p init, decode_init_resp: \n{decode_init_resp}\n') # Step 3. Connection prefill_endpoint_conn_reqs = DistServeConnectionRequest( protocol=conn_req.protocol, diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index d8ec198349..a67f69ef94 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -317,7 +317,7 @@ def get_cache_block_size(cls, total = num_layers * (mem_key_block + mem_value_block) return total - """ Metheds for PD Disaggregation Begin. """ + """ Methods for PD Disaggregation Begin. """ def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistServeKVTransferEndpointInfo: if not self.migration_backend_impl: @@ -383,4 +383,9 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote batch=assignment_batch, )) - """ Metheds for PD Disaggregation End. """ + async def ep_migrate(self): + # TODO, implement actual EP migration logic here + # TODO, we may consider a seperate MM cache, may not exactly be here + pass + + """ Methods for PD Disaggregation End. """ diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 8a8452b03d..314964f39a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -398,6 +398,8 @@ def __init__(self, # for PD Disaggregation # For migrating prefill request to decode engine self.migration_event: asyncio.Event = None + # For encoder result migration + self.ep_migration_event: asyncio.Event = None # For backpressure prefill request when cache is full self.perfill_watermark_event: asyncio.Event = None @@ -568,13 +570,19 @@ def _on_add_message(self, reqs: List[Request], **kwargs): logger.warning('Vision encoder has not been loaded, multimodal inputs will be ignored.') continue - result = self.input_processor.preprocess_input(input_ids, input_multimodals) + # FIXME, here if we deploy MLLM, will invoke input_processor to process multimodal data + # but now if encoder_result is detected, need to skip preprocess + if req_data.get('encoder_result') is None: + result = self.input_processor.preprocess_input(input_ids, input_multimodals) + input_ids = result.input_ids + input_multimodals = result.input_multimodals - input_ids = result.input_ids - input_multimodals = result.input_multimodals - - req_data['token_ids'] = input_ids - req_data['input_multimodals'] = input_multimodals + req_data['token_ids'] = input_ids + req_data['input_multimodals'] = input_multimodals + else: + # ignore multimodal inputs + req_data['input_multimodals'] = None + logger.info('Have encoder result, try to fetch from encode instance') if len(valid_reqs) > 0: self._add_message(valid_reqs) @@ -606,6 +614,9 @@ def __update_max_new_tokens(msg): return_logits = sampling_param.out_logits if len(sess.sequences) == 0: migration_request = req.data.get('migration_request') + encoder_result = req.data.get('encoder_result') + print(f'=> add msg, migration_request {migration_request}') + print(f'=> add msg, encoder_result {encoder_result}') assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') sess.add_sequence(req.data['token_ids'], sampling_param=sampling_param, @@ -614,6 +625,7 @@ def __update_max_new_tokens(msg): multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings', ), migration_request=migration_request, + encoder_result=encoder_result, resp_cache=req.data.get('with_cache'), preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) @@ -622,6 +634,13 @@ def __update_max_new_tokens(msg): if migration_request: self.scheduler._set_message_status(msg, MessageStatus.WAITING_MIGRATION) self.migration_event.set() + # if have encoder results here, skip encoding, directly proceed to prefill + if encoder_result: + print(f'set waiting EP migration!!!') + self.scheduler._set_message_status(msg, MessageStatus.WAITING_EP_MIGRATION) + self.ep_migration_event.set() + + # FIXME, seems we need to care about tokens ids, otherwise, the token ids processed will not contain the image token id as place holder else: msg = next(iter(sess.sequences.values())) msg.update_token_ids( @@ -684,6 +703,15 @@ def __has_values(input_multimodals): return True return False + has_encoder_result = any([msg.encoder_result is not None for msg in messages]) + # FIXME, suppose we already have the image feature migrated to local cache blocks + # The only thing left here is to create an indexing tensor to tell the model where the + # image tokens are in the input sequence. + if has_encoder_result: + # we need to determine input_embeddings here, but a fake one as place holder? for prefill instance to allocate mem? + # second one is the image mask, we can directly fetch from encoder_result + pass + has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages]) if has_embedding: has_embedding = any([len(msg.input_embeddings) > 0 for msg in messages]) @@ -786,6 +814,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): model_inputs.history_cross_length = history_cross_length # vision inputs + # FIXME, handle vision inputs differently, directly migrate feature from encoder_results vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs) model_inputs.vision_inputs = vision_model_inputs @@ -1105,6 +1134,50 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event # release coroutine for decoding await asyncio.sleep(.5) + @torch.inference_mode() + async def _async_loop_ep_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): + """Async loop for encoder-prefill migration.""" + while True: + ep_migration_running = self.scheduler._schedule_ep_migration() + if not ep_migration_running and not self.scheduler.has_ep_migration_waiting(): + await self.ep_migration_event.wait() + elif ep_migration_running: + self.ep_migration_event.clear() + for msg in ep_migration_running: + print(f'fake migration here') + migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] + ep_migration_request = msg.encoder_result + encoder_block_ids = ep_migration_request.remote_block_ids + prefill_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg)) + + # assert len(encoder_block_ids) == len(prefill_block_ids), ( + # f'#encoder block ids ({len(encoder_block_ids)}) must equal to ' + # f'#prefill block ids ({len(prefill_block_ids)}) ' + # f'all id length: {len(msg.num_token_ids)}') + # migration_execution_requests.append(( + # ep_migration_request.remote_engine_id, + # list(zip(encoder_block_ids, prefill_block_ids)), + # )) + # migration_inputs = MigrationExecutionBatch(protocol=ep_migration_request.protocol, + # requests=migration_execution_requests) + # logger.info(f'migrating encoder features for session: {msg.session_id} begin') + # await self.executor.migrate(migration_inputs) + # logger.info(f'migrating encoder features for session: {msg.session_id} done') + # await self.engine_conn.zmq_send(remote_engine_id=ep_migration_request.remote_engine_id, + # remote_session_id=ep_migration_request.remote_session_id) + + # After migration, the sequences are ready for prefill. We change their status to WAITING + # later it will be scheduled by self.scheduler.schedule_prefill() and proceed to prefill stage + self.scheduler.lock_running_ep_migration(ep_migration_running) + for msg in ep_migration_running: + self.scheduler._set_message_status(msg, MessageStatus.WAITING) + self.scheduler.unlock_running_ep_migration(ep_migration_running) + + has_runable_event.set() + else: + # release coroutine for other tasks + await asyncio.sleep(.5) + @torch.inference_mode() async def _async_loop_main( self, @@ -1129,6 +1202,7 @@ async def _async_loop_main( forward_event.clear() scheduler.collect_migration_done() + scheduler.collect_ep_migration_done() forward_inputs, next_running = await inputs_maker.send_next_inputs() if next_running is None: # TODO (JimyMa): add watermark check event instead of async sleep. @@ -1151,6 +1225,7 @@ async def _async_loop_main( # pre-forward before get last token if idx == num_loops - 1: scheduler.collect_migration_done() + scheduler.collect_ep_migration_done() forward_inputs, next_running = await inputs_maker.prefetch_next_inputs() # send output @@ -1217,6 +1292,7 @@ async def async_loop(self): # migration task self.migration_event = asyncio.Event() + self.ep_migration_event = asyncio.Event() logger.info('Starting executor.') self.executor.start(forward_event) @@ -1245,6 +1321,14 @@ async def async_loop(self): ) loop_tasks.append(loop_migration) + # TODO: modify proxy, add encoder role, only create this coroutine when in EPD mode + logger.info('Starting async task EPMigrationLoop.') + loop_ep_migration = event_loop.create_task( + self._async_loop_ep_migration(resp_que, has_runable_event=has_runable_event), + name='MainLoopEPMigration', + ) + loop_tasks.append(loop_ep_migration) + # binding done callback self._add_loop_tasks_done_callback(loop_tasks) self._loop_main = loop_main diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 041d8a042e..102b92d755 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -139,6 +139,7 @@ async def async_stream_infer(self, adapter_name=adapter_name, input_multimodals=multimodal, migration_request=gen_config.migration_request, + encoder_result=gen_config.encoder_result, with_cache=gen_config.with_cache, preserve_cache=gen_config.preserve_cache, ) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index c21db48eab..94c0f2a3e2 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,7 +8,7 @@ from torch import Tensor from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest, EncoderResult from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger @@ -152,6 +152,11 @@ class MessageStatus(enum.Enum): MIGRATION_LOCKED = enum.auto() MIGRATION_DONE = enum.auto() + WAITING_EP_MIGRATION = enum.auto() # waiting for encoder => prefill migration + RUNNING_EP_MIGRATION = enum.auto() # running encoder => prefill migration + EP_MIGRATION_LOCKED = enum.auto() # locked during encoder => prefill migration + EP_MIGRATION_DONE = enum.auto() # done encoder => prefill migration + _SEQ_COUNT = 0 @@ -236,6 +241,7 @@ def add_sequence(self, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" @@ -260,6 +266,7 @@ def add_sequence(self, history_multimodals=HistoryMultiModals(multimodals), return_logits=return_logits, migration_request=migration_request, + encoder_result=encoder_result, resp_cache=resp_cache, preserve_cache=preserve_cache, ) @@ -472,6 +479,7 @@ class SchedulerSequence: migration_request: Optional[MigrationRequest] = None resp_cache: bool = False preserve_cache: bool = False + encoder_result: Optional[EncoderResult] = None # For logging engine_events: List[EngineEvent] = field(default_factory=list) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 0be35ab9be..db18fe5a70 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -94,6 +94,24 @@ def migration_done(self): seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) return list(seq_map.values()) + @property + def waiting_ep_migration(self): + """Get waiting sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_EP_MIGRATION) + return list(seq_map.values()) + + @property + def running_ep_migration(self): + """Get running sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_EP_MIGRATION) + return list(seq_map.values()) + + @property + def ep_migration_done(self): + """Get migration done sequence.""" + seq_map = self.seq_manager.get_sequences(MessageStatus.EP_MIGRATION_DONE) + return list(seq_map.values()) + def build_eviction_helper(self, eviction_type: str): if eviction_type == 'copy': logger.warning('`copy` eviction has been deprecated, ' @@ -178,6 +196,48 @@ def _reorder_migrating(): return running_migration + @logging_timer('ScheduleEPMigration', logger) + def _schedule_ep_migration(self): + running_ep_migration: SeqList = [] + migrating_token_count = 0 + + def _to_running(seq: SchedulerSequence): + """To running.""" + seq.status = MessageStatus.RUNNING_EP_MIGRATION + running_ep_migration.append(seq) + nonlocal migrating_token_count + migrating_token_count += seq.num_token_ids + + def __evict_for_seq(seq: SchedulerSequence, waiting): + """Evict until can append.""" + from itertools import chain + + hanging = reversed(self.hanging) + waiting = reversed(waiting) + evictable = list(chain(hanging, waiting)) + return self.eviction_helper.evict_for_seq(seq, evictable, 0) + + def _reorder_migrating(): + """Reorder waiting.""" + return sorted(self.waiting_ep_migration, key=lambda seq: seq.arrive_time) + + waiting_ep_migration = _reorder_migrating() + print(f'=> check waiting EP migration: {waiting_ep_migration}') + + max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() + while len(waiting_ep_migration) > 0 and len(running_ep_migration) < max_batches: + seq = waiting_ep_migration.pop(0) + self.block_trie.match(waiting_ep_migration) + if not __evict_for_seq(seq, waiting_ep_migration): + break + + # allocate session memory + self.block_manager.allocate(seq) + _to_running(seq) + + print(f'=> check running EP migration: {running_ep_migration}') + return running_ep_migration + @logging_timer('SchedulePrefilling', logger) def _schedule_prefill(self): """Schedule for prefilling.""" @@ -339,7 +399,8 @@ def end_session(self, session_id: int): def has_unfinished(self): """Check if there are any unfinished message.""" - return self.has_running() or self.has_waiting() or self.has_migration_done() + # return self.has_running() or self.has_waiting() or self.has_migration_done() + return self.has_running() or self.has_waiting() or self.has_migration_done() or self.has_ep_migration_done() def has_running(self): return self.num_running() > 0 @@ -359,6 +420,15 @@ def has_migration_waiting(self): def has_migration_done(self): return self.num_migration_done() > 0 + def has_ep_migration_running(self): + return self.num_ep_migration_running() > 0 + + def has_ep_migration_waiting(self): + return self.num_ep_migration_waiting() > 0 + + def has_ep_migration_done(self): + return self.num_ep_migration_done() > 0 + def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] @@ -391,6 +461,18 @@ def num_migration_waiting(self): """Num waiting.""" return self.seq_manager.num_sequences(MessageStatus.WAITING_MIGRATION) + def num_ep_migration_running(self): + """Num EP migration running.""" + return self.seq_manager.num_sequences(MessageStatus.RUNNING_EP_MIGRATION) + + def num_ep_migration_done(self): + """Num EP migration done.""" + return self.seq_manager.num_sequences(MessageStatus.EP_MIGRATION_DONE) + + def num_ep_migration_waiting(self): + """Num EP migration waiting.""" + return self.seq_manager.num_sequences(MessageStatus.WAITING_EP_MIGRATION) + def num_locked(self): """Num locked.""" return self.seq_manager.num_sequences(MessageStatus.LOCKED) @@ -418,11 +500,28 @@ def unlock_running_migration(self, locked: SeqList): if seq.status == MessageStatus.MIGRATION_LOCKED: self._set_message_status(seq, MessageStatus.MIGRATION_DONE) + def lock_running_ep_migration(self, running: SeqList): + """Lock running EP migration sequence.""" + for seq in running: + if seq.status == MessageStatus.RUNNING_EP_MIGRATION: + self._set_message_status(seq, MessageStatus.EP_MIGRATION_LOCKED) + + def unlock_running_ep_migration(self, locked: SeqList): + """Unlock running EP migration.""" + for seq in locked: + if seq.status == MessageStatus.EP_MIGRATION_LOCKED: + self._set_message_status(seq, MessageStatus.EP_MIGRATION_DONE) + def collect_migration_done(self): migration_done = self.migration_done for seq in migration_done: self._set_message_status(seq, MessageStatus.RUNNING) + def collect_ep_migration_done(self): + ep_migration_done = self.ep_migration_done + for seq in ep_migration_done: + self._set_message_status(seq, MessageStatus.RUNNING) + @property def schedule_metrics(self): return ScheduleMetrics( diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 57800467d3..d23c6b23ea 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -28,7 +28,7 @@ from lmdeploy.pytorch.disagg.config import DistServeEngineConfig from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest, - MigrationRequest) + MigrationRequest, EncoderResult) from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501 from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice, @@ -354,9 +354,14 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque migration_request = json_request.pop('migration_request', None) with_cache = json_request.pop('with_cache', False) preserve_cache = json_request.pop('preserve_cache', False) + encoder_result = json_request.pop('encoder_result', None) if migration_request: migration_request = MigrationRequest.model_validate(migration_request) - + if encoder_result: + encoder_result = EncoderResult.model_validate(encoder_result) + print(f'=> api server, migration_request: \n{migration_request}\n') + print(f'=> api server, encoder_result: \n{encoder_result}\n') + # import pdb; pdb.set_trace() if request.session_id == -1: VariableInterface.session_id += 1 request.session_id = VariableInterface.session_id @@ -414,6 +419,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, + encoder_result=encoder_result, with_cache=with_cache, preserve_cache=preserve_cache, ) diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index cd59f9d246..b3916f4557 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -6,6 +6,7 @@ import shortuuid from pydantic import BaseModel, Field +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult class ErrorResponse(BaseModel): @@ -148,6 +149,7 @@ class ChatCompletionRequest(BaseModel): min_new_tokens: Optional[int] = Field(default=None, examples=[None]) min_p: float = 0.0 enable_thinking: Optional[bool] = None + encoder_result: Optional[EncoderResult] = None class FunctionCall(BaseModel): From 7f03f0ed63ca3df3c832bba8476868eb89dd2b8f Mon Sep 17 00:00:00 2001 From: zxy Date: Wed, 24 Sep 2025 17:25:31 +0800 Subject: [PATCH 02/11] fix api after merge --- lmdeploy/messages.py | 2 +- lmdeploy/pytorch/strategies/ar/sequence.py | 3 ++- lmdeploy/pytorch/strategies/base/sequence.py | 3 ++- lmdeploy/pytorch/strategies/dllm/sequence.py | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 2c018ed5b2..5869a4a399 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -8,7 +8,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest, EncoderResult +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest from .tokenizer import Tokenizer from .utils import get_logger diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index 91a3335f18..fb73ee3996 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -5,7 +5,7 @@ from torch import Tensor -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest from lmdeploy.pytorch.engine.model_agent import BatchedOutputs from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, SchedulerSequence, SchedulerSession, UpdateTokenMode, _to_ndarray) @@ -81,6 +81,7 @@ def make_sequence(self, sampling_param: 'SamplingParam' = None, adapter_name: str = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Make sequence.""" diff --git a/lmdeploy/pytorch/strategies/base/sequence.py b/lmdeploy/pytorch/strategies/base/sequence.py index 408a3cc15e..ea8f533700 100644 --- a/lmdeploy/pytorch/strategies/base/sequence.py +++ b/lmdeploy/pytorch/strategies/base/sequence.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest if TYPE_CHECKING: from lmdeploy.pytorch.engine.model_agent import BatchedOutputs @@ -19,6 +19,7 @@ def make_sequence(self, sampling_param: 'SamplingParam' = None, adapter_name: str = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Make sequence.""" diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py index ab004a2b63..06be439b87 100644 --- a/lmdeploy/pytorch/strategies/dllm/sequence.py +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -7,7 +7,7 @@ from torch import Tensor from lmdeploy.pytorch import consts -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationRequest from lmdeploy.pytorch.engine.model_agent import BatchedOutputs from lmdeploy.pytorch.messages import (HistoryTokenIds, InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, SchedulerSession, UpdateTokenMode, _to_ndarray) @@ -206,6 +206,7 @@ def make_sequence(self, sampling_param: 'SamplingParam' = None, adapter_name: str = None, migration_request: Optional[MigrationRequest] = None, + encoder_result: Optional[EncoderResult] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequenceDLLM': """Make sequence.""" From ccf122f5f6a59b0256b3c19526c302d3a20cdaa0 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 26 Sep 2025 11:13:42 +0800 Subject: [PATCH 03/11] fix --- lmdeploy/pytorch/strategies/ar/sequence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index fb73ee3996..ca30517faa 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -90,6 +90,7 @@ def make_sequence(self, sampling_param=sampling_param, adapter_name=adapter_name, migration_request=migration_request, + encoder_result=encoder_result, resp_cache=resp_cache, preserve_cache=preserve_cache) From 3a21d8f543af73930578977108aac70cebd67ade Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 9 Oct 2025 17:06:15 +0800 Subject: [PATCH 04/11] add encoder migration, integrate and run through the pipeline --- 0_encode.sh | 1 + 0_pd.sh | 12 + 0_proxy.sh | 40 ++ 0_vl_oai.py | 386 ++++++++++++++++++ .../pytorch/backends/cuda/graph_runner.py | 2 + lmdeploy/pytorch/backends/graph_runner.py | 2 + lmdeploy/pytorch/disagg/config.py | 1 + lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py | 323 +++++++++++++++ lmdeploy/pytorch/disagg/conn/protocol.py | 12 +- lmdeploy/pytorch/disagg/messages.py | 9 + lmdeploy/pytorch/engine/cache_engine.py | 54 ++- .../pytorch/engine/encoder_cache_engine.py | 140 +++++++ lmdeploy/pytorch/engine/engine.py | 48 +-- .../pytorch/engine/executor/ray_executor.py | 4 + .../pytorch/engine/executor/uni_executor.py | 4 + lmdeploy/pytorch/engine/model_agent.py | 1 + lmdeploy/pytorch/model_inputs.py | 4 + lmdeploy/pytorch/models/internvl3_hf.py | 61 ++- lmdeploy/serve/proxy/proxy.py | 94 ++++- 19 files changed, 1144 insertions(+), 54 deletions(-) create mode 100644 0_encode.sh create mode 100644 0_pd.sh create mode 100644 0_proxy.sh create mode 100644 0_vl_oai.py create mode 100644 lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py create mode 100644 lmdeploy/pytorch/engine/encoder_cache_engine.py diff --git a/0_encode.sh b/0_encode.sh new file mode 100644 index 0000000000..d897bba09c --- /dev/null +++ b/0_encode.sh @@ -0,0 +1 @@ +python 0_vl_oai.py diff --git a/0_pd.sh b/0_pd.sh new file mode 100644 index 0000000000..2ebaae1cf1 --- /dev/null +++ b/0_pd.sh @@ -0,0 +1,12 @@ +model_path="/mnt/137_nvme3/interns1-mini-remote" + + +CUDA_VISIBLE_DEVICES=2 lmdeploy serve api_server \ + $model_path \ + --server-port 23334 \ + --role Hybrid \ + --proxy-url http://0.0.0.0:8001 \ + --tp 1 \ + --backend pytorch \ + --disable-vision-encoder \ + --log-level INFO diff --git a/0_proxy.sh b/0_proxy.sh new file mode 100644 index 0000000000..46e78d276c --- /dev/null +++ b/0_proxy.sh @@ -0,0 +1,40 @@ +lmdeploy serve proxy --server-name 0.0.0.0 --server-port 8001 --routing-strategy "min_expected_latency" --serving-strategy Hybrid --log-level DEBUG + +curl -X POST http://0.0.0.0:8001/distserve/connection_warmup + +curl http://0.0.0.0:8001/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "/mnt/137_nvme3/interns1-mini-remote", + "messages": [ + { + "role": "user", + "content": "Hello! How are you?" + } + ] + }' + + +curl http://0.0.0.0:8001/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "/mnt/137_nvme3/interns1-mini-remote", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg" + } + } + ] + } + ], + "max_tokens": 200 + }' diff --git a/0_vl_oai.py b/0_vl_oai.py new file mode 100644 index 0000000000..3a560047fc --- /dev/null +++ b/0_vl_oai.py @@ -0,0 +1,386 @@ +import asyncio +import logging +import os +from contextlib import asynccontextmanager +from dataclasses import asdict, dataclass +from threading import Lock +from typing import Dict, List, Optional + +import torch +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse, Response + +from lmdeploy import Tokenizer +from lmdeploy.archs import get_model_arch +from lmdeploy.model import ChatTemplateConfig, best_match_model +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, + DistServeConnectionResponse, DistServeConnectionStatus, + DistServeEngineEndpointInfo, DistServeInitRequest, + DistServeInitResponse, MigrationProtocol) +from lmdeploy.pytorch.engine.encoder_cache_engine import EncoderCacheEngine +from lmdeploy.serve.openai.launch_server import get_host_ip +from lmdeploy.serve.openai.protocol import ChatCompletionRequest +from lmdeploy.vl.model.internvl3_hf import InternVL3VisionModel +from lmdeploy.vl.utils import load_image + +os.environ['CUDA_VISIBLE_DEVICES'] = '1' + +# --- 1. 全局模型变量 --- +model_instance: InternVL3VisionModel = None # type: ignore +migration_backend_impl: Optional[MigrationBackendImpl] = None +model_path = '/mnt/137_nvme3/interns1-mini-remote' +SERVER_PORT = 8086 +chat_template_name = best_match_model(model_path.lower()) +# chat_template_config = ChatTemplateConfig(chat_template_name) +chat_template_config = ChatTemplateConfig(model_name=chat_template_name, model_path=model_path) +chat_template = chat_template_config.chat_template +tokenizer = Tokenizer(model_path) +# encoder_url = f"http://{get_host_ip()}:{SERVER_PORT}" +encoder_url = f'http://0.0.0.0:{SERVER_PORT}' +# 初始化 Cache Engine 相关变量 +cache_engine_instance: EncoderCacheEngine = None # type: ignore +NUM_GPU_BLOCKS = 128 +free_blocks: List[int] = [] +session_blocks: Dict[int, List[int]] = {} +session_counter = 0 +block_manager_lock = Lock() # 线程锁,用于安全地分配和释放块 + + +def get_model_list(): + return [model_path] + + +# --- 2. 生命周期事件处理器 --- + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动时的事件 + global model_instance, cache_engine_instance, free_blocks + logger = logging.getLogger('uvicorn.error') + logger.setLevel(logging.INFO) + logger.info('模型加载中,请稍候...') + try: + cfg = get_model_arch(model_path)[1] + kwargs = dict(model_path=model_path, with_llm=False, max_memory=None, hf_config=cfg, backend='pytorch') + model_instance = InternVL3VisionModel(**kwargs) + model_instance.build_model() + model_instance.build_preprocessor() + logger.info('✅ 模型加载成功!服务器已准备就绪。') + except Exception as e: + logger.error(f'❌ 模型加载失败: {e}', exc_info=True) + raise RuntimeError(f'模型初始化失败: {e}') from e + + # TODO MigrationBackendImpl () + + # TODO 增加 memory 页表注册 + logger.info('正在初始化 Cache Engine...') + try: + # 实例化 CacheEngine + cache_engine_instance = EncoderCacheEngine(NUM_GPU_BLOCKS) + + # 初始化空闲块列表 + free_blocks = list(range(NUM_GPU_BLOCKS)) + logger.info(f'✅ Cache Engine 初始化成功,总共 {NUM_GPU_BLOCKS} 个缓存块。') + + except Exception as e: + logger.error(f'❌ Cache Engine 初始化失败: {e}', exc_info=True) + raise RuntimeError(f'Cache Engine 初始化失败: {e}') from e + + # TODO 向 proxy 发送node add + try: + import requests + engine_role = EngineRole.Encoder.value + url = 'http://0.0.0.0:8001/nodes/add' + data = {'url': f'http://0.0.0.0:{SERVER_PORT}', 'status': {'models': get_model_list(), 'role': engine_role}} + headers = {'accept': 'application/json', 'Content-Type': 'application/json'} + response = requests.post(url, headers=headers, json=data) + + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail=response.text) + else: + logger.info('✅ 服务注册成功!') + except Exception as e: + logger.error(f'Service registration failed: {e}') + # TODO p2p initialize(warm up) + # /nvme2/share/linbinbin1/src/lmdeploy-encoder/lmdeploy/serve/openai/api_server.py PD DIs + + # TODO p2p conn + + yield # 应用运行期间 + + # 关闭时的事件(如果需要清理资源) + logger.info('🔄 正在关闭服务器...') + del model_instance + torch.cuda.empty_cache() + logger.info('模型资源已释放。') + + +# --- 3. 初始化 FastAPI 应用 --- +app = FastAPI(title='InternVL Vision Model Server (Arrow Edition)', + description='一个用于通过 InternVL3 模型为图片数组提取特征张量,并使用 Apache Arrow 高效返回结果的 API', + version='1.2.0', + lifespan=lifespan) +logger = logging.getLogger('uvicorn.error') +logger.setLevel(logging.INFO) + + +# --- 4. 辅助函数 --- +def find_forward_content(output: list) -> list: + for item in output: + if isinstance(item, dict) and item.get('role') == 'forward': + return item.get('content', []) + return [] + + +async def async_convert_to_pil_images(messages: List[Dict]) -> List[Dict]: + """Scan the provided messages to find image URLs or base64-encoded image + data. Loads the images into Pillow image objects. + + Args: + messages (List[Dict]): a user request of GPT4V message format + """ + if isinstance(messages, Dict): + messages = [messages] + assert isinstance(messages, List) + + out_messages = [None] * len(messages) + + def _inner_call(i, in_messages, out_messages): + role = in_messages[i]['role'] + content = in_messages[i]['content'] + assert role in ['system', 'user', 'assistant'], \ + f'unsupported role "{role}"' + if role != 'user' or isinstance(content, str): + # the content is a user's prompt or an assistant's prompt, + # returning it directly + out_messages[i] = in_messages[i] + return + # the role is a user and the content is a list, in which there + # might be image_url or image_data + assert isinstance(content, List) + message = dict(role=role, content=[]) + for item in content: + # image url or base64-encoded image data + if item['type'] == 'image_url': + """ + convert the following item: + { + 'type': 'image_url', + 'image_url': { + 'url': 'image url or base64-encoded image data', + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_url'].copy() + try: + url = data.pop('url') + image = load_image(url) + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'image_data': + """ + convert the following item: + { + 'type': 'image_data', + 'image_data': { + 'data': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_data'].copy() + try: + image = data.pop('data') + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'text': + message['content'].append(item) + else: + logger.error(f'unexpected content type {message}') + out_messages[i] = message + + await asyncio.gather(*[ + asyncio.get_event_loop().run_in_executor(None, _inner_call, i, messages, out_messages) + for i in range(len(messages)) + ]) + return out_messages + + +@app.get('/health') +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@dataclass +class EncoderResult: + token_ids: List[int] + image_mask: List[int] + # MigrationRequest 中相似的字段 + protocol: MigrationProtocol # RDMA + remote_engine_id: str # 标识 encode 引擎编号 + remote_session_id: int # 用于 encode 引擎释放指定区域 + remote_block_ids: List[int] # 从 encode 引擎读取指定区域内容 + + +# --- 5. API 端点:处理图片并返回特征 --- + + +@app.post('/v1/chat/completion', summary='接收 open ai 格式的请求,并且返回给 proxy') +async def process_images(request_raw: ChatCompletionRequest = None): + if model_instance is None: + raise HTTPException(status_code=503, detail='模型正在加载或加载失败,请稍后再试。') + + request = request_raw.model_dump() + messages = await async_convert_to_pil_images(request['messages']) + results = model_instance.preprocess(messages) + # print(results) + # import pdb; pdb.set_trace() + + # prompt = chat_template.messages2prompt(messages) + # input_ids = tokenizer.encode(prompt, add_bos=True) # 只包含了文本部分 + # prompt, input_ids(包含了图片 token 序列), multi_modal + # 这个是将要返回的内容 + to_pt = model_instance.to_pytorch(results, chat_template, tokenizer, True, None, None) + image_mask = [1 if x == to_pt['multimodal'][0]['image_token_id'] else 0 for x in to_pt['input_ids']] + # 这里用来获得 image embedding + output = model_instance.forward(results) + forward_content = find_forward_content(output) + # tensor_shape = forward_content[0].shape + if not forward_content: + raise HTTPException(status_code=500, detail="无法在模型输出中找到 'forward' 内容。") + # store the image embedding to gpu cache + image_embedding = forward_content[0] + image_embedding = image_embedding.to( + torch.bfloat16 + ) # FIXME: forward() is used by turbomind, which returns float16 feature, but pytorch will return bfloat16 + print(f'image_embedding shape: {image_embedding.shape}') + print(f'image_embedding: {image_embedding}') + num_required_blocks = image_embedding.shape[0] // 256 + global session_counter + allocated_block_ids = [] + session_id = -1 + with block_manager_lock: + if len(free_blocks) < num_required_blocks: + raise HTTPException(status_code=503, detail='GPU 缓存已满,请稍后再试。') + + allocated_block_ids = [free_blocks.pop() for _ in range(num_required_blocks)] + session_counter += 1 + session_id = session_counter + session_blocks[session_id] = allocated_block_ids + print('in blocks') + print(allocated_block_ids) + print(cache_engine_instance.gpu_cache[allocated_block_ids].shape) + print(cache_engine_instance.gpu_cache[allocated_block_ids]) + try: + with torch.cuda.stream(cache_engine_instance.cache_stream): + for i in range(num_required_blocks): + src_chunk = image_embedding[i * 256:(i + 1) * 256, :] + dst_block_id = allocated_block_ids[i] + cache_engine_instance.gpu_cache[dst_block_id].copy_(src_chunk) + cache_engine_instance.cache_stream.synchronize() + except Exception as e: + # 如果拷贝失败,必须归还申请的块,防止内存泄漏 + with block_manager_lock: + free_blocks.extend(allocated_block_ids) + del session_blocks[session_id] + logger.error(f'拷贝 embedding 到缓存失败: {e}') + raise HTTPException(status_code=500, detail='缓存图像 embedding 失败。') + + # 返回内容 + + # FIXME, zhouxinyu this should not be empty + # otherwise gen config related information are lost, for instance top_p, top_k, max_new_tokens + # request['messages'] = [] + encoder_result_obj = EncoderResult( + token_ids=to_pt['input_ids'], + image_mask=image_mask, + protocol=MigrationProtocol.RDMA, + remote_engine_id=encoder_url, # encode 引擎的 url + remote_session_id=session_id, # encode 阶段的 session id + remote_block_ids=allocated_block_ids # image embedding 的 memory block id + ) + request['encoder_result'] = asdict(encoder_result_obj) + + return JSONResponse(jsonable_encoder(request)) + + +@app.post('/distserve/p2p_initialize') +async def p2p_initialize(init_request: DistServeInitRequest): + kv_eps = cache_engine_instance.p2p_initialize(init_request) + # 目前 encoder 没有 zmq 通信;返回一个假地址 + zmq_addr = f'tcp://{get_host_ip()}:65001' + resp = DistServeInitResponse( + status=DistServeConnectionStatus.SUCCESS, + engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_addr), + kvtransfer_endpoint_info=kv_eps, + ) + return JSONResponse(jsonable_encoder(resp.model_dump())) + + +@app.post('/distserve/p2p_connect') +async def p2p_connect(conn_request: DistServeConnectionRequest): + cache_engine_instance.p2p_connect( + conn_request.remote_engine_id, + conn_request.remote_kvtransfer_endpoint_info, + ) + resp = DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS) + return JSONResponse(jsonable_encoder(resp.model_dump())) + + +@app.post('/distserve/free_cache') +async def free_cache(free_req: DistServeCacheFreeRequest): + # Free allocated GPU blocks for a given session id + global free_blocks, session_blocks + sid = free_req.remote_session_id + with block_manager_lock: + blocks = session_blocks.pop(sid, []) + if blocks: + free_blocks.extend(blocks) + return JSONResponse({'success': True, 'freed_blocks': blocks if 'blocks' in locals() else []}) + + +@app.get('/distserve/engine_info') +async def engine_info(): + + response = DistServeEngineConfig(tp_size=1, + dp_size=1, + pp_size=1, + ep_size=1, + dp_rank=1, + block_size=256 * 4096, + num_cpu_blocks=0, + num_gpu_blocks=NUM_GPU_BLOCKS) + + return response.model_dump_json() + + +# --- 6. 运行服务器 --- +if __name__ == '__main__': + uvicorn.run(app, host='0.0.0.0', port=SERVER_PORT) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index deb6c66bfd..7259154fc7 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -234,12 +234,14 @@ def __call__(self, **kwargs): def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], + encoder_cache: torch.Tensor, inputs_embeds: torch.Tensor = None, context: StepContext = None, ): """Prepare inputs.""" return self.model.prepare_inputs_for_generation( past_key_values=past_key_values, + encoder_cache=encoder_cache, inputs_embeds=inputs_embeds, context=context, ) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index a88872f2bd..aea8a3b976 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -56,12 +56,14 @@ def get_logits(self, hidden_states: torch.Tensor): def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], + encoder_cache: torch.Tensor, inputs_embeds: torch.Tensor = None, context: StepContext = None, ): """Prepare inputs.""" return self.model.prepare_inputs_for_generation( past_key_values, + encoder_cache, inputs_embeds, context, ) diff --git a/lmdeploy/pytorch/disagg/config.py b/lmdeploy/pytorch/disagg/config.py index f4dd002231..f79a3718fa 100644 --- a/lmdeploy/pytorch/disagg/config.py +++ b/lmdeploy/pytorch/disagg/config.py @@ -35,6 +35,7 @@ class EngineRole(enum.Enum): Hybrid = enum.auto() Prefill = enum.auto() Decode = enum.auto() + Encoder = enum.auto() class MigrationBackend(enum.Enum): diff --git a/lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py b/lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py new file mode 100644 index 0000000000..52409a6668 --- /dev/null +++ b/lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import enum +import os +from collections import defaultdict +from typing import Dict, Optional, Set, Tuple + +import aiohttp +import requests + +from lmdeploy.logger import get_logger +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, + DistServeConnectionResponse, DistServeDropConnectionRequest, + DistServeInitRequest, DistServeInitResponse) +from lmdeploy.pytorch.disagg.messages import EPConnectionMessage + +logger = get_logger('lmdeploy') + +# Parse timeout env (string -> float) safely +_raw_timeout = os.getenv('AIOHTTP_TIMEOUT', None) +try: + AIOHTTP_TIMEOUT: Optional[float] = float(_raw_timeout) if _raw_timeout else None +except ValueError: # fallback silently and log + logger.warning(f'Invalid AIOHTTP_TIMEOUT value: {_raw_timeout}, fallback to None') + AIOHTTP_TIMEOUT = None + + +class EPConnectionStatus(enum.Enum): + Disconnected = enum.auto() + Connected = enum.auto() + Connecting = enum.auto() + + +class EPConnectionState: + """EPConnectionState (simple state holder with one event).""" + + def __init__(self, status: EPConnectionStatus, event: asyncio.Event): + self.status = status + self.event = event + + async def wait(self): + await self.event.wait() + + def set_status(self, status: EPConnectionStatus): + self.status = status + + +def get_server_api(url: str, api: str): + return f'{url}/{api}' + + +class EPConnectionPool: + """Constructing the link of E & P engine for the migration of KVCache. + + Note: we use Peer to Peer transportation in KVCache migration. + Note: Lazy link construction is supported, which perform connection + at the first LLM request. As a result, we don't need to construct + PD Communication group when start a engine server. + Note: we perform simple fault tolerance by checkpointing the session_id of a + request which is under migrating and will trigger `gc` when the decode + instanceis crushed. + TODO (JimyMa): By now, only engines with same parallel configuration can be + correctly connected. + """ + + # Maximum concurrent connections​​ + CONN_SEMAPHORE_SIZE = 2048 + + def __init__(self): + # all prefill and decode instances + # TODO (JimyMa): Maybe encoding instances + self.prefill_endpoints: Set[str] = set() + self.encode_endpoints: Set[str] = set() + + # Links of PD Connection. + self.pool: Dict[Tuple[str, str], EPConnectionState] = {} + + # put migrating session to `self.migration_session_shelf` for increasing fault tolerance + # if a session is finished, then pop it from `self.migration_session_shelf` + # if a decode instance is disconnected, then gc all blocks of these sessions in prefill instance. + # use tuple (left, right) as key to align with drop() usage + self.migration_session_shelf: Dict[Tuple[str, str], Set[int]] = defaultdict(set) + + # conn_perform handler queue + self.waiting_conn: asyncio.Queue[Tuple[EPConnectionMessage, asyncio.Event]] = asyncio.Queue() + + # conn Registry Lock + self.conn_lock = asyncio.Lock() + + # Connection Retry when failure + self.max_retry_cnt = 8 + + # trigger signal when conn request arrive. + self.conn_req_event = asyncio.Event() + + # conn initialized signal + self.initialized = False + + def reg_instance(self, role: EngineRole, endpoint: str): + if role == EngineRole.Prefill: + self.prefill_endpoints.add(endpoint) + elif role == EngineRole.Encoder: + self.encode_endpoints.add(endpoint) + else: + raise ValueError(f'Unsupported role: {role}') + + def dereg_instance(self, endpoint: str): + # Symmetric cleanup for both roles + if endpoint in self.encode_endpoints: + dropped_key = [k for k in self.pool.keys() if k[0] == endpoint] + for k in dropped_key: + self.drop(k) + self.encode_endpoints.remove(endpoint) + elif endpoint in self.prefill_endpoints: + dropped_key = [k for k in self.pool.keys() if k[1] == endpoint] + for k in dropped_key: + self.drop(k) + # TODO(JimyMa): handle side-effect by kvcache migration + self.prefill_endpoints.remove(endpoint) + + async def connect(self, conn_req: EPConnectionMessage): + + async def get_engine_config(server_endpoint): + async with self.conn_sem: + async with self.conn_sess.get( + get_server_api(server_endpoint, 'distserve/engine_info'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + # model_validate_json expects a JSON string; result is already dict + logger.info(f'engine info from {server_endpoint}: {result}') + return DistServeEngineConfig.model_validate_json(result) + + async def p2p_initialize(server_endpoint, init_request: DistServeInitRequest) -> DistServeInitResponse: + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_initialize'), + json=init_request.model_dump(mode='json'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + logger.info(f'P2P Initialize response from {server_endpoint}: {result}') + return DistServeInitResponse.model_validate(result) + + async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) -> DistServeConnectionResponse: + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_connect'), + json=conn_request.model_dump(mode='json'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + return DistServeConnectionResponse.model_validate(result) + + async def conn_worker(conn_req: EPConnectionMessage, conn_event: asyncio.Event): + # try: + link = (conn_req.e_url, conn_req.p_url) + logger.debug(f'{link} connecting...') + # Step 1. Get Remote Engine Configuration + prefill_engine_config = await get_engine_config(conn_req.p_url) + encode_engine_config = await get_engine_config(conn_req.e_url) + print(f'prefill_engine_config: {prefill_engine_config}') + print(f'encode_engine_config: {encode_engine_config}') + + # encode 的 config 大部分字段为 空 + + # Step 2. Construct Initialize Configuration + prefill_init_req = DistServeInitRequest( + protocol=conn_req.protocol, + local_engine_id=conn_req.p_url, + local_engine_config=prefill_engine_config, + remote_engine_id=conn_req.e_url, + remote_engine_config=encode_engine_config, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config, + ) + encode_init_req = DistServeInitRequest( + protocol=conn_req.protocol, + local_engine_id=conn_req.e_url, + local_engine_config=encode_engine_config, + remote_engine_id=conn_req.p_url, + remote_engine_config=prefill_engine_config, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config, + ) + + print(f'prefill_init_req: {prefill_init_req}') + print(f'encode_init_req: {encode_init_req}') + prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req) + encode_init_resp = await p2p_initialize(conn_req.e_url, encode_init_req) + + # Step 3. Connection + encode_endpoint_conn_reqs = DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.p_url, + remote_engine_endpoint_info=prefill_init_resp.engine_endpoint_info, + remote_kvtransfer_endpoint_info=prefill_init_resp.kvtransfer_endpoint_info) + prefill_endpoint_conn_reqs = DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.e_url, + remote_engine_endpoint_info=encode_init_resp.engine_endpoint_info, + remote_kvtransfer_endpoint_info=encode_init_resp.kvtransfer_endpoint_info) + print(f'encode_endpoint_conn_reqs: {encode_endpoint_conn_reqs}') + print(f'prefill_endpoint_conn_reqs: {prefill_endpoint_conn_reqs}') + await p2p_connect(conn_req.p_url, prefill_endpoint_conn_reqs) + await p2p_connect(conn_req.e_url, encode_endpoint_conn_reqs) + self.pool[link].set_status(EPConnectionStatus.Connected) + logger.debug(f'{(conn_req.e_url, conn_req.p_url)} connected') + # except Exception as e: + # self.pool[link].set_status(EPConnectionStatus.Disconnected) + # logger.error(f'ep connection error: {e}') + conn_event.set() + + async def wait_for_conn(conn_req: EPConnectionMessage, conn_event: asyncio.Event): + await self.pool[(conn_req.e_url, conn_req.p_url)].event.wait() + conn_event.set() + + async def _perform_conn(): + logger.debug('perform_conn start') + while True: + if self.waiting_conn.empty(): + await self.conn_req_event.wait() + + self.conn_req_event.clear() + + while not self.waiting_conn.empty(): + conn_req, conn_event = self.waiting_conn.get_nowait() + link = (conn_req.e_url, conn_req.p_url) + if link not in self.pool: + self.pool[link] = EPConnectionState( + EPConnectionStatus.Disconnected, + conn_event, + ) + if self.pool[link].status == EPConnectionStatus.Connecting: + asyncio.create_task(wait_for_conn(conn_req, conn_event)) + elif self.pool[link].status == EPConnectionStatus.Disconnected: + self.pool[link].set_status(EPConnectionStatus.Connecting) + asyncio.create_task(conn_worker(conn_req, conn_event)) + + if not self.initialized: + loop = asyncio.get_event_loop() + loop.create_task(_perform_conn()) + self.conn_sem = asyncio.Semaphore(self.CONN_SEMAPHORE_SIZE) + self.conn_sess = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit_per_host=256), + timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT), + ) + self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) + self.initialized = True + + print(f'EPConnectionPool connect called: {conn_req.e_url} <-> {conn_req.p_url}') + self.reg_instance(EngineRole.Encoder, conn_req.e_url) + self.reg_instance(EngineRole.Prefill, conn_req.p_url) + + cnt = 0 + while cnt < self.max_retry_cnt: + if self.is_connected(conn_req.e_url, conn_req.p_url): + return + if cnt > 0: + logger.warning(f'EP connection failure, retry cnt: {cnt}') + # simple incremental backoff + await asyncio.sleep(min(1.0, 0.2 * cnt)) + conn_event = asyncio.Event() + self.waiting_conn.put_nowait((conn_req, conn_event)) + self.conn_req_event.set() + await conn_event.wait() + cnt += 1 + async with self.conn_lock: + if (conn_req.e_url, conn_req.p_url) in self.pool: + self.pool[conn_req.e_url, conn_req.p_url].set_status(EPConnectionStatus.Disconnected) + raise TimeoutError('EPConnection Failure') + + def is_connected(self, e_url: str, p_url: str): + link = self.pool.get((e_url, p_url), None) + if not link: + return False + return link.status == EPConnectionStatus.Connected + + def drop(self, ep_key: Tuple[str, str]): + left = ep_key[0] + right = ep_key[1] + + def cache_free(server_endpoint, cache_free_request: DistServeCacheFreeRequest) -> None: + try: + requests.post(get_server_api(server_endpoint, 'distserve/free_cache'), + json=cache_free_request.model_dump(mode='json')) + except Exception as e: + logger.warning(f'error cache block free {server_endpoint, cache_free_request}. ErrorMsg: {str(e)}') + + def drop_connect(server_endpoint: str, p2p_disconnect_request: DistServeDropConnectionRequest): + try: + requests.post(get_server_api(server_endpoint, 'distserve/p2p_drop_connect'), + json=p2p_disconnect_request.model_dump(mode='json')) + except Exception as e: + logger.warning(f'error drop connect {server_endpoint, p2p_disconnect_request}. ErrorMsg: {str(e)}') + + # trigger gc + logger.warning('cache block gc triggered.') + try: + for session_id in self.migration_session_shelf[(left, right)]: + cache_free(left, DistServeCacheFreeRequest(remote_engine_id=left, remote_session_id=session_id)) + except Exception as e: + logger.warning(f'gc error, ErrorMsg: {str(e)}') + finally: + self.migration_session_shelf.pop((left, right), None) + + # trigger p2p disconnect + logger.warning('drop connection triggered.') + try: + drop_connect(left, DistServeDropConnectionRequest(engine_id=left, remote_engine_id=right)) + drop_connect(right, DistServeDropConnectionRequest(engine_id=right, remote_engine_id=left)) + except Exception as e: + logger.warning(f'p2p disconnect error, ErrorMsg: {str(e)}') + + self.pool.pop((left, right), None) + + async def close(self): + if getattr(self, 'initialized', False): + try: + await self.conn_sess.close() + except Exception as e: + logger.warning(f'EPConnectionPool close error: {e}') diff --git a/lmdeploy/pytorch/disagg/conn/protocol.py b/lmdeploy/pytorch/disagg/conn/protocol.py index 3853974bf1..54af27a0fa 100644 --- a/lmdeploy/pytorch/disagg/conn/protocol.py +++ b/lmdeploy/pytorch/disagg/conn/protocol.py @@ -79,13 +79,13 @@ class DistServeConnectionResponse(BaseModel): class EncoderResult(BaseModel): - token_ids : List[int] # FIXME, why we need this? - image_mask : List[int] + token_ids: List[int] + image_mask: List[int] - protocol: MigrationProtocol # RDMA - remote_engine_id: str # encoder engine id - remote_session_id: int # specify encoder cache to free - remote_block_ids: List[int] # encoder multi-modal cache region + protocol: MigrationProtocol + remote_engine_id: str + remote_session_id: int + remote_block_ids: List[int] class MigrationRequest(BaseModel): diff --git a/lmdeploy/pytorch/disagg/messages.py b/lmdeploy/pytorch/disagg/messages.py index 9dac0b0391..d27f873e47 100644 --- a/lmdeploy/pytorch/disagg/messages.py +++ b/lmdeploy/pytorch/disagg/messages.py @@ -38,6 +38,15 @@ class PDConnectionMessage(BaseModel): nvlink_config: Optional[DistServeNVLinkConfig] = None +class EPConnectionMessage(BaseModel): + e_url: str + p_url: str + protocol: MigrationProtocol = MigrationProtocol.RDMA + tcp_config: Optional[DistServeTCPConfig] = None + rdma_config: Optional[DistServeRDMAConfig] = None + nvlink_config: Optional[DistServeNVLinkConfig] = None + + class DistServeRegisterMRMessage(BaseModel): protocol: MigrationProtocol diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index a67f69ef94..026882eea2 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -62,6 +62,8 @@ def __init__( # Initialize the cache. self.local_gpu_cache = self.allocate_gpu_cache() self.local_cpu_cache = self.allocate_cpu_cache() + # FIXME: hardcode cache size for interns1 series + self.encoder_gpu_cache = torch.empty(size=(128, 256, 4096), dtype=torch.bfloat16, device='cuda') self.migration_backend_impl: Optional[MigrationBackendImpl] = None @@ -334,6 +336,19 @@ def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistSe offset=t.storage_offset(), length=t.numel() * t.itemsize) self.migration_backend_impl.register_memory_region(register_mr_request) + + # register memory region for encoder cache, otherwise cannot perform RDMA transfer + if self.encoder_gpu_cache.numel() > 0: + logger.info('p2p_init encoder_cache') + register_mr_request = DistServeRegisterMRMessage( + protocol=migration_init_request.protocol, + remote_engine_id=migration_init_request.remote_engine_id, + mr_key='encoder_cache', # Use the fixed mr key, same as the one in encoder_cache_engine + addr=self.encoder_gpu_cache.data_ptr(), + offset=self.encoder_gpu_cache.storage_offset(), + length=self.encoder_gpu_cache.numel() * self.encoder_gpu_cache.itemsize) + self.migration_backend_impl.register_memory_region(register_mr_request) + return DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, endpoint_info=json.dumps( self.migration_backend_impl.endpoint_info( @@ -383,9 +398,40 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote batch=assignment_batch, )) - async def ep_migrate(self): - # TODO, implement actual EP migration logic here - # TODO, we may consider a seperate MM cache, may not exactly be here - pass + async def ep_migrate(self, migration_execution_inputs: MigrationExecutionBatch): + """Handles the migration of the Multi-Modal (MM) cache.""" + if not self.migration_backend_impl: + logger.error('Migration backend is not initialized. Cannot perform EP migration.') + return + + if self.encoder_gpu_cache.numel() == 0: + logger.warning('MM GPU cache is not allocated or is empty. Skipping EP migration.') + return + + _, tokens_per_image, hidden_size = self.encoder_gpu_cache.shape + assignment_len = tokens_per_image * hidden_size * self.encoder_gpu_cache.element_size() + + assignment_batch: List[AssignmentInstruct] = [] + mr_key = 'encoder_cache' # Use the fixed mr key, same as the one in encoder_cache_engine + + for _, blocks_to_migration in migration_execution_inputs.requests: + for source_idx, target_idx in blocks_to_migration: + source_offset = source_idx * assignment_len + target_offset = target_idx * assignment_len + instruction = AssignmentInstruct(mr_key=mr_key, + target_offset=target_offset, + source_offset=source_offset, + length=assignment_len) + assignment_batch.append(instruction) + + if assignment_batch: + remote_engine_id = migration_execution_inputs.requests[0][0] + logger.debug(f'Migrating {len(assignment_batch)} MM feature blocks to {remote_engine_id}.') + await self.migration_backend_impl.p2p_migrate( + MigrationAssignment( + protocol=migration_execution_inputs.protocol, + remote_engine_id=remote_engine_id, + batch=assignment_batch, + )) """ Methods for PD Disaggregation End. """ diff --git a/lmdeploy/pytorch/engine/encoder_cache_engine.py b/lmdeploy/pytorch/engine/encoder_cache_engine.py new file mode 100644 index 0000000000..85ac364083 --- /dev/null +++ b/lmdeploy/pytorch/engine/encoder_cache_engine.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from: https://github.com/vllm-project/vllm +import json +from typing import List, Optional, Tuple + +import torch + +from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo +from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage +from lmdeploy.utils import get_logger + +from ..config import ModelConfig + +logger = get_logger('lmdeploy') + +FEATURE_BLOCK_SHAPE = (256, 4096) + + +class EncoderCacheEngine: + """Manages the memory pool for image features. + + This engine allocates and manages a contiguous block of GPU memory + to store image embeddings transferred from an encoder. It is adapted for + an encoder-LLM separated architecture. + + Args: + cache_config (CacheConfig): Configuration for the cache, such as the + number of blocks. + model_config (ModelConfig): Model configuration, used for dtype. + rank (int): Distributed rank. + tp_rank (int): Tensor parallelism rank. + world_size (int): Distributed world size. + """ + + def __init__( + self, + num_gpu_blocks: int = 128, + rank: int = 0, + tp_rank: int = 0, + world_size: int = 1, + ) -> None: + self.world_size = world_size + self.rank = rank + self.tp_rank = tp_rank + + # self.feature_dtype = torch.float16 + # FIXME: turbomind forward() returns float16, pytorch returns bfloat16 + self.feature_dtype = torch.bfloat16 + self._num_gpu_blocks = num_gpu_blocks + + self.encoder_gpu_cache = self._allocate_gpu_cache() + + self.migration_backend_impl: Optional[MigrationBackendImpl] = None + + self.cache_stream = torch.cuda.Stream() + assert self.cache_stream != torch.cuda.current_stream() + self.events = torch.cuda.Event() + + logger.debug(f'Initialize feature cache engine with {self.num_gpu_blocks} gpu blocks.') + + @property + def gpu_cache(self) -> torch.Tensor: + """The GPU feature pool tensor.""" + return self.encoder_gpu_cache + + @property + def num_gpu_blocks(self) -> int: + """Number of GPU blocks.""" + return self._num_gpu_blocks + + @staticmethod + def get_feature_block_shape() -> Tuple[int, int]: + """Get the shape of a single image feature block.""" + return FEATURE_BLOCK_SHAPE + + def _allocate_cache(self, num_blocks: int, device: torch.device) -> torch.Tensor: + """Allocate the memory pool on the specified device.""" + block_shape = self.get_feature_block_shape() + + # allocate a large contiguous tensor as the feature pool + encoder_cache = torch.empty( + size=(num_blocks, *block_shape), + dtype=self.feature_dtype, + device=device, + ) + return encoder_cache + + def _allocate_gpu_cache(self) -> torch.Tensor: + """Allocate the feature pool on the GPU.""" + return self._allocate_cache(self.num_gpu_blocks, 'cuda') + + @classmethod + def get_cache_block_size(cls, model_config: ModelConfig) -> int: + """Get the memory size in bytes of a single feature block. + + Args: + model_config (ModelConfig): The model config, used for dtype. + + Return: + int: Required memory size in bytes for one block. + """ + shape = cls.get_feature_block_shape() + dtype = model_config.dtype + + meta_tensor = torch.empty(shape, dtype=dtype, device='meta') + return meta_tensor.numel() * meta_tensor.element_size() + + """ Methods for Disaggregation Begin. """ + + def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> List[DistServeKVTransferEndpointInfo]: + if not self.migration_backend_impl: + self.migration_backend_impl: MigrationBackendImpl = MIGRATION_BACKENDS.module_dict['DLSlime']() + migration_init_request.rank = self.rank + self.migration_backend_impl.p2p_initialize(migration_init_request) + + t = self.encoder_gpu_cache + if t.numel() > 0: + register_mr_request = DistServeRegisterMRMessage( + protocol=migration_init_request.protocol, + remote_engine_id=migration_init_request.remote_engine_id, + mr_key='encoder_cache', # use fixed key + addr=t.data_ptr(), + offset=t.storage_offset(), + length=t.numel() * t.itemsize) + self.migration_backend_impl.register_memory_region(register_mr_request) + + return [ + DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, + endpoint_info=json.dumps( + self.migration_backend_impl.endpoint_info( + migration_init_request.remote_engine_id, + migration_init_request.protocol))) + ] + + def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]): + self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank]) + + """ Methods for Disaggregation End. """ diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index dfcd51c3bd..b6bede791d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -649,11 +649,9 @@ def __update_max_new_tokens(msg): self.migration_event.set() # if have encoder results here, skip encoding, directly proceed to prefill if encoder_result: - print(f'set waiting EP migration!!!') + logger.info('set waiting EP migration') self.scheduler._set_message_status(msg, MessageStatus.WAITING_EP_MIGRATION) self.ep_migration_event.set() - - # FIXME, seems we need to care about tokens ids, otherwise, the token ids processed will not contain the image token id as place holder else: msg = next(iter(sess.sequences.values())) msg.update_token_ids( @@ -716,12 +714,8 @@ def __has_values(input_multimodals): return False has_encoder_result = any([msg.encoder_result is not None for msg in messages]) - # FIXME, suppose we already have the image feature migrated to local cache blocks - # The only thing left here is to create an indexing tensor to tell the model where the - # image tokens are in the input sequence. + # FIXME: any special treatment for encoder_result? if has_encoder_result: - # we need to determine input_embeddings here, but a fake one as place holder? for prefill instance to allocate mem? - # second one is the image mask, we can directly fetch from encoder_result pass has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages]) @@ -795,6 +789,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): # model_metas model_metas = [msg.model_meta for msg in messages] + encoder_results = [msg.encoder_result for msg in messages] # create model inputs for all required fields model_inputs = ModelInputs( @@ -808,6 +803,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): max_kv_seqlen=max_kv_seqlen, sum_kv_seqlen=sum_kv_seqlen, model_metas=model_metas, + encoder_results=encoder_results, ) # adapters @@ -1092,25 +1088,29 @@ async def _async_loop_ep_migration(self, resp_que: asyncio.Queue, has_runable_ev elif ep_migration_running: self.ep_migration_event.clear() for msg in ep_migration_running: - print(f'fake migration here') + logger.info('performing ep migrations.') migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] ep_migration_request = msg.encoder_result encoder_block_ids = ep_migration_request.remote_block_ids - prefill_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg)) - - # assert len(encoder_block_ids) == len(prefill_block_ids), ( - # f'#encoder block ids ({len(encoder_block_ids)}) must equal to ' - # f'#prefill block ids ({len(prefill_block_ids)}) ' - # f'all id length: {len(msg.num_token_ids)}') - # migration_execution_requests.append(( - # ep_migration_request.remote_engine_id, - # list(zip(encoder_block_ids, prefill_block_ids)), - # )) - # migration_inputs = MigrationExecutionBatch(protocol=ep_migration_request.protocol, - # requests=migration_execution_requests) - # logger.info(f'migrating encoder features for session: {msg.session_id} begin') - # await self.executor.migrate(migration_inputs) - # logger.info(f'migrating encoder features for session: {msg.session_id} done') + + # FIXME: only test one request now, we simply use the same block ids + # ideally we should get block ids from scheduler, corresponding to the msg + prefill_block_ids = ep_migration_request.remote_block_ids + + assert len(encoder_block_ids) == len(prefill_block_ids), ( + f'#encoder block ids ({len(encoder_block_ids)}) must equal to ' + f'#prefill block ids ({len(prefill_block_ids)}) ' + f'all id length: {len(msg.num_token_ids)}') + migration_execution_requests.append(( + ep_migration_request.remote_engine_id, + list(zip(encoder_block_ids, prefill_block_ids)), + )) + migration_inputs = MigrationExecutionBatch(protocol=ep_migration_request.protocol, + requests=migration_execution_requests) + logger.info(f'migrating encoder features for session: {msg.session_id} begin') + await self.executor.ep_migrate(migration_inputs) + logger.info(f'migrating encoder features for session: {msg.session_id} done') + # TODO: we don't send free cache via zmq now, leave as future work # await self.engine_conn.zmq_send(remote_engine_id=ep_migration_request.remote_engine_id, # remote_session_id=ep_migration_request.remote_session_id) diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 327d56a5ca..47336fbda7 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -593,4 +593,8 @@ async def migrate(self, batch: MigrationExecutionBatch): jobs = (worker.migrate.remote(batch) for worker in self.workers) return await asyncio.gather(*jobs) + async def ep_migrate(self, batch: MigrationExecutionBatch): + jobs = (worker.ep_migrate.remote(batch) for worker in self.workers) + return await asyncio.gather(*jobs) + """ PD Disaggregation API Begin """ diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index 283a8cddc7..e5d41d92fe 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -117,4 +117,8 @@ async def migrate(self, batch: MigrationExecutionBatch): """KV Cache Migration.""" return await self.model_agent.cache_engine.migrate(batch) + async def ep_migrate(self, batch: MigrationExecutionBatch): + """Encoder Cache Migration.""" + return await self.model_agent.cache_engine.ep_migrate(batch) + """ PD Disaggregation API End """ diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 7b332b6f0c..3796b0e8bd 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -242,6 +242,7 @@ def model_forward( ) input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, + encoder_cache=cache_engine.encoder_gpu_cache, context=context, ) output = model(**input_dict) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index a377c9d4d6..7c7d801e3b 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -142,6 +142,7 @@ class ModelInputs: cross_length: torch.LongTensor = None history_cross_length: torch.LongTensor = None model_metas: List[Dict[str, Any]] = None + encoder_results: List[Dict[str, Any]] = None dp_meta: 'DPMeta' = None enable_microbatch: bool = False @@ -254,6 +255,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int): local_adapter_ids=self.local_adapter_ids, vision_inputs=vision_inputs, model_metas=self.model_metas, + encoder_results=self.encoder_results, cross_length=cross_length, history_cross_length=history_cross_length, ) @@ -319,6 +321,7 @@ class StepContext: cross_attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 model_metas: List[Dict[str, Any]] = None + encoder_results: List[Dict[str, Any]] = None dp_meta: DPMeta = None enable_microbatch: bool = False @@ -385,6 +388,7 @@ def new( vision_inputs=inputs.vision_inputs, kv_quant_policy=kv_quant_policy, model_metas=inputs.model_metas, + encoder_results=inputs.encoder_results, cross_seqlens=cross_seqlens, cross_kv_seqlens=cross_kv_seqlens, dp_meta=inputs.dp_meta, diff --git a/lmdeploy/pytorch/models/internvl3_hf.py b/lmdeploy/pytorch/models/internvl3_hf.py index 6e760dbeac..6cde7c0468 100644 --- a/lmdeploy/pytorch/models/internvl3_hf.py +++ b/lmdeploy/pytorch/models/internvl3_hf.py @@ -578,25 +578,42 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], + vision_embeddings: torch.Tensor = None, attn_metadata: Any = None, pixel_values: torch.Tensor = None, image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): - if inputs_embeds is None and pixel_values is not None: - # extract feature - self._mark_dynamic_once(pixel_values, [0]) - vit_embeds = self.get_image_features( - pixel_values, - self.vision_feature_layer, - self.vision_feature_select_strategy, - ) - lang_embeds = self.get_input_embeddings()(input_ids) - lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) - - inputs_embeds = lang_embeds - input_ids = None + if False: + if inputs_embeds is None and pixel_values is not None: + # extract feature + self._mark_dynamic_once(pixel_values, [0]) + vit_embeds = self.get_image_features( + pixel_values, + self.vision_feature_layer, + self.vision_feature_select_strategy, + ) + lang_embeds = self.get_input_embeddings()(input_ids) + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) + + inputs_embeds = lang_embeds + input_ids = None + else: + if inputs_embeds is None and vision_embeddings is not None and image_mask is not None: + print('Using encoder_cache as vit_embeds !!!!!') + print(f'input_ids: {input_ids.shape}') + # use cached feature + vit_embeds = vision_embeddings + lang_embeds = self.get_input_embeddings()(input_ids) + print(f'lang_embeds.shape: {lang_embeds.shape}') + print(f'vit_embeds.shape: {vit_embeds.shape}') + print(f'image_mask.shape: {image_mask.shape}') + print(f'image_mask[..., None].shape: {image_mask[..., None].shape}') + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) + + inputs_embeds = lang_embeds + input_ids = None if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError('You must specify exactly one of input_ids or inputs_embeds') @@ -614,6 +631,7 @@ def forward( def prepare_inputs_for_generation( self, past_key_values: List[List[torch.Tensor]], + encoder_cache: torch.Tensor = None, inputs_embeds: torch.Tensor = None, context: StepContext = None, ): @@ -646,10 +664,27 @@ def prepare_inputs_for_generation( inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + if not context.is_decoding: + if context.encoder_results is not None and context.encoder_results[0] is not None: + # FIXME: pick 0 index for now, should fix for batch > 1 + image_mask = context.encoder_results[0].image_mask + image_mask = torch.tensor(image_mask, device=input_ids.device, dtype=torch.bool) + remote_block_ids = context.encoder_results[0].remote_block_ids + vision_embeddings = encoder_cache[remote_block_ids] + print(f'len(remote_block_ids): {len(remote_block_ids)}') + print(f'remote_block_ids: {remote_block_ids}') + print(f'vision_embeddings.shape: {vision_embeddings.shape}') + print(f'vision_embeddings: {vision_embeddings}') + # FIXME: we need to change the input_ids here, or maybe even earlier + # since multi-modal requests input_ids has image token ids, different from the others + encoder_input_ids = context.encoder_results[0].token_ids + print(f'encoder_input_ids: {len(encoder_input_ids)}') + return dict( input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values, + vision_embeddings=vision_embeddings, attn_metadata=attn_metadata, pixel_values=pixel_values, image_mask=image_mask, diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 4f88340593..953d2c32e4 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -22,9 +22,10 @@ from pydantic import BaseModel, Field from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy +from lmdeploy.pytorch.disagg.conn.ep_proxy_conn import EPConnectionPool from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool -from lmdeploy.pytorch.disagg.messages import PDConnectionMessage +from lmdeploy.pytorch.disagg.messages import EPConnectionMessage, PDConnectionMessage from lmdeploy.serve.openai.api_server import check_api_key, create_error_response from lmdeploy.serve.openai.protocol import ModelCard # noqa: E501 from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission @@ -108,6 +109,7 @@ def __init__(self, self.migration_protocol = MigrationProtocol[migration_protocol] self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type]) self.pd_connection_pool = PDConnectionPool() + self.ep_connection_pool = EPConnectionPool() self.dummy_prefill = False def get_nodes(self, role: EngineRole) -> Dict: @@ -126,6 +128,10 @@ def prefill_nodes(self): def decode_nodes(self): return self.get_nodes(EngineRole.Decode) + @property + def encoder_nodes(self): + return self.get_nodes(EngineRole.Encoder) + def update_config_file(self): """Update the config file.""" nodes = copy.deepcopy(self.nodes) @@ -504,6 +510,18 @@ async def connection_warmup(): rdma_config=node_manager.rdma_config, )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes ]) + logger.info(f'encoder nodes: {node_manager.decode_nodes}\nprefill nodes: {node_manager.prefill_nodes}') + # FIXME: use hybrid nodes now, since we start language server in hybrid, not prefill + await asyncio.gather(*[ + node_manager.ep_connection_pool.connect( + EPConnectionMessage( + e_url=e_url, + p_url=p_url, + protocol=node_manager.migration_protocol, + rdma_config=node_manager.rdma_config, + # )) for e_url in node_manager.encoder_nodes for p_url in node_manager.prefill_nodes + )) for e_url in node_manager.encoder_nodes for p_url in node_manager.hybrid_nodes + ]) return JSONResponse({'SUCCESS': True}) @@ -576,21 +594,80 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque return check_response if node_manager.serving_strategy == ServingStrategy.Hybrid: - node_url = node_manager.get_node_url(request.model) - if not node_url: - return node_manager.handle_unavailable_model(request.model) + # Helper: decide whether we need encoder stage + def _need_encoder(msgs: List[Dict]) -> bool: + try: + for m in msgs: + content = m.get('content') + # user role + list content -> possible multimodal + if m.get('role') == 'user' and isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get('type') in ['image_url', 'image_data', 'image']: + return True + return False + except Exception as e: # noqa + logger.warning(f'encoder detect failed, fallback no-encoder: {e}') + return False - logger.info(f'A request is dispatched to {node_url}') request_dict = request.model_dump() + # 1. Encoder stage (only if encoder node exists & messages contain images) + encoder_url = None + if len(node_manager.encoder_nodes): + if _need_encoder(request_dict.get('messages', [])): + encoder_url = node_manager.get_node_url(request.model, EngineRole.Encoder) + if not encoder_url: + logger.warning( + 'Encoder nodes registered but no suitable encoder node found for model; skip encoder stage.') + else: + logger.info(f'Encoder stage dispatched to {encoder_url}') + enc_start = node_manager.pre_call(encoder_url) + # encoder endpoint path: using vision server example: /v1/chat/completion (singular) + # fall back to /v1/chat/completions if first fails + encoder_response_text = await node_manager.generate(request_dict, encoder_url, + '/v1/chat/completion') + if isinstance(encoder_response_text, (bytes, bytearray)): + try: + encoder_response_text = encoder_response_text.decode('utf-8') + except Exception: # noqa + pass + # simple heuristic: if returns timeout structure (bytes) keep original + try: + enc_json = json.loads(encoder_response_text) + except Exception: + # try alternative endpoint if first not json (maybe 404 HTML) + alt_text = await node_manager.generate(request_dict, encoder_url, '/v1/chat/completions') + try: + enc_json = json.loads(alt_text) + encoder_response_text = alt_text + except Exception: + logger.error('Encoder stage failed: cannot parse JSON; skip encoder stage') + enc_json = None + node_manager.post_call(encoder_url, enc_start) + if enc_json and isinstance(enc_json, dict) and 'encoder_result' in enc_json: + # Replace messages with encoder returned (likely empty) to avoid double encoding + request_dict['messages'] = enc_json.get('messages', []) + request_dict['encoder_result'] = enc_json['encoder_result'] + else: + logger.warning('Encoder response lacks encoder_result, skip passing encoder_result.') + logger.info(f'Post-encoder request dict: {request_dict}') + # 2. Hybrid (LLM) generation stage + node_url = node_manager.get_node_url(request.model, EngineRole.Hybrid) + if not node_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'LLM stage dispatched to {node_url}' + (f' (after encoder {encoder_url})' if encoder_url else '')) start = node_manager.pre_call(node_url) if request.stream is True: response = node_manager.stream_generate(request_dict, node_url, '/v1/chat/completions') background_task = node_manager.create_background_tasks(node_url, start) return StreamingResponse(response, background=background_task) else: - response = await node_manager.generate(request_dict, node_url, '/v1/chat/completions') + response_text = await node_manager.generate(request_dict, node_url, '/v1/chat/completions') node_manager.post_call(node_url, start) - return JSONResponse(json.loads(response)) + try: + return JSONResponse(json.loads(response_text)) + except Exception: + logger.error('Failed to parse LLM response JSON, returning raw text') + return JSONResponse({'raw': response_text}) elif node_manager.serving_strategy == ServingStrategy.DistServe: request_dict = request.model_dump() @@ -621,6 +698,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if not node_manager.dummy_prefill: if not node_manager.pd_connection_pool.is_connected(p_url, d_url): + # FIXME: here perform connections! we need to add similar logic for encode connect + # currently we connect and warmup manually through /distserve/connection_warmup await node_manager.pd_connection_pool.connect( PDConnectionMessage( p_url=p_url, @@ -662,6 +741,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') +# TODO: also change to /v1/completions, similar to /v1/chat/completions @app.post('/v1/completions', dependencies=[Depends(check_api_key)]) async def completions_v1(request: CompletionRequest, raw_request: Request = None): """Completion API similar to OpenAI's API. From 810024907dfa759d8ae89a153da9a3cea519e669 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 10 Oct 2025 11:57:46 +0800 Subject: [PATCH 05/11] cleanups --- .../{ep_proxy_conn.py => epd_proxy_conn.py} | 113 +++++++++--------- lmdeploy/pytorch/disagg/conn/proxy_conn.py | 9 +- lmdeploy/pytorch/disagg/messages.py | 4 +- lmdeploy/pytorch/engine/cache_engine.py | 6 +- .../pytorch/engine/encoder_cache_engine.py | 3 +- lmdeploy/pytorch/engine/engine.py | 91 +++++++------- .../pytorch/engine/executor/ray_executor.py | 2 +- .../pytorch/engine/executor/uni_executor.py | 4 +- lmdeploy/pytorch/messages.py | 8 +- lmdeploy/pytorch/paging/scheduler.py | 93 +++++++------- lmdeploy/serve/openai/api_server.py | 1 - lmdeploy/serve/proxy/proxy.py | 15 ++- 12 files changed, 171 insertions(+), 178 deletions(-) rename lmdeploy/pytorch/disagg/conn/{ep_proxy_conn.py => epd_proxy_conn.py} (74%) diff --git a/lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py b/lmdeploy/pytorch/disagg/conn/epd_proxy_conn.py similarity index 74% rename from lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py rename to lmdeploy/pytorch/disagg/conn/epd_proxy_conn.py index 52409a6668..a388fa5301 100644 --- a/lmdeploy/pytorch/disagg/conn/ep_proxy_conn.py +++ b/lmdeploy/pytorch/disagg/conn/epd_proxy_conn.py @@ -13,7 +13,7 @@ from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, DistServeConnectionResponse, DistServeDropConnectionRequest, DistServeInitRequest, DistServeInitResponse) -from lmdeploy.pytorch.disagg.messages import EPConnectionMessage +from lmdeploy.pytorch.disagg.messages import EPDConnectionMessage logger = get_logger('lmdeploy') @@ -26,23 +26,23 @@ AIOHTTP_TIMEOUT = None -class EPConnectionStatus(enum.Enum): +class EPDConnectionStatus(enum.Enum): Disconnected = enum.auto() Connected = enum.auto() Connecting = enum.auto() -class EPConnectionState: - """EPConnectionState (simple state holder with one event).""" +class EPDConnectionState: + """EPDConnectionState (simple state holder with one event).""" - def __init__(self, status: EPConnectionStatus, event: asyncio.Event): + def __init__(self, status: EPDConnectionStatus, event: asyncio.Event): self.status = status self.event = event async def wait(self): await self.event.wait() - def set_status(self, status: EPConnectionStatus): + def set_status(self, status: EPDConnectionStatus): self.status = status @@ -50,8 +50,9 @@ def get_server_api(url: str, api: str): return f'{url}/{api}' -class EPConnectionPool: - """Constructing the link of E & P engine for the migration of KVCache. +class EPDConnectionPool: + """Constructing the link of E & PD engine for the migration of Encoder + cache. Note: we use Peer to Peer transportation in KVCache migration. Note: Lazy link construction is supported, which perform connection @@ -68,13 +69,13 @@ class EPConnectionPool: CONN_SEMAPHORE_SIZE = 2048 def __init__(self): - # all prefill and decode instances + # all encode, prefill and decode instances # TODO (JimyMa): Maybe encoding instances - self.prefill_endpoints: Set[str] = set() + self.prefill_decode_endpoints: Set[str] = set() self.encode_endpoints: Set[str] = set() - # Links of PD Connection. - self.pool: Dict[Tuple[str, str], EPConnectionState] = {} + # Links of EPD Connection. + self.pool: Dict[Tuple[str, str], EPDConnectionState] = {} # put migrating session to `self.migration_session_shelf` for increasing fault tolerance # if a session is finished, then pop it from `self.migration_session_shelf` @@ -83,7 +84,7 @@ def __init__(self): self.migration_session_shelf: Dict[Tuple[str, str], Set[int]] = defaultdict(set) # conn_perform handler queue - self.waiting_conn: asyncio.Queue[Tuple[EPConnectionMessage, asyncio.Event]] = asyncio.Queue() + self.waiting_conn: asyncio.Queue[Tuple[EPDConnectionMessage, asyncio.Event]] = asyncio.Queue() # conn Registry Lock self.conn_lock = asyncio.Lock() @@ -99,7 +100,7 @@ def __init__(self): def reg_instance(self, role: EngineRole, endpoint: str): if role == EngineRole.Prefill: - self.prefill_endpoints.add(endpoint) + self.prefill_decode_endpoints.add(endpoint) elif role == EngineRole.Encoder: self.encode_endpoints.add(endpoint) else: @@ -112,14 +113,14 @@ def dereg_instance(self, endpoint: str): for k in dropped_key: self.drop(k) self.encode_endpoints.remove(endpoint) - elif endpoint in self.prefill_endpoints: + elif endpoint in self.prefill_decode_endpoints: dropped_key = [k for k in self.pool.keys() if k[1] == endpoint] for k in dropped_key: self.drop(k) # TODO(JimyMa): handle side-effect by kvcache migration - self.prefill_endpoints.remove(endpoint) + self.prefill_decode_endpoints.remove(endpoint) - async def connect(self, conn_req: EPConnectionMessage): + async def connect(self, conn_req: EPDConnectionMessage): async def get_engine_config(server_endpoint): async with self.conn_sem: @@ -153,23 +154,23 @@ async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) result = await resp.json() return DistServeConnectionResponse.model_validate(result) - async def conn_worker(conn_req: EPConnectionMessage, conn_event: asyncio.Event): + async def conn_worker(conn_req: EPDConnectionMessage, conn_event: asyncio.Event): # try: - link = (conn_req.e_url, conn_req.p_url) + link = (conn_req.e_url, conn_req.pd_url) logger.debug(f'{link} connecting...') # Step 1. Get Remote Engine Configuration - prefill_engine_config = await get_engine_config(conn_req.p_url) + prefill_decode_engine_configs = await get_engine_config(conn_req.pd_url) encode_engine_config = await get_engine_config(conn_req.e_url) - print(f'prefill_engine_config: {prefill_engine_config}') + print(f'prefill_decode_engine_configs: {prefill_decode_engine_configs}') print(f'encode_engine_config: {encode_engine_config}') # encode 的 config 大部分字段为 空 # Step 2. Construct Initialize Configuration - prefill_init_req = DistServeInitRequest( + prefill_decode_init_req = DistServeInitRequest( protocol=conn_req.protocol, - local_engine_id=conn_req.p_url, - local_engine_config=prefill_engine_config, + local_engine_id=conn_req.pd_url, + local_engine_config=prefill_decode_engine_configs, remote_engine_id=conn_req.e_url, remote_engine_config=encode_engine_config, rdma_config=conn_req.rdma_config, @@ -179,41 +180,41 @@ async def conn_worker(conn_req: EPConnectionMessage, conn_event: asyncio.Event): protocol=conn_req.protocol, local_engine_id=conn_req.e_url, local_engine_config=encode_engine_config, - remote_engine_id=conn_req.p_url, - remote_engine_config=prefill_engine_config, + remote_engine_id=conn_req.pd_url, + remote_engine_config=prefill_decode_engine_configs, rdma_config=conn_req.rdma_config, nvlink_config=conn_req.nvlink_config, ) - print(f'prefill_init_req: {prefill_init_req}') + print(f'prefill_decode_init_req: {prefill_decode_init_req}') print(f'encode_init_req: {encode_init_req}') - prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req) + prefill_decode_init_resp = await p2p_initialize(conn_req.pd_url, prefill_decode_init_req) encode_init_resp = await p2p_initialize(conn_req.e_url, encode_init_req) # Step 3. Connection encode_endpoint_conn_reqs = DistServeConnectionRequest( protocol=conn_req.protocol, - remote_engine_id=conn_req.p_url, - remote_engine_endpoint_info=prefill_init_resp.engine_endpoint_info, - remote_kvtransfer_endpoint_info=prefill_init_resp.kvtransfer_endpoint_info) - prefill_endpoint_conn_reqs = DistServeConnectionRequest( + remote_engine_id=conn_req.pd_url, + remote_engine_endpoint_info=prefill_decode_init_resp.engine_endpoint_info, + remote_kvtransfer_endpoint_info=prefill_decode_init_resp.kvtransfer_endpoint_info) + prefill_decode_endpoint_conn_reqs = DistServeConnectionRequest( protocol=conn_req.protocol, remote_engine_id=conn_req.e_url, remote_engine_endpoint_info=encode_init_resp.engine_endpoint_info, remote_kvtransfer_endpoint_info=encode_init_resp.kvtransfer_endpoint_info) print(f'encode_endpoint_conn_reqs: {encode_endpoint_conn_reqs}') - print(f'prefill_endpoint_conn_reqs: {prefill_endpoint_conn_reqs}') - await p2p_connect(conn_req.p_url, prefill_endpoint_conn_reqs) + print(f'prefill_decode_endpoint_conn_reqs: {prefill_decode_endpoint_conn_reqs}') + await p2p_connect(conn_req.pd_url, prefill_decode_endpoint_conn_reqs) await p2p_connect(conn_req.e_url, encode_endpoint_conn_reqs) - self.pool[link].set_status(EPConnectionStatus.Connected) - logger.debug(f'{(conn_req.e_url, conn_req.p_url)} connected') + self.pool[link].set_status(EPDConnectionStatus.Connected) + logger.debug(f'{(conn_req.e_url, conn_req.pd_url)} connected') # except Exception as e: - # self.pool[link].set_status(EPConnectionStatus.Disconnected) + # self.pool[link].set_status(EPDConnectionStatus.Disconnected) # logger.error(f'ep connection error: {e}') conn_event.set() - async def wait_for_conn(conn_req: EPConnectionMessage, conn_event: asyncio.Event): - await self.pool[(conn_req.e_url, conn_req.p_url)].event.wait() + async def wait_for_conn(conn_req: EPDConnectionMessage, conn_event: asyncio.Event): + await self.pool[(conn_req.e_url, conn_req.pd_url)].event.wait() conn_event.set() async def _perform_conn(): @@ -226,16 +227,16 @@ async def _perform_conn(): while not self.waiting_conn.empty(): conn_req, conn_event = self.waiting_conn.get_nowait() - link = (conn_req.e_url, conn_req.p_url) + link = (conn_req.e_url, conn_req.pd_url) if link not in self.pool: - self.pool[link] = EPConnectionState( - EPConnectionStatus.Disconnected, + self.pool[link] = EPDConnectionState( + EPDConnectionStatus.Disconnected, conn_event, ) - if self.pool[link].status == EPConnectionStatus.Connecting: + if self.pool[link].status == EPDConnectionStatus.Connecting: asyncio.create_task(wait_for_conn(conn_req, conn_event)) - elif self.pool[link].status == EPConnectionStatus.Disconnected: - self.pool[link].set_status(EPConnectionStatus.Connecting) + elif self.pool[link].status == EPDConnectionStatus.Disconnected: + self.pool[link].set_status(EPDConnectionStatus.Connecting) asyncio.create_task(conn_worker(conn_req, conn_event)) if not self.initialized: @@ -249,16 +250,16 @@ async def _perform_conn(): self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) self.initialized = True - print(f'EPConnectionPool connect called: {conn_req.e_url} <-> {conn_req.p_url}') + print(f'EPDConnectionPool connect called: {conn_req.e_url} <-> {conn_req.pd_url}') self.reg_instance(EngineRole.Encoder, conn_req.e_url) - self.reg_instance(EngineRole.Prefill, conn_req.p_url) + self.reg_instance(EngineRole.Prefill, conn_req.pd_url) cnt = 0 while cnt < self.max_retry_cnt: - if self.is_connected(conn_req.e_url, conn_req.p_url): + if self.is_connected(conn_req.e_url, conn_req.pd_url): return if cnt > 0: - logger.warning(f'EP connection failure, retry cnt: {cnt}') + logger.warning(f'EPD connection failure, retry cnt: {cnt}') # simple incremental backoff await asyncio.sleep(min(1.0, 0.2 * cnt)) conn_event = asyncio.Event() @@ -267,15 +268,15 @@ async def _perform_conn(): await conn_event.wait() cnt += 1 async with self.conn_lock: - if (conn_req.e_url, conn_req.p_url) in self.pool: - self.pool[conn_req.e_url, conn_req.p_url].set_status(EPConnectionStatus.Disconnected) - raise TimeoutError('EPConnection Failure') + if (conn_req.e_url, conn_req.pd_url) in self.pool: + self.pool[conn_req.e_url, conn_req.pd_url].set_status(EPDConnectionStatus.Disconnected) + raise TimeoutError('EPDConnection Failure') - def is_connected(self, e_url: str, p_url: str): - link = self.pool.get((e_url, p_url), None) + def is_connected(self, e_url: str, pd_url: str): + link = self.pool.get((e_url, pd_url), None) if not link: return False - return link.status == EPConnectionStatus.Connected + return link.status == EPDConnectionStatus.Connected def drop(self, ep_key: Tuple[str, str]): left = ep_key[0] @@ -320,4 +321,4 @@ async def close(self): try: await self.conn_sess.close() except Exception as e: - logger.warning(f'EPConnectionPool close error: {e}') + logger.warning(f'EPDConnectionPool close error: {e}') diff --git a/lmdeploy/pytorch/disagg/conn/proxy_conn.py b/lmdeploy/pytorch/disagg/conn/proxy_conn.py index a740c28284..f0ae916251 100644 --- a/lmdeploy/pytorch/disagg/conn/proxy_conn.py +++ b/lmdeploy/pytorch/disagg/conn/proxy_conn.py @@ -119,7 +119,6 @@ def unshelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int): self.migration_session_shelf[conn_key].remove(session_id) async def connect(self, conn_req: PDConnectionMessage): - # perform connection here async def get_engine_config(server_endpoint): async with self.conn_sem: @@ -148,7 +147,7 @@ async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) timeout=self.aiotimeout, ) as resp: result = await resp.json() - print(f'p2p_connect response: {result}') + logger.info(f'=> p2p_connect response: {result}') return DistServeConnectionResponse.model_validate(result) async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): @@ -163,7 +162,7 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): assert prefill_engine_config.tp_size == decode_engine_config.tp_size # Step 2. Construct Initialize Configuration - print(f'check conn_req: {conn_req}') + logger.info(f'=> check conn_req: {conn_req}') prefill_init_req = DistServeInitRequest( protocol=conn_req.protocol, local_engine_id=conn_req.p_url, @@ -186,8 +185,8 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req) decode_init_resp = await p2p_initialize(conn_req.d_url, decode_init_req) - print(f'=> p2p init, prefill_init_resp: \n{prefill_init_resp}\n') - print(f'=> p2p init, decode_init_resp: \n{decode_init_resp}\n') + logger.info(f'=> p2p init, prefill_init_resp: \n{prefill_init_resp}\n') + logger.info(f'=> p2p init, decode_init_resp: \n{decode_init_resp}\n') # Step 3. Connection prefill_endpoint_conn_reqs = DistServeConnectionRequest( protocol=conn_req.protocol, diff --git a/lmdeploy/pytorch/disagg/messages.py b/lmdeploy/pytorch/disagg/messages.py index d27f873e47..437ddbeebd 100644 --- a/lmdeploy/pytorch/disagg/messages.py +++ b/lmdeploy/pytorch/disagg/messages.py @@ -38,9 +38,9 @@ class PDConnectionMessage(BaseModel): nvlink_config: Optional[DistServeNVLinkConfig] = None -class EPConnectionMessage(BaseModel): +class EPDConnectionMessage(BaseModel): e_url: str - p_url: str + pd_url: str protocol: MigrationProtocol = MigrationProtocol.RDMA tcp_config: Optional[DistServeTCPConfig] = None rdma_config: Optional[DistServeRDMAConfig] = None diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index 026882eea2..d4b231277e 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -398,14 +398,14 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote batch=assignment_batch, )) - async def ep_migrate(self, migration_execution_inputs: MigrationExecutionBatch): + async def epd_migrate(self, migration_execution_inputs: MigrationExecutionBatch): """Handles the migration of the Multi-Modal (MM) cache.""" if not self.migration_backend_impl: - logger.error('Migration backend is not initialized. Cannot perform EP migration.') + logger.error('Migration backend is not initialized. Cannot perform EPD migration.') return if self.encoder_gpu_cache.numel() == 0: - logger.warning('MM GPU cache is not allocated or is empty. Skipping EP migration.') + logger.warning('MM GPU cache is not allocated or is empty. Skipping EPD migration.') return _, tokens_per_image, hidden_size = self.encoder_gpu_cache.shape diff --git a/lmdeploy/pytorch/engine/encoder_cache_engine.py b/lmdeploy/pytorch/engine/encoder_cache_engine.py index 85ac364083..26d092a65a 100644 --- a/lmdeploy/pytorch/engine/encoder_cache_engine.py +++ b/lmdeploy/pytorch/engine/encoder_cache_engine.py @@ -45,8 +45,7 @@ def __init__( self.rank = rank self.tp_rank = tp_rank - # self.feature_dtype = torch.float16 - # FIXME: turbomind forward() returns float16, pytorch returns bfloat16 + # FIXME: turbomind forward() returns float16, pytorch forward uses bfloat16 self.feature_dtype = torch.bfloat16 self._num_gpu_blocks = num_gpu_blocks diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b6bede791d..1d8c54cf51 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -418,9 +418,9 @@ def __init__(self, # For migrating prefill request to decode engine self.migration_event: asyncio.Event = None # For encoder result migration - self.ep_migration_event: asyncio.Event = None + self.epd_migration_event: asyncio.Event = None # For backpressure prefill request when cache is full - self.perfill_watermark_event: asyncio.Event = None + self.prefill_watermark_event: asyncio.Event = None self.engine_conn = EngineP2PConnection(self) @@ -585,8 +585,7 @@ def _on_add_message(self, reqs: List[Request], **kwargs): logger.warning('Vision encoder has not been loaded, multimodal inputs will be ignored.') continue - # FIXME, here if we deploy MLLM, will invoke input_processor to process multimodal data - # but now if encoder_result is detected, need to skip preprocess + # skip preprocess if encoder results exist if req_data.get('encoder_result') is None: result = self.input_processor.preprocess_input(input_ids, input_multimodals) input_ids = result.input_ids @@ -595,9 +594,8 @@ def _on_add_message(self, reqs: List[Request], **kwargs): req_data['token_ids'] = input_ids req_data['input_multimodals'] = input_multimodals else: - # ignore multimodal inputs req_data['input_multimodals'] = None - logger.info('Have encoder result, try to fetch from encode instance') + logger.info('Ignore multimodal inputs since encoder results exist.') if len(valid_reqs) > 0: self._add_message(valid_reqs) @@ -629,8 +627,8 @@ def __update_max_new_tokens(msg): if len(sess.sequences) == 0: migration_request = req.data.get('migration_request') encoder_result = req.data.get('encoder_result') - print(f'=> add msg, migration_request {migration_request}') - print(f'=> add msg, encoder_result {encoder_result}') + logger.info(f'=> add msg, migration_request {migration_request}') + logger.info(f'=> add msg, encoder_result {encoder_result}') assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') sess.add_sequence(req.data['token_ids'], sampling_param=sampling_param, @@ -649,9 +647,9 @@ def __update_max_new_tokens(msg): self.migration_event.set() # if have encoder results here, skip encoding, directly proceed to prefill if encoder_result: - logger.info('set waiting EP migration') - self.scheduler._set_message_status(msg, MessageStatus.WAITING_EP_MIGRATION) - self.ep_migration_event.set() + logger.info('=> set waiting EPD migration') + self.scheduler._set_message_status(msg, MessageStatus.WAITING_EPD_MIGRATION) + self.epd_migration_event.set() else: msg = next(iter(sess.sequences.values())) msg.update_token_ids( @@ -713,10 +711,10 @@ def __has_values(input_multimodals): return True return False - has_encoder_result = any([msg.encoder_result is not None for msg in messages]) - # FIXME: any special treatment for encoder_result? - if has_encoder_result: - pass + # has_encoder_result = any([msg.encoder_result is not None for msg in messages]) + # # FIXME: any special treatment for encoder_result? + # if has_encoder_result: + # pass has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages]) if has_embedding: @@ -822,7 +820,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): model_inputs.history_cross_length = history_cross_length # vision inputs - # FIXME, handle vision inputs differently, directly migrate feature from encoder_results vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs) model_inputs.vision_inputs = vision_model_inputs @@ -1079,47 +1076,47 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event await asyncio.sleep(.5) @torch.inference_mode() - async def _async_loop_ep_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): + async def _async_loop_epd_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): """Async loop for encoder-prefill migration.""" while True: - ep_migration_running = self.scheduler._schedule_ep_migration() - if not ep_migration_running and not self.scheduler.has_ep_migration_waiting(): - await self.ep_migration_event.wait() - elif ep_migration_running: - self.ep_migration_event.clear() - for msg in ep_migration_running: - logger.info('performing ep migrations.') + epd_migration_running = self.scheduler._schedule_epd_migration() + if not epd_migration_running and not self.scheduler.has_epd_migration_waiting(): + await self.epd_migration_event.wait() + elif epd_migration_running: + self.epd_migration_event.clear() + for msg in epd_migration_running: + logger.info('performing epd migrations.') migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] - ep_migration_request = msg.encoder_result - encoder_block_ids = ep_migration_request.remote_block_ids + epd_migration_request = msg.encoder_result + encoder_block_ids = epd_migration_request.remote_block_ids # FIXME: only test one request now, we simply use the same block ids # ideally we should get block ids from scheduler, corresponding to the msg - prefill_block_ids = ep_migration_request.remote_block_ids + prefill_block_ids = epd_migration_request.remote_block_ids assert len(encoder_block_ids) == len(prefill_block_ids), ( f'#encoder block ids ({len(encoder_block_ids)}) must equal to ' f'#prefill block ids ({len(prefill_block_ids)}) ' f'all id length: {len(msg.num_token_ids)}') migration_execution_requests.append(( - ep_migration_request.remote_engine_id, + epd_migration_request.remote_engine_id, list(zip(encoder_block_ids, prefill_block_ids)), )) - migration_inputs = MigrationExecutionBatch(protocol=ep_migration_request.protocol, + migration_inputs = MigrationExecutionBatch(protocol=epd_migration_request.protocol, requests=migration_execution_requests) - logger.info(f'migrating encoder features for session: {msg.session_id} begin') - await self.executor.ep_migrate(migration_inputs) - logger.info(f'migrating encoder features for session: {msg.session_id} done') + logger.info(f'migrating encoder cache for session: {msg.session_id} begin') + await self.executor.epd_migrate(migration_inputs) + logger.info(f'migrating encoder cache for session: {msg.session_id} done') # TODO: we don't send free cache via zmq now, leave as future work - # await self.engine_conn.zmq_send(remote_engine_id=ep_migration_request.remote_engine_id, - # remote_session_id=ep_migration_request.remote_session_id) + # await self.engine_conn.zmq_send(remote_engine_id=epd_migration_request.remote_engine_id, + # remote_session_id=epd_migration_request.remote_session_id) # After migration, the sequences are ready for prefill. We change their status to WAITING # later it will be scheduled by self.scheduler.schedule_prefill() and proceed to prefill stage - self.scheduler.lock_running_ep_migration(ep_migration_running) - for msg in ep_migration_running: + self.scheduler.lock_running_epd_migration(epd_migration_running) + for msg in epd_migration_running: self.scheduler._set_message_status(msg, MessageStatus.WAITING) - self.scheduler.unlock_running_ep_migration(ep_migration_running) + self.scheduler.unlock_running_epd_migration(epd_migration_running) has_runable_event.set() else: @@ -1150,11 +1147,11 @@ async def _async_loop_main( forward_event.clear() scheduler.collect_migration_done() - scheduler.collect_ep_migration_done() + scheduler.collect_epd_migration_done() forward_inputs, next_running = await inputs_maker.send_next_inputs() if next_running is None: # TODO (JimyMa): add watermark check event instead of async sleep. - # self.perfill_watermark_event.wait() + # self.prefill_watermark_event.wait() logger.warning(f'no next prefill running request, Maybe cache is full, ' f'free gpu cache blocks: {scheduler.block_manager.get_num_free_gpu_blocks()}, ' f'total gpu cache blocks: {scheduler.block_manager.num_gpu_blocks}') @@ -1174,7 +1171,7 @@ async def _async_loop_main( # pre-forward before get last token if idx == num_loops - 1: scheduler.collect_migration_done() - scheduler.collect_ep_migration_done() + scheduler.collect_epd_migration_done() forward_inputs, next_running = await inputs_maker.prefetch_next_inputs() # send output @@ -1241,7 +1238,7 @@ async def async_loop(self): # migration task self.migration_event = asyncio.Event() - self.ep_migration_event = asyncio.Event() + self.epd_migration_event = asyncio.Event() logger.info('Starting executor.') self.executor.start(forward_event) @@ -1270,13 +1267,13 @@ async def async_loop(self): ) loop_tasks.append(loop_migration) - # TODO: modify proxy, add encoder role, only create this coroutine when in EPD mode - logger.info('Starting async task EPMigrationLoop.') - loop_ep_migration = event_loop.create_task( - self._async_loop_ep_migration(resp_que, has_runable_event=has_runable_event), - name='MainLoopEPMigration', + # TODO: only create this coroutine when in EPD mode + logger.info('Starting async task EPDMigrationLoop.') + loop_epd_migration = event_loop.create_task( + self._async_loop_epd_migration(resp_que, has_runable_event=has_runable_event), + name='MainLoopEPDMigration', ) - loop_tasks.append(loop_ep_migration) + loop_tasks.append(loop_epd_migration) # binding done callback self._add_loop_tasks_done_callback(loop_tasks) diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 47336fbda7..0915ebf3cc 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -593,7 +593,7 @@ async def migrate(self, batch: MigrationExecutionBatch): jobs = (worker.migrate.remote(batch) for worker in self.workers) return await asyncio.gather(*jobs) - async def ep_migrate(self, batch: MigrationExecutionBatch): + async def epd_migrate(self, batch: MigrationExecutionBatch): jobs = (worker.ep_migrate.remote(batch) for worker in self.workers) return await asyncio.gather(*jobs) diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index e5d41d92fe..c43cc6ad50 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -117,8 +117,8 @@ async def migrate(self, batch: MigrationExecutionBatch): """KV Cache Migration.""" return await self.model_agent.cache_engine.migrate(batch) - async def ep_migrate(self, batch: MigrationExecutionBatch): + async def epd_migrate(self, batch: MigrationExecutionBatch): """Encoder Cache Migration.""" - return await self.model_agent.cache_engine.ep_migrate(batch) + return await self.model_agent.cache_engine.epd_migrate(batch) """ PD Disaggregation API End """ diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index a7d728c40a..8aa974845a 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -159,10 +159,10 @@ class MessageStatus(enum.Enum): MIGRATION_LOCKED = enum.auto() MIGRATION_DONE = enum.auto() - WAITING_EP_MIGRATION = enum.auto() # waiting for encoder => prefill migration - RUNNING_EP_MIGRATION = enum.auto() # running encoder => prefill migration - EP_MIGRATION_LOCKED = enum.auto() # locked during encoder => prefill migration - EP_MIGRATION_DONE = enum.auto() # done encoder => prefill migration + WAITING_EPD_MIGRATION = enum.auto() + RUNNING_EPD_MIGRATION = enum.auto() + EPD_MIGRATION_LOCKED = enum.auto() + EPD_MIGRATION_DONE = enum.auto() SeqMap = Dict[int, 'SchedulerSequence'] diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index b53b6d795b..1e05246edc 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -99,21 +99,21 @@ def migration_done(self): return list(seq_map.values()) @property - def waiting_ep_migration(self): + def waiting_epd_migration(self): """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_EP_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_EPD_MIGRATION) return list(seq_map.values()) @property - def running_ep_migration(self): + def running_epd_migration(self): """Get running sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_EP_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_EPD_MIGRATION) return list(seq_map.values()) @property - def ep_migration_done(self): + def epd_migration_done(self): """Get migration done sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.EP_MIGRATION_DONE) + seq_map = self.seq_manager.get_sequences(MessageStatus.EPD_MIGRATION_DONE) return list(seq_map.values()) def build_eviction_helper(self, eviction_type: str): @@ -201,14 +201,14 @@ def _reorder_migrating(): return running_migration @logging_timer('ScheduleEPMigration', logger) - def _schedule_ep_migration(self): - running_ep_migration: SeqList = [] + def _schedule_epd_migration(self): + running_epd_migration: SeqList = [] migrating_token_count = 0 def _to_running(seq: SchedulerSequence): """To running.""" - seq.status = MessageStatus.RUNNING_EP_MIGRATION - running_ep_migration.append(seq) + seq.status = MessageStatus.RUNNING_EPD_MIGRATION + running_epd_migration.append(seq) nonlocal migrating_token_count migrating_token_count += seq.num_token_ids @@ -223,24 +223,24 @@ def __evict_for_seq(seq: SchedulerSequence, waiting): def _reorder_migrating(): """Reorder waiting.""" - return sorted(self.waiting_ep_migration, key=lambda seq: seq.arrive_time) + return sorted(self.waiting_epd_migration, key=lambda seq: seq.arrive_time) - waiting_ep_migration = _reorder_migrating() - print(f'=> check waiting EP migration: {waiting_ep_migration}') + waiting_epd_migration = _reorder_migrating() + print(f'=> check waiting EPD migration: {waiting_epd_migration}') max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() - while len(waiting_ep_migration) > 0 and len(running_ep_migration) < max_batches: - seq = waiting_ep_migration.pop(0) - self.block_trie.match(waiting_ep_migration) - if not __evict_for_seq(seq, waiting_ep_migration): + while len(waiting_epd_migration) > 0 and len(running_epd_migration) < max_batches: + seq = waiting_epd_migration.pop(0) + self.block_trie.match(waiting_epd_migration) + if not __evict_for_seq(seq, waiting_epd_migration): break # allocate session memory self.block_manager.allocate(seq) _to_running(seq) - print(f'=> check running EP migration: {running_ep_migration}') - return running_ep_migration + print(f'=> check running EPD migration: {running_epd_migration}') + return running_epd_migration @logging_timer('SchedulePrefilling', logger) def _schedule_prefill(self, prealloc_size: int = 0): @@ -403,8 +403,7 @@ def end_session(self, session_id: int): def has_unfinished(self): """Check if there are any unfinished message.""" - # return self.has_running() or self.has_waiting() or self.has_migration_done() - return self.has_running() or self.has_waiting() or self.has_migration_done() or self.has_ep_migration_done() + return self.has_running() or self.has_waiting() or self.has_migration_done() or self.has_epd_migration_done() def has_running(self): return self.num_running() > 0 @@ -424,14 +423,14 @@ def has_migration_waiting(self): def has_migration_done(self): return self.num_migration_done() > 0 - def has_ep_migration_running(self): - return self.num_ep_migration_running() > 0 + def has_epd_migration_running(self): + return self.num_epd_migration_running() > 0 - def has_ep_migration_waiting(self): - return self.num_ep_migration_waiting() > 0 + def has_epd_migration_waiting(self): + return self.num_epd_migration_waiting() > 0 - def has_ep_migration_done(self): - return self.num_ep_migration_done() > 0 + def has_epd_migration_done(self): + return self.num_epd_migration_done() > 0 def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" @@ -465,17 +464,17 @@ def num_migration_waiting(self): """Num waiting.""" return self.seq_manager.num_sequences(MessageStatus.WAITING_MIGRATION) - def num_ep_migration_running(self): - """Num EP migration running.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING_EP_MIGRATION) + def num_epd_migration_running(self): + """Num EPD migration running.""" + return self.seq_manager.num_sequences(MessageStatus.RUNNING_EPD_MIGRATION) - def num_ep_migration_done(self): - """Num EP migration done.""" - return self.seq_manager.num_sequences(MessageStatus.EP_MIGRATION_DONE) + def num_epd_migration_done(self): + """Num EPD migration done.""" + return self.seq_manager.num_sequences(MessageStatus.EPD_MIGRATION_DONE) - def num_ep_migration_waiting(self): - """Num EP migration waiting.""" - return self.seq_manager.num_sequences(MessageStatus.WAITING_EP_MIGRATION) + def num_epd_migration_waiting(self): + """Num EPD migration waiting.""" + return self.seq_manager.num_sequences(MessageStatus.WAITING_EPD_MIGRATION) def num_locked(self): """Num locked.""" @@ -504,26 +503,26 @@ def unlock_running_migration(self, locked: SeqList): if seq.status == MessageStatus.MIGRATION_LOCKED: self._set_message_status(seq, MessageStatus.MIGRATION_DONE) - def lock_running_ep_migration(self, running: SeqList): - """Lock running EP migration sequence.""" + def lock_running_epd_migration(self, running: SeqList): + """Lock running EPD migration sequence.""" for seq in running: - if seq.status == MessageStatus.RUNNING_EP_MIGRATION: - self._set_message_status(seq, MessageStatus.EP_MIGRATION_LOCKED) + if seq.status == MessageStatus.RUNNING_EPD_MIGRATION: + self._set_message_status(seq, MessageStatus.EPD_MIGRATION_LOCKED) - def unlock_running_ep_migration(self, locked: SeqList): - """Unlock running EP migration.""" + def unlock_running_epd_migration(self, locked: SeqList): + """Unlock running EPD migration.""" for seq in locked: - if seq.status == MessageStatus.EP_MIGRATION_LOCKED: - self._set_message_status(seq, MessageStatus.EP_MIGRATION_DONE) + if seq.status == MessageStatus.EPD_MIGRATION_LOCKED: + self._set_message_status(seq, MessageStatus.EPD_MIGRATION_DONE) def collect_migration_done(self): migration_done = self.migration_done for seq in migration_done: self._set_message_status(seq, MessageStatus.RUNNING) - def collect_ep_migration_done(self): - ep_migration_done = self.ep_migration_done - for seq in ep_migration_done: + def collect_epd_migration_done(self): + epd_migration_done = self.epd_migration_done + for seq in epd_migration_done: self._set_message_status(seq, MessageStatus.RUNNING) @property diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index e8727d7de9..b37c92df5a 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -362,7 +362,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque encoder_result = EncoderResult.model_validate(encoder_result) print(f'=> api server, migration_request: \n{migration_request}\n') print(f'=> api server, encoder_result: \n{encoder_result}\n') - # import pdb; pdb.set_trace() if request.session_id == -1: VariableInterface.session_id += 1 request.session_id = VariableInterface.session_id diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 953d2c32e4..ba6beb0053 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -22,10 +22,10 @@ from pydantic import BaseModel, Field from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy -from lmdeploy.pytorch.disagg.conn.ep_proxy_conn import EPConnectionPool +from lmdeploy.pytorch.disagg.conn.epd_proxy_conn import EPDConnectionPool from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool -from lmdeploy.pytorch.disagg.messages import EPConnectionMessage, PDConnectionMessage +from lmdeploy.pytorch.disagg.messages import EPDConnectionMessage, PDConnectionMessage from lmdeploy.serve.openai.api_server import check_api_key, create_error_response from lmdeploy.serve.openai.protocol import ModelCard # noqa: E501 from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission @@ -109,7 +109,7 @@ def __init__(self, self.migration_protocol = MigrationProtocol[migration_protocol] self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type]) self.pd_connection_pool = PDConnectionPool() - self.ep_connection_pool = EPConnectionPool() + self.ep_connection_pool = EPDConnectionPool() self.dummy_prefill = False def get_nodes(self, role: EngineRole) -> Dict: @@ -511,16 +511,15 @@ async def connection_warmup(): )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes ]) logger.info(f'encoder nodes: {node_manager.decode_nodes}\nprefill nodes: {node_manager.prefill_nodes}') - # FIXME: use hybrid nodes now, since we start language server in hybrid, not prefill + # FIXME: we set pd_urls to use hybrid nodes, since we start LLM server in hybrid role await asyncio.gather(*[ node_manager.ep_connection_pool.connect( - EPConnectionMessage( + EPDConnectionMessage( e_url=e_url, - p_url=p_url, + pd_url=pd_url, protocol=node_manager.migration_protocol, rdma_config=node_manager.rdma_config, - # )) for e_url in node_manager.encoder_nodes for p_url in node_manager.prefill_nodes - )) for e_url in node_manager.encoder_nodes for p_url in node_manager.hybrid_nodes + )) for e_url in node_manager.encoder_nodes for pd_url in node_manager.hybrid_nodes ]) return JSONResponse({'SUCCESS': True}) From eda8966b4771f9cba15a545b640740de480663e7 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 16 Oct 2025 17:20:38 +0800 Subject: [PATCH 06/11] WIP: seperate multimodal engine --- lmdeploy/cli/utils.py | 2 +- lmdeploy/multimodal/engine/cache_engine.py | 132 +++++++++++ lmdeploy/multimodal/engine/engine.py | 75 ++++++ lmdeploy/multimodal/engine/model_agent.py | 231 +++++++++++++++++++ lmdeploy/multimodal/engine/post_process.py | 55 +++++ lmdeploy/multimodal/engine/pre_process.py | 167 ++++++++++++++ lmdeploy/multimodal/models/base.py | 253 +++++++++++++++++++++ lmdeploy/multimodal/models/builder.py | 55 +++++ lmdeploy/multimodal/models/internvl3_hf.py | 212 +++++++++++++++++ lmdeploy/pytorch/engine/model_agent.py | 4 +- lmdeploy/pytorch/engine/request.py | 2 +- lmdeploy/serve/async_engine.py | 9 + lmdeploy/serve/openai/api_server.py | 14 +- lmdeploy/serve/vl_async_engine.py | 9 + 14 files changed, 1215 insertions(+), 5 deletions(-) create mode 100644 lmdeploy/multimodal/engine/cache_engine.py create mode 100644 lmdeploy/multimodal/engine/engine.py create mode 100644 lmdeploy/multimodal/engine/model_agent.py create mode 100644 lmdeploy/multimodal/engine/post_process.py create mode 100644 lmdeploy/multimodal/engine/pre_process.py create mode 100644 lmdeploy/multimodal/models/base.py create mode 100644 lmdeploy/multimodal/models/builder.py create mode 100644 lmdeploy/multimodal/models/internvl3_hf.py diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index bfd94182d0..a7b2e6a778 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -594,7 +594,7 @@ def role(parser): return parser.add_argument('--role', type=str, default='Hybrid', - choices=['Hybrid', 'Prefill', 'Decode'], + choices=['Hybrid', 'Prefill', 'Decode', 'Encoder'], help='Hybrid for Non-Disaggregated Engine; ' 'Prefill for Disaggregated Prefill Engine; ' 'Decode for Disaggregated Decode Engine') diff --git a/lmdeploy/multimodal/engine/cache_engine.py b/lmdeploy/multimodal/engine/cache_engine.py new file mode 100644 index 0000000000..cf4fa5343f --- /dev/null +++ b/lmdeploy/multimodal/engine/cache_engine.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +from typing import List, Optional, Tuple + +import torch + +from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS +from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo +from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +FEATURE_BLOCK_SHAPE = (256, 4096) + + +class EncoderCacheEngine: + """Manages the memory pool for image features. + + This engine allocates and manages a contiguous block of GPU memory + to store image embeddings transferred from an encoder. It is adapted for + an encoder-LLM separated architecture. + Args: + rank (int): Distributed rank. + tp_rank (int): Tensor parallelism rank. + world_size (int): Distributed world size. + """ + + def __init__( + self, + num_gpu_blocks: int = 128, + rank: int = 0, + tp_rank: int = 0, + world_size: int = 1, + ) -> None: + self.world_size = world_size + self.rank = rank + self.tp_rank = tp_rank + + self.feature_dtype = torch.bfloat16 + self._num_gpu_blocks = num_gpu_blocks + + self.encoder_gpu_cache = self._allocate_gpu_cache() + + self.migration_backend_impl: Optional[MigrationBackendImpl] = None + + self.cache_stream = torch.cuda.Stream() + assert self.cache_stream != torch.cuda.current_stream() + self.events = torch.cuda.Event() + + # for memory block management + self.free_blocks = list(range(num_gpu_blocks)) + logger.debug(f'Initialize feature cache engine with {self.num_gpu_blocks} gpu blocks.') + + @property + def free_block_count(self) -> int: + """Number of free blocks available in the cache.""" + return len(self.free_blocks) + + @property + def gpu_cache(self) -> torch.Tensor: + """The GPU feature pool tensor.""" + return self.encoder_gpu_cache + + @property + def num_gpu_blocks(self) -> int: + """Number of GPU blocks.""" + return self._num_gpu_blocks + + @staticmethod + def get_feature_block_shape() -> Tuple[int, int]: + """Get the shape of a single image feature block.""" + return FEATURE_BLOCK_SHAPE + + def _allocate_cache(self, num_blocks: int, device: torch.device) -> torch.Tensor: + """Allocate the memory pool on the specified device.""" + block_shape = self.get_feature_block_shape() + + # allocate a large contiguous tensor as the encoder cache + encoder_cache = torch.empty( + size=(num_blocks, *block_shape), + dtype=self.feature_dtype, + device=device, + ) + return encoder_cache + + def _allocate_gpu_cache(self) -> torch.Tensor: + """Allocate the feature pool on the GPU.""" + return self._allocate_cache(self.num_gpu_blocks, 'cuda') + + @classmethod + def get_cache_block_size(cls) -> int: + """Get the memory size in bytes of a single feature block.""" + + shape = cls.get_feature_block_shape() + dtype = torch.bfloat16 + + meta_tensor = torch.empty(shape, dtype=dtype, device='meta') + return meta_tensor.numel() * meta_tensor.element_size() + + """ Methods for Disaggregation Begin. """ + + def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> List[DistServeKVTransferEndpointInfo]: + if not self.migration_backend_impl: + self.migration_backend_impl: MigrationBackendImpl = MIGRATION_BACKENDS.module_dict['DLSlime']() + migration_init_request.rank = self.rank + self.migration_backend_impl.p2p_initialize(migration_init_request) + + t = self.encoder_gpu_cache + if t.numel() > 0: + register_mr_request = DistServeRegisterMRMessage( + protocol=migration_init_request.protocol, + remote_engine_id=migration_init_request.remote_engine_id, + mr_key='encoder_cache', # fix memory registration key + addr=t.data_ptr(), + offset=t.storage_offset(), + length=t.numel() * t.itemsize) + self.migration_backend_impl.register_memory_region(register_mr_request) + + return [ + DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, + endpoint_info=json.dumps( + self.migration_backend_impl.endpoint_info( + migration_init_request.remote_engine_id, + migration_init_request.protocol))) + ] + + def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]): + self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank]) + + """ Methods for Disaggregation End. """ diff --git a/lmdeploy/multimodal/engine/engine.py b/lmdeploy/multimodal/engine/engine.py new file mode 100644 index 0000000000..1755591862 --- /dev/null +++ b/lmdeploy/multimodal/engine/engine.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import copy + +from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.utils import get_logger + +from ..models.builder import load_mm_model +from .model_agent import build_model_agent +from .post_process import PostProcessor +from .pre_process import PreProcessor + +logger = get_logger('lmdeploy') + + +class MultiModalEngine(): + """The multi-modal async inference engine of lmdeploy.""" + + def __init__(self, + model_path: str, + tokenizer: object, + engine_config: PytorchEngineConfig = None, + trust_remote_code: bool = True) -> None: + # make sure engine config exist + if engine_config is None: + engine_config = PytorchEngineConfig() + self.engine_config = copy.deepcopy(engine_config) + self.tokenizer = tokenizer + + # build model + self.model = load_mm_model(model_path, backend_config=self.engine_config) + + # build model agent + self.model_agent = build_model_agent(self.model) + self.model_agent.init() + + # init pre / post processor + self.post_processor = PostProcessor(self.model_agent) + self.pre_processor = PreProcessor(self.model_agent, self.post_processor) + + def start_loop(self): + """Start async loops.""" + # invoked in api server start up event, where we already have running event loop started by uvicorn.run() + # therefore we don't create a new event loop manually, simply start loops for each module + self.pre_processor.start_loop() + self.post_processor.start_loop() + self.model_agent.start_loop() + + def close(self): + """Close the engine and release resources.""" + self.pre_processor.close() + self.post_processor.close() + self.model_agent.close() + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path: str, + tokenizer: object, + engine_config: PytorchEngineConfig = None, + trust_remote_code: bool = True, + **kwargs): + """Create a MultiModalEngine instance.""" + return cls(model_path=pretrained_model_name_or_path, + tokenizer=tokenizer, + engine_config=engine_config, + trust_remote_code=trust_remote_code) + + async def encode(self, messages, session_id: int): + """Async encode.""" + future = asyncio.Future() + + # future will later be set in post-processor + self.pre_processor.process(session_id, messages, future) + + return await future diff --git a/lmdeploy/multimodal/engine/model_agent.py b/lmdeploy/multimodal/engine/model_agent.py new file mode 100644 index 0000000000..f15402311f --- /dev/null +++ b/lmdeploy/multimodal/engine/model_agent.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from typing import List + +import torch + +from lmdeploy.utils import get_logger + +from .cache_engine import EncoderCacheEngine + +logger = get_logger('lmdeploy') + + +def _try_to_cuda(data, non_blocking: bool = True): + """Recursively traverses a data structure and moves all torch.Tensors to + the configured device.""" + if data is None: + return None + if isinstance(data, torch.Tensor): + return data.to('cuda', non_blocking=non_blocking) + if isinstance(data, list): + return [_try_to_cuda(item, non_blocking) for item in data] + if isinstance(data, tuple): + return tuple(_try_to_cuda(item, non_blocking) for item in data) + if isinstance(data, dict): + return {key: _try_to_cuda(value, non_blocking) for key, value in data.items()} + return data + + +def _try_to_cpu(data): + """Recursively traverses a data structure and moves all torch.Tensors to + the CPU.""" + if data is None: + return None + if isinstance(data, torch.Tensor): + return data.cpu() + if isinstance(data, list): + return [_try_to_cpu(item) for item in data] + if isinstance(data, tuple): + return tuple(_try_to_cpu(item) for item in data) + if isinstance(data, dict): + return {key: _try_to_cpu(value) for key, value in data.items()} + return data + + +class BaseModelAgent: + + def __init__(self, model): + + # PreProcessor -> h2d loop + self._pre_in_que = asyncio.Queue() + # h2d loop -> forward loop + self._in_que = asyncio.Queue() + # forward loop -> d2h loop + self._out_que = asyncio.Queue() + # d2h loop -> PostProcessor + self._post_proc_que = asyncio.Queue() + + # backpressure signal between h2d loop <-> forward loop + self.has_inputs = asyncio.Event() + + # CUDA streams + self.in_stream = torch.cuda.Stream() + self.out_stream = torch.cuda.Stream() + self.forward_stream = torch.cuda.Stream() + + # self.model_path = model_path + self.model = model + self.device = 'cuda' + + async def make_batch(self): + # TODO: fix for multi-batch + requests = [] + + req = await self._pre_in_que.get() + requests.append(req) + + return requests[0] + + async def async_model_forward(self): + """Model forward.""" + while True: + # wait for inputs + session_id, forward_inputs = await self._in_que.get() + print(f'get session_id: {session_id}') + print(f'get forward inputs from _in_que: {forward_inputs}') + self.next_inputs = None + + with torch.cuda.stream(self.forward_stream): + forward_outputs, allocated_blocks = self._forward_impl(forward_inputs) + print(f'check forward output: {forward_outputs}') + + # event for async fetch outputs + event = torch.cuda.Event() + event.record() + + # put inside out_que + out = dict( + session_id=session_id, + # output=forward_outputs, + output=allocated_blocks, + event=event, + ) + self._out_que.put_nowait(out) + + # reset events, for h2d prepare the next round inputs + self.has_inputs.set() + + async def h2d_loop(self): + """Host to device loop. + + preprocess inputs and put them into in_que. copy inputs to device in a different stream. + """ + while True: + await self.has_inputs.wait() + + session_id, forward_inputs = await self.make_batch() + print(f'check forward_inputs: {forward_inputs}') + + # use a different stream to copy h2d + with torch.cuda.stream(self.in_stream): + forward_inputs = _try_to_cuda(forward_inputs) + + # put inputs inside in_que, reset has_inputs + self._in_que.put_nowait((session_id, forward_inputs)) + self.has_inputs.clear() + + async def d2h_loop(self): + """Device to host loop. + + copy outputs from device to host. put outputs into post processing queue. + """ + while True: + out = await self._out_que.get() + + # check event periodically + event = out.pop('event') + while not event.query(): + await asyncio.sleep(0.001) + + # use a different stream to copy d2h + with torch.cuda.stream(self.out_stream): + out = _try_to_cpu(out) + + self._post_proc_que.put_nowait(out) + + def start_loop(self): + """Start event loop.""" + event_loop = asyncio.get_event_loop() + + # set for the first batch + self.has_inputs.set() + + # forward task + logger.info('Create task MultiModal ModelAgent ForwardLoop.') + self._forward_task = event_loop.create_task(self.async_model_forward(), name='ModelAgentForwardLoop') + + # preprocess inputs task + logger.info('Create task MultiModal ModelAgent Preprocess.') + self._preprocess_task = event_loop.create_task(self.h2d_loop(), name='ModelAgentPreprocess') + + # postprocess outputs task + logger.info('Create task MultiModal ModelAgent Postprocess.') + self._postprocess_task = event_loop.create_task(self.d2h_loop(), name='ModelAgentPostprocess') + + loop_tasks: list[asyncio.Task] = [self._forward_task, self._preprocess_task, self._postprocess_task] + + # binding done callback + self._add_loop_tasks_done_callback(loop_tasks) + + @staticmethod + def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]): + """Add loop tasks done callback.""" + + def __task_callback(task: asyncio.Task) -> None: + """Raise exception on finish.""" + task_name = task.get_name() + try: + task.result() + except asyncio.CancelledError: + logger.debug(f'Task <{task_name}> cancelled.') + return + except Exception: + logger.exception(f'Task <{task_name}> failed') + finally: + for task in tasks: + if not task.done(): + task.cancel() + + for task in tasks: + task.add_done_callback(__task_callback) + + def build_cache_engine(self): + cache_engine = EncoderCacheEngine() + self.cache_engine = cache_engine + + def _forward_impl(self, inputs): + """Model forward implementation.""" + feats = self.model.forward(inputs) + + # put feat into encoder cache + feats = feats[0] # FIXME + num_required_blocks = feats.shape[0] // 256 + if len(self.cache_engine.free_blocks) < num_required_blocks: + raise RuntimeError('Not enough free blocks in cache engine') + allocated_blocks = self.cache_engine.free_blocks[:num_required_blocks] + + # move into dedicated mm cache pool + # TODO: we dont want copy, better to just write into that memory region + # but current transformers get_image_features() returns a new tensor, seems no way to achieve this + for i in range(num_required_blocks): + src_chunk = feats[i * 256:(i + 1) * 256, :] + dst_block_id = allocated_blocks[i] + self.cache_engine.gpu_cache[dst_block_id].copy_(src_chunk) + print(f'=> allocated blocks: {allocated_blocks}') + + return feats, allocated_blocks + + def init(self): + self.build_cache_engine() + + def close(self): + self.cache_engine = None + self.model = None + torch.cuda.empty_cache() + + +def build_model_agent(model): + """Build model agent.""" + model_agent = BaseModelAgent(model=model, ) + return model_agent diff --git a/lmdeploy/multimodal/engine/post_process.py b/lmdeploy/multimodal/engine/post_process.py new file mode 100644 index 0000000000..cfe4f72624 --- /dev/null +++ b/lmdeploy/multimodal/engine/post_process.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from typing import Dict, List + +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + + +class PostProcessor(): + + def __init__(self, model_agent): + print('=> PostProcessor init') + self.model_agent = model_agent + self._loop_task = None + + # session_id -> future + self._future_store: Dict[int, asyncio.Future] = {} + + def add_future(self, session_id, messages, future): + self._future_store[session_id] = (messages, future) + + def start_loop(self): + if not hasattr(self, '_loop_task') or self._loop_task is None: + logger.info('Starting PostProcessor loop') + self._loop_task = asyncio.create_task(self.async_loop()) + + def post_process(self, messages: List[Dict]): + # TODO: implement model-specific post process logic + return messages + + async def async_loop(self): + while True: + out = await self.model_agent._post_proc_que.get() + print(f'=> PostProcessor got data: {out}') + + out = self.post_process(out) + print(f'=> PostProcessor post-processed data: {out}') + + session_id = out.pop('session_id', None) + print(f'=> PostProcessor session_id: {session_id}') + messages, future = self._future_store.pop(session_id, None) + print(messages) + # add out as an attri named 'encoder_result' to messages + + messages[0]['encoder_result'] = out + if future and not future.done(): + print(f'=> PostProcessor setting future result: {messages}') + future.set_result(messages) + + def close(self): + """Cancel the background loop task.""" + if self._loop_task and not self._loop_task.done(): + self._loop_task.cancel() + logger.info('PostProcessor loop cancelled.') diff --git a/lmdeploy/multimodal/engine/pre_process.py b/lmdeploy/multimodal/engine/pre_process.py new file mode 100644 index 0000000000..ba04c5bbc7 --- /dev/null +++ b/lmdeploy/multimodal/engine/pre_process.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from typing import Dict, List + +from lmdeploy.utils import get_logger +from lmdeploy.vl.utils import load_image + +logger = get_logger('lmdeploy') + + +class PreProcessor(): + + def __init__(self, model_agent, post_processor): + print('=> PreProcessor init') + self._in_que = asyncio.Queue() + self.model_agent = model_agent + self.post_processor = post_processor + + self._loop_task = None + + @staticmethod + def collect_images(messages): + """Gather all images along with their respective parameters from the + messages and compile them into a single list. Each image is converted + to RGB color space. + + Args: + messages (List[Tuple[Image, Dict]]): a list of images with their + corresponding parameters + """ # noqa + images = [] + for message in messages: + content = message['content'] + if not isinstance(content, List): + continue + images.extend([(x['image'], { + k: v + for k, v in x.items() if k not in {'type', 'image'} + }) for x in content if x['type'] == 'image']) + return images + + @classmethod + async def async_convert_to_pil_images(cls, messages: List[Dict]) -> List[Dict]: + """Scan the provided messages to find image URLs or base64-encoded + image data. Loads the images into Pillow image objects. + + Args: + messages (List[Dict]): a user request of GPT4V message format + """ + if isinstance(messages, Dict): + messages = [messages] + assert isinstance(messages, List) + + out_messages = [None] * len(messages) + + def _inner_call(i, in_messages, out_messages): + role = in_messages[i]['role'] + content = in_messages[i]['content'] + assert role in ['system', 'user', 'assistant'], \ + f'unsupported role "{role}"' + if role != 'user' or isinstance(content, str): + # the content is a user's prompt or an assistant's prompt, + # returning it directly + out_messages[i] = in_messages[i] + return + # the role is a user and the content is a list, in which there + # might be image_url or image_data + assert isinstance(content, List) + message = dict(role=role, content=[]) + for item in content: + # image url or base64-encoded image data + if item['type'] == 'image_url': + """ + convert the following item: + { + 'type': 'image_url', + 'image_url': { + 'url': 'image url or base64-encoded image data', + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_url'].copy() + try: + url = data.pop('url') + image = load_image(url) + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'image_data': + """ + convert the following item: + { + 'type': 'image_data', + 'image_data': { + 'data': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + } + to: + { + 'type': 'image', + 'image': Pillow.Image, + 'key': 'value' # parameters used in image processing + ... + } + """ # noqa + data = item['image_data'].copy() + try: + image = data.pop('data') + data.update(type='image', image=image) + message['content'].append(data) + except KeyError: + logger.error(f'invalid format {message}') + elif item['type'] == 'text': + message['content'].append(item) + else: + logger.error(f'unexpected content type {message}') + out_messages[i] = message + + await asyncio.gather(*[ + asyncio.get_event_loop().run_in_executor(None, _inner_call, i, messages, out_messages) + for i in range(len(messages)) + ]) + return out_messages + + def start_loop(self): + """Creates a task for the given coroutine.""" + if not hasattr(self, '_loop_task') or self._loop_task is None: + logger.info('Starting PreProcessor loop') + self._loop_task = asyncio.create_task(self.async_loop()) + + async def async_loop(self): + while True: + session_id, messages = await self._in_que.get() + + messages = await self.async_convert_to_pil_images(messages) + print(f'after convert msg: {messages}') + + proc_inputs = self.model_agent.model.preprocess(messages) + print(f'after preproc msg: {proc_inputs}') + + # TODO: process to get token ids, image mask + + self.model_agent._pre_in_que.put_nowait((session_id, proc_inputs)) + + def process(self, session_id, messages, future): + if messages is not None: + self._in_que.put_nowait((session_id, messages)) + + self.post_processor.add_future(session_id, messages, future) + + def close(self): + """Cancel the background loop task.""" + if self._loop_task and not self._loop_task.done(): + self._loop_task.cancel() + logger.info('PreProcessor loop cancelled.') diff --git a/lmdeploy/multimodal/models/base.py b/lmdeploy/multimodal/models/base.py new file mode 100644 index 0000000000..c28dcaaab0 --- /dev/null +++ b/lmdeploy/multimodal/models/base.py @@ -0,0 +1,253 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import Dict, List, Union + +import numpy as np +from mmengine import Registry +from transformers import AutoConfig, AutoTokenizer + +from lmdeploy.archs import get_model_arch +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +BASE_MODELS = Registry('base_model') + + +class BaseModel(ABC): + """Abstract base model class in the multimodal engine.""" + _arch: Union[str, List[str]] = None + + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + """init.""" + self.model_path = model_path + self.with_llm = with_llm + self.max_memory = max_memory + self.backend = backend + if hf_config is None: + _, hf_config = get_model_arch(model_path) + self.hf_config = hf_config + self.image_token_id = self.get_pad_token_id(model_path, hf_config) or 0 + + def get_pad_token_id(self, model_path, hf_config): + """Get pad_token_id from hf_config or tokenizer.""" + pad_token_id = getattr(hf_config, 'pad_token_id', None) + if pad_token_id is None: + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + pad_token_id = getattr(tokenizer, 'pad_token_id', None) + except Exception as e: + print(e) + pass + return pad_token_id + + @abstractmethod + def build_preprocessor(self, ): + """Build the preprocessor.""" + raise NotImplementedError() + + def build_model(self, ): + """Build the vision part of a VLM model when backend is turbomind. + + But when `with_llm=True`, load the whole VLM model + """ + raise NotImplementedError() + + @abstractmethod + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """Preprocess multimodal data in the messages. + + The derived class, + i.e., a specific vision model, takes the charge of image preprocessing + and the result management. + It can integrate the result into the messages list, or insert it to + the individual image item. + Args: + message(Dict): multimodal data in a dict, which is as follows: + [ + {'role': 'user', 'content': 'user prompt'}, + {'role': 'assisant', 'content': 'AI reponse'}, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'string', + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + { + 'type': 'image', + 'image': pillow.Image, + 'key1': value1, + ... + }, + ... + ] + } + {....} + ] + Returns: + the message list with preprocessing results included, which is + determined by the derived classes + """ # noqa + raise NotImplementedError() + + # @abstractmethod + # def postprocess(self, model_outputs: torch.Tensor, processed_inputs: List[Dict]) -> List[Dict]: + # """ + # Takes the model outputs and performs post-process. + # """ + # raise NotImplementedError() + + def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]: + """Extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included, which is + determined by the derived classes + """ + if self.backend == 'turbomind': + raise NotImplementedError() + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs): + """Pack the preprocessing results in a format compatible with what is + required by pytorch engine. ONLY implement it when the backend is + pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'pytorch': + raise NotImplementedError() + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs): + """Pack the forwarding results in a format compatible with what is + required by turbomind engine. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + chat_template: the chat template defined in `lmdeploy/model.py` + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + if self.backend == 'turbomind': + raise NotImplementedError() + + @staticmethod + def collect_images(messages): + """Gather all images along with their respective parameters from the + messages and compile them into a single list. Each image is converted + to RGB color space. + + Args: + messages (List[Tuple[Image, Dict]]): a list of images with their + corresponding parameters + """ # noqa + images = [] + for message in messages: + content = message['content'] + if not isinstance(content, List): + continue + images.extend([(x['image'], { + k: v + for k, v in x.items() if k not in {'type', 'image'} + }) for x in content if x['type'] == 'image']) + return images + + def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): + """Auxiliary function to pack the preprocessing results in a format + compatible with what is required by pytorch engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect all preprocessing result from messages + preps = [x['content'] for x in messages if x['role'] == 'preprocess'] + assert len(preps) == 1 + preps = preps[0] + + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(preps)}') + + # calculate the image token offset for each image + input_ids = [] + for i, seg in enumerate(segs): + if i > 0 and i <= len(preps): + preps[i - 1].update(offset=len(input_ids)) + image_tokens = preps[i - 1]['image_tokens'] + assert self.image_token_id == preps[i - 1]['image_token_id'] + input_ids.extend([self.image_token_id] * image_tokens) + token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) + input_ids.extend(token_ids) + + return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) + + def to_turbomind_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): + """Auxiliary function to pack the forwarding results in a format + compatible with what is required by turbomind engine. + + Args: + messages(List[Dict]): the output of `preprocess` + prompt(str): the prompt after applying chat template + IMAGE_TOKEN(str): a placeholder where image tokens will be + inserted + tokenzer: the tokenizer model + sequence_start: starting flag of a sequence + """ + # collect image features from messages + features = [x['content'] for x in messages if x['role'] == 'forward'] + features = features[0] + features = [x.cpu().numpy() for x in features] + # split prompt into segments and validate data + segs = prompt.split(IMAGE_TOKEN) + assert len(segs) == len(features) + 1, (f'the number of {IMAGE_TOKEN} is not equal ' + f'to input images, {len(segs) - 1} vs {len(features)}') + + # tokenizer prompt, and get input_embeddings and input_embedding_ranges + input_ids = [] + begins = [] + ends = [] + for i, seg in enumerate(segs): + if i > 0 and i <= len(features): + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([self.image_token_id] * image_dim) + seg_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + return dict(prompt=prompt, input_ids=input_ids, input_embeddings=features, input_embedding_ranges=ranges) + + @classmethod + def match(cls, config: AutoConfig): + """Check whether the config match the model.""" + arch = config.architectures[0] if config.architectures else None + if arch and (arch == cls._arch or arch in cls._arch): + return True + return False diff --git a/lmdeploy/multimodal/models/builder.py b/lmdeploy/multimodal/models/builder.py new file mode 100644 index 0000000000..7c3f4a052c --- /dev/null +++ b/lmdeploy/multimodal/models/builder.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Optional, Union + +import torch + +from lmdeploy.archs import get_model_arch +from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig +from lmdeploy.multimodal.models.base import BASE_MODELS +from lmdeploy.utils import get_logger, get_model + +from .internvl3_hf import InternVL3VisionModel # noqa F401 + +logger = get_logger('lmdeploy') + + +def load_mm_model(model_path: str, + backend: str = '', + with_llm: bool = False, + backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None): + """Load multimodal model. + + Args: + model_path(str): the path or repo_id from model hub of the model + backend(str): the name of inference backend + with_llm(bool): load LLM model or not. Set it to False for VLM + inference scenarios and True for VLM quantization + backend_config: the config of the inference engine + """ + if not os.path.exists(model_path): + revision = getattr(backend_config, 'revision', None) + download_dir = getattr(backend_config, 'download_dir', None) + model_path = get_model(model_path, revision=revision, download_dir=download_dir) + + max_memory = None + if not with_llm: + tp = getattr(backend_config, 'tp', 1) + max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)} + + _, hf_config = get_model_arch(model_path) + kwargs = dict(model_path=model_path, with_llm=with_llm, max_memory=max_memory, hf_config=hf_config, backend=backend) + + for name, module in BASE_MODELS.module_dict.items(): + try: + if module.match(hf_config): + logger.info(f'matching multimodal model: {name}') + model = module(**kwargs) + model.build_preprocessor() + model.build_model() + return model + except Exception as e: + logger.error(f'build multimodal model {name} failed, {e}') + raise + + raise ValueError(f'unsupported multimodal model with config {hf_config}') diff --git a/lmdeploy/multimodal/models/internvl3_hf.py b/lmdeploy/multimodal/models/internvl3_hf.py new file mode 100644 index 0000000000..088eac509e --- /dev/null +++ b/lmdeploy/multimodal/models/internvl3_hf.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# TODO: may consider separate similar to transformers +# - Internvl +# - configuration_internvl.py +# - modeling_internvl.py +# - processing_internvl.py + +# but this may bring too many files, so currently we just put all things together + +from typing import Dict, List, Optional + +import torch +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoProcessor +from transformers.processing_utils import ImagesKwargs, ProcessingKwargs + +from lmdeploy.utils import get_logger +from lmdeploy.vl.model.utils import disable_logging + +from .base import BASE_MODELS, BaseModel + +logger = get_logger('lmdeploy') + + +class InternVLImagesKwargs(ImagesKwargs, total=False): + crop_to_patches: Optional[bool] + min_patches: Optional[int] + max_patches: Optional[int] + + +class InternVLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: InternVLImagesKwargs + _defaults = { + 'text_kwargs': { + 'padding': False, + }, + 'images_kwargs': { + 'crop_to_patches': True, + }, + 'videos_kwargs': {}, + } + + +@BASE_MODELS.register_module() +class InternVL3VisionModel(BaseModel): + """Internvl3 vision model.""" + + _arch = ['InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration'] + + def __init__(self, + model_path: str, + with_llm: bool = False, + max_memory: Dict[int, int] = None, + hf_config: AutoConfig = None, + backend: str = ''): + super().__init__(model_path, with_llm, max_memory, hf_config, backend) + self.arch = hf_config.architectures[0] + + def build_preprocessor(self): + self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) + tokenizer = self.processor.tokenizer + self.image_token_id = tokenizer.context_image_token_id + self.image_tokens_per_patch = self.processor.image_seq_length + self.tokenizer_init_kwargs = tokenizer.init_kwargs + + def build_model(self): + """Build the vision part of a VLM model when backend is turbomind, or + load the whole VLM model when `self.with_llm==True`""" + from accelerate import init_empty_weights + with init_empty_weights(): + if self.arch == 'InternVLForConditionalGeneration': + model = AutoModel.from_config(self.hf_config, trust_remote_code=True) + # if not self.with_llm: + # print('delete language model') + del model.language_model + elif self.arch == 'InternS1ForConditionalGeneration': + model = AutoModelForCausalLM.from_config(self.hf_config, trust_remote_code=True) + # if not self.with_llm: + # print('delete language model') + del model.model.language_model + else: + raise ValueError(f'unsupported model arch {self.arch}') + + model.half() + from accelerate import load_checkpoint_and_dispatch + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + checkpoint=self.model_path, + # device_map='auto' if not self.with_llm else {'': 'cpu'}, + device_map='auto', + max_memory=self.max_memory, + no_split_module_classes=['InternVLVisionLayer', 'InternS1VisionLayer'], + dtype=torch.half) + # We need eval mode to freeze the weights in model, thus, + # avoid randomness in inference. + self.model = model.eval() + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """Refers to `super.preprocess() for spec.""" + from transformers.image_utils import make_flat_list_of_images + output_kwargs = self.processor._merge_kwargs( + InternVLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer_init_kwargs, + **{ + 'return_tensors': 'pt', + 'add_special_tokens': False + }, + ) + images = self.collect_images(messages) + images = [image.convert('RGB') for image, _ in images] + num_image = len(images) + images = make_flat_list_of_images(images) + image_inputs = self.processor.image_processor(images, **output_kwargs['images_kwargs']) + image_num_patches = image_inputs.pop('num_patches').cpu().numpy().tolist() + image_pixel_values = image_inputs.pop('pixel_values') + outputs = [] + cum_num_patches = 0 + for idx in range(num_image): + cur_num_patches = image_num_patches[idx] + pixel_values = image_pixel_values[cum_num_patches:cum_num_patches + cur_num_patches, ...] + cum_num_patches += cur_num_patches + data = dict(pixel_values=pixel_values, + image_tokens=self.image_tokens_per_patch * cur_num_patches, + image_token_id=self.image_token_id) + outputs.append(data) + + return outputs + + @torch.no_grad() + def forward(self, processed_inputs: List[Dict]) -> torch.Tensor: + # FIXME: consider batch? + outputs = [] + pixel_values = [x['pixel_values'] for x in processed_inputs] + split = [x.shape[0] for x in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_values = pixel_values.to(dtype=torch.float16) + logger.info(f'vision forward shape: {pixel_values.shape}') + feats = self.model.get_image_features( + pixel_values, + vision_feature_layer=self.hf_config.vision_feature_layer, + vision_feature_select_strategy=self.hf_config.vision_feature_select_strategy, + ) + feats = torch.split(feats, split, dim=0) + outputs.extend([x.reshape(-1, x.shape[-1]) for x in feats]) + return outputs + + @staticmethod + def proc_messages( + messages, + chat_template, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + ): + """Apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['preprocess', 'forward']: + continue + n_images = len([1 for x in message['content'] if x['type'] == 'image']) + content = [x.get('text', '') for x in message['content'] if x['type'] == 'text'] + prompt = content[0] + if IMAGE_TOKEN in prompt and f'{IMAGE_TOKEN}' not in prompt: + prompt = prompt.replace(f'{IMAGE_TOKEN}', f'{IMAGE_TOKEN}') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + prompt = prompt.replace('', '') + elif IMAGE_TOKEN not in prompt: + prompt = f'{IMAGE_TOKEN * n_images}\n' + prompt + else: + pass + prompt_messages.append(dict(role='user', content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + return prompt, IMAGE_TOKEN + + def to_pytorch(self, + messages, + chat_template, + tokenizer, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + **kwargs): + prompt, IMAGE_TOKEN = self.proc_messages(messages, + chat_template, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) + + def to_turbomind(self, + messages, + chat_template, + tokenizer, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + **kwargs): + prompt, IMAGE_TOKEN = self.proc_messages(messages, + chat_template, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + return self.to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 3f697f9880..f7c606649f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -30,7 +30,7 @@ from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine -from .guided_process import GuidedDecodingMangager +# from .guided_process import GuidedDecodingMangager from .logits_process import FusedLogitsProcessor, SamplingInputs logger = get_logger('lmdeploy') @@ -354,7 +354,7 @@ def __init__(self, self.patched_model = None self.cache_engine = None self.profiler: AgentProfiler = None - self.guided_decoding_manager = GuidedDecodingMangager(self.tokenizer, self.sampling_vocab_size) + # self.guided_decoding_manager = GuidedDecodingMangager(self.tokenizer, self.sampling_vocab_size) # microbatch self.enable_microbatch = self.dist_ctx.dist_config.enable_microbatch diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 466e102e22..c1c0a545cc 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -48,7 +48,7 @@ class Request: def _run_until_complete(future: Awaitable): - """Run untile complete.""" + """Run until complete.""" try: event_loop = asyncio.get_event_loop() except Exception: diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 0e7bbdf7eb..e3a856b2a3 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -911,6 +911,15 @@ def is_error(status): # manually end pytorch session await inst.async_end(session_id) + async def encode_generate(self, messages, session_id: int): + """Perform encoding.""" + if not hasattr(self.engine, 'encode'): + raise NotImplementedError('encode() is not implemented for the current backend') + + encoder_result = await self.engine.encode(messages, session_id) + + return encoder_result + def _run(self, fn=None, coro=None, loop=None): assert (fn or coro) and not (fn and coro) loop = loop or self.internal_thread.loop diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 170968fa0e..c5abe374c2 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -25,7 +25,7 @@ from lmdeploy.messages import GenerationConfig, LogitsProcessor, PytorchEngineConfig, TurbomindEngineConfig from lmdeploy.metrics.metrics_processor import metrics_processor from lmdeploy.model import ChatTemplateConfig -from lmdeploy.pytorch.disagg.config import DistServeEngineConfig +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest, EncoderResult, MigrationRequest) @@ -372,6 +372,18 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id!r} is occupied.') + # if encoder, we only do encoding and return + engine_role = VariableInterface.async_engine.backend_config.role + if engine_role == EngineRole.Encoder: + from fastapi.encoders import jsonable_encoder + from fastapi.responses import JSONResponse + encoder_result = await VariableInterface.async_engine.encode_generate( + request.messages, + request.session_id, + ) + print(f'api_server, v1/completion, encoder_result: {encoder_result}') + return JSONResponse(jsonable_encoder(encoder_result)) + model_name = request.model adapter_name = None if model_name != VariableInterface.async_engine.model_name: diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index a784e67e74..140b5c7bbd 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -7,6 +7,7 @@ from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig from lmdeploy.model import BaseChatTemplate +from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.utils import get_logger, try_import_deeplink from lmdeploy.vl.engine import ImageEncoder @@ -38,6 +39,14 @@ def __init__(self, 'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501 ) + # if the server started as encoder, we replace with mm engine + # TODO: find a way to disable LLM engine initialization and weight loading + if self.backend_config.role == EngineRole.Encoder: + from lmdeploy.multimodal.engine.engine import MultiModalEngine + self.engine = MultiModalEngine.from_pretrained(pretrained_model_name_or_path=model_path, + tokenizer=self.tokenizer, + engine_config=backend_config) + @classmethod def _convert_prompts(cls, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]]): """Convert prompts to openai GPT4V format.""" From 8d80d1d5cd72d974d630372d6eeebaf80fa72562 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 16 Oct 2025 21:30:35 +0800 Subject: [PATCH 07/11] WIP: adjust encoder outputs --- lmdeploy/multimodal/engine/engine.py | 57 ++++++++++++++++++++++ lmdeploy/multimodal/engine/model_agent.py | 8 ++- lmdeploy/multimodal/engine/post_process.py | 6 +-- lmdeploy/serve/async_engine.py | 32 ++++++++++-- lmdeploy/serve/vl_async_engine.py | 1 + 5 files changed, 91 insertions(+), 13 deletions(-) diff --git a/lmdeploy/multimodal/engine/engine.py b/lmdeploy/multimodal/engine/engine.py index 1755591862..39b5de6eff 100644 --- a/lmdeploy/multimodal/engine/engine.py +++ b/lmdeploy/multimodal/engine/engine.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio import copy +from typing import Dict, List, Optional from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, + DistServeInitRequest) from lmdeploy.utils import get_logger from ..models.builder import load_mm_model @@ -18,6 +22,7 @@ class MultiModalEngine(): def __init__(self, model_path: str, + chat_template: object, tokenizer: object, engine_config: PytorchEngineConfig = None, trust_remote_code: bool = True) -> None: @@ -25,6 +30,7 @@ def __init__(self, if engine_config is None: engine_config = PytorchEngineConfig() self.engine_config = copy.deepcopy(engine_config) + self.chat_template = chat_template self.tokenizer = tokenizer # build model @@ -38,6 +44,8 @@ def __init__(self, self.post_processor = PostProcessor(self.model_agent) self.pre_processor = PreProcessor(self.model_agent, self.post_processor) + self.engine_conn = EngineP2PConnection(self) + def start_loop(self): """Start async loops.""" # invoked in api server start up event, where we already have running event loop started by uvicorn.run() @@ -55,12 +63,14 @@ def close(self): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, + chat_template: object, tokenizer: object, engine_config: PytorchEngineConfig = None, trust_remote_code: bool = True, **kwargs): """Create a MultiModalEngine instance.""" return cls(model_path=pretrained_model_name_or_path, + chat_template=chat_template, tokenizer=tokenizer, engine_config=engine_config, trust_remote_code=trust_remote_code) @@ -73,3 +83,50 @@ async def encode(self, messages, session_id: int): self.pre_processor.process(session_id, messages, future) return await future + + # TODO: change this, put into pre-processor? + async def wrap_for_pytorch( + self, + messages: List[Dict], + chat_template, + tokenizer, + sequence_start, + tools: Optional[List[object]] = None, + enable_thinking: Optional[bool] = None, + ) -> List[Dict]: + """ + Args: + messages (List[Dict]): a list of message, which is supposed to be + the output of `preprocess` + Returns: + a dict which will be passed to pytorch engine_instance's forward. + The dict is like the following: + Dict( + 'prompt': 'the prompt after applying chat template' + 'input_ids': [], + 'multimodal': { + 'pixel_values': torch.Tensor, + ... + ] + ) + """ + result = self.model.to_pytorch(messages, + chat_template, + tokenizer, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + # clear data + for i, message in enumerate(messages): + if isinstance(message['content'], List): + messages[i]['preprocess'] = None + return result + + async def p2p_initialize(self, init_request: DistServeInitRequest): + return await self.engine_conn.p2p_initialize(init_request) + + def p2p_connect(self, conn_request: DistServeConnectionRequest): + return self.engine_conn.p2p_connect(conn_request) + + async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): + return self.engine_conn.p2p_drop_connect(drop_conn_request) diff --git a/lmdeploy/multimodal/engine/model_agent.py b/lmdeploy/multimodal/engine/model_agent.py index f15402311f..cd7e8ab5e4 100644 --- a/lmdeploy/multimodal/engine/model_agent.py +++ b/lmdeploy/multimodal/engine/model_agent.py @@ -64,7 +64,6 @@ def __init__(self, model): self.out_stream = torch.cuda.Stream() self.forward_stream = torch.cuda.Stream() - # self.model_path = model_path self.model = model self.device = 'cuda' @@ -87,8 +86,7 @@ async def async_model_forward(self): self.next_inputs = None with torch.cuda.stream(self.forward_stream): - forward_outputs, allocated_blocks = self._forward_impl(forward_inputs) - print(f'check forward output: {forward_outputs}') + feats, allocated_blocks = self._forward_impl(forward_inputs) # event for async fetch outputs event = torch.cuda.Event() @@ -97,8 +95,8 @@ async def async_model_forward(self): # put inside out_que out = dict( session_id=session_id, - # output=forward_outputs, - output=allocated_blocks, + feats=feats, + block_ids=allocated_blocks, event=event, ) self._out_que.put_nowait(out) diff --git a/lmdeploy/multimodal/engine/post_process.py b/lmdeploy/multimodal/engine/post_process.py index cfe4f72624..5c9ea70e53 100644 --- a/lmdeploy/multimodal/engine/post_process.py +++ b/lmdeploy/multimodal/engine/post_process.py @@ -38,12 +38,8 @@ async def async_loop(self): print(f'=> PostProcessor post-processed data: {out}') session_id = out.pop('session_id', None) - print(f'=> PostProcessor session_id: {session_id}') messages, future = self._future_store.pop(session_id, None) - print(messages) - # add out as an attri named 'encoder_result' to messages - - messages[0]['encoder_result'] = out + messages[0]['block_ids'] = out['block_ids'] if future and not future.done(): print(f'=> PostProcessor setting future result: {messages}') future.set_result(messages) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index e3a856b2a3..4b35e71bd6 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -911,14 +911,40 @@ def is_error(status): # manually end pytorch session await inst.async_end(session_id) - async def encode_generate(self, messages, session_id: int): + async def encode_generate(self, messages, session_id: int, **kwargs): """Perform encoding.""" if not hasattr(self.engine, 'encode'): raise NotImplementedError('encode() is not implemented for the current backend') encoder_result = await self.engine.encode(messages, session_id) - - return encoder_result + print(f'check encoder_result: {encoder_result}') + remote_block_ids = encoder_result[0]['block_ids'] + + # FIXME: for simplicity, we reuse previous get_prompt_input function + # in order to get input_ids, and for calculatin image_mask + # but get_prompt_input() will invoke vl_encoder preprocess(), which duplicate with above encode() + # we should adopt a new function to get input_ids and image_mask only + prompt = messages + self.request_logger.log_prompt(session_id=session_id, prompt=prompt) + prompt_input = await self._get_prompt_input(prompt, + do_preprocess=True, + sequence_start=True, + adapter_name=None, + **kwargs) + prompt = prompt_input['prompt'] + input_ids = prompt_input['input_ids'] + + # get image_mask + image_token_id = prompt_input['multimodal'][0]['image_token_id'] + image_mask = [1 if x == image_token_id else 0 for x in prompt_input['input_ids']] + + # pack results together and return to api server + return { + 'token_ids': input_ids, + 'image_mask': image_mask, + 'remote_session_id': session_id, + 'remote_block_ids': remote_block_ids + } def _run(self, fn=None, coro=None, loop=None): assert (fn or coro) and not (fn and coro) diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index 140b5c7bbd..a51abd3628 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -44,6 +44,7 @@ def __init__(self, if self.backend_config.role == EngineRole.Encoder: from lmdeploy.multimodal.engine.engine import MultiModalEngine self.engine = MultiModalEngine.from_pretrained(pretrained_model_name_or_path=model_path, + chat_template=self.chat_template, tokenizer=self.tokenizer, engine_config=backend_config) From 31c761f497f6a47fceb3459528d8e1ff0146f5ac Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 17 Oct 2025 15:43:22 +0800 Subject: [PATCH 08/11] fix p2p connection, fix proxy --- lmdeploy/multimodal/engine/engine.py | 36 +++++++++++++++++++++---- lmdeploy/pytorch/engine/model_agent.py | 14 +++++++--- lmdeploy/serve/openai/api_server.py | 5 ++-- lmdeploy/serve/proxy/proxy.py | 32 +++++++++++++--------- lmdeploy/turbomind/turbomind.py | 37 +++++++++++++++----------- 5 files changed, 85 insertions(+), 39 deletions(-) diff --git a/lmdeploy/multimodal/engine/engine.py b/lmdeploy/multimodal/engine/engine.py index 39b5de6eff..ea1eb9924b 100644 --- a/lmdeploy/multimodal/engine/engine.py +++ b/lmdeploy/multimodal/engine/engine.py @@ -3,10 +3,15 @@ import copy from typing import Dict, List, Optional +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse + from lmdeploy.messages import PytorchEngineConfig from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection -from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, - DistServeInitRequest) +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeConnectionResponse, + DistServeConnectionStatus, DistServeDropConnectionRequest, + DistServeEngineEndpointInfo, DistServeInitRequest, + DistServeInitResponse) from lmdeploy.utils import get_logger from ..models.builder import load_mm_model @@ -122,11 +127,32 @@ async def wrap_for_pytorch( messages[i]['preprocess'] = None return result - async def p2p_initialize(self, init_request: DistServeInitRequest): - return await self.engine_conn.p2p_initialize(init_request) + def p2p_initialize(self, init_request: DistServeInitRequest): + """Initialize p2p connection. + + FIXME: This method is synchronous (`def`). + The standard PytorchEngine (in multi-process mode) has a synchronous + `p2p_initialize` that acts as an RPC bridge to an async worker. + To maintain a compatible interface for the `AsyncEngine` adapter, + this single-process engine also provides a synchronous implementation. + """ + kv_eps = self.model_agent.cache_engine.p2p_initialize(init_request) + # encoder has no zmq communication for now; return a dummy address + zmq_addr = 'tcp://0.0.0.0:65001' + resp = DistServeInitResponse( + status=DistServeConnectionStatus.SUCCESS, + engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_addr), + kvtransfer_endpoint_info=kv_eps, + ) + return JSONResponse(jsonable_encoder(resp.model_dump())) def p2p_connect(self, conn_request: DistServeConnectionRequest): - return self.engine_conn.p2p_connect(conn_request) + self.model_agent.cache_engine.p2p_connect( + conn_request.remote_engine_id, + conn_request.remote_kvtransfer_endpoint_info, + ) + resp = DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS) + return JSONResponse(jsonable_encoder(resp.model_dump())) async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): return self.engine_conn.p2p_drop_connect(drop_conn_request) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index f7c606649f..0507167b53 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -30,7 +30,7 @@ from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine -# from .guided_process import GuidedDecodingMangager +from .guided_process import GuidedDecodingMangager from .logits_process import FusedLogitsProcessor, SamplingInputs logger = get_logger('lmdeploy') @@ -354,7 +354,11 @@ def __init__(self, self.patched_model = None self.cache_engine = None self.profiler: AgentProfiler = None - # self.guided_decoding_manager = GuidedDecodingMangager(self.tokenizer, self.sampling_vocab_size) + try: + self.guided_decoding_manager = GuidedDecodingMangager(self.tokenizer, self.sampling_vocab_size) + except ValueError as e: + logger.warning(f'Failed to create GuidedManager for tokenizer {self.tokenizer}: {e}') + self.guided_decoding_manager = None # microbatch self.enable_microbatch = self.dist_ctx.dist_config.enable_microbatch @@ -861,7 +865,8 @@ def stop(self): if not self._preprocess_task.done(): self._preprocess_task.cancel() - self.guided_decoding_manager.clear() + if self.guided_decoding_manager: + self.guided_decoding_manager.clear() async def stop_async(self): """Stop task.""" @@ -891,7 +896,8 @@ async def stop_async(self): except asyncio.CancelledError: logger.debug('ModelAgent preprocess task cancelled.') - self.guided_decoding_manager.clear() + if self.guided_decoding_manager: + self.guided_decoding_manager.clear() def set_forward_inputs(self, inputs): """Set forward inputs.""" diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index c5abe374c2..66123c26cd 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -375,14 +375,13 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque # if encoder, we only do encoding and return engine_role = VariableInterface.async_engine.backend_config.role if engine_role == EngineRole.Encoder: - from fastapi.encoders import jsonable_encoder - from fastapi.responses import JSONResponse encoder_result = await VariableInterface.async_engine.encode_generate( request.messages, request.session_id, ) + # TODO: use CompleteResponse prototype print(f'api_server, v1/completion, encoder_result: {encoder_result}') - return JSONResponse(jsonable_encoder(encoder_result)) + return encoder_result model_name = request.model adapter_name = None diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 7056ef8426..d2e5719b26 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -23,7 +23,7 @@ from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy from lmdeploy.pytorch.disagg.conn.epd_proxy_conn import EPDConnectionPool -from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import EncoderResult, MigrationProtocol, MigrationRequest from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool from lmdeploy.pytorch.disagg.messages import EPDConnectionMessage, PDConnectionMessage from lmdeploy.serve.openai.api_server import check_api_key, create_error_response @@ -510,7 +510,7 @@ async def connection_warmup(): rdma_config=node_manager.rdma_config, )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes ]) - logger.info(f'encoder nodes: {node_manager.decode_nodes}\nprefill nodes: {node_manager.prefill_nodes}') + logger.info(f'encoder nodes: {node_manager.encoder_nodes}\nprefill nodes: {node_manager.hybrid_nodes}') # FIXME: we set pd_urls to use hybrid nodes, since we start LLM server in hybrid role await asyncio.gather(*[ node_manager.ep_connection_pool.connect( @@ -620,24 +620,32 @@ def _need_encoder(msgs: List[Dict]) -> bool: else: logger.info(f'Encoder stage dispatched to {encoder_url}') enc_start = node_manager.pre_call(encoder_url) - # encoder endpoint path: using vision server example: /v1/chat/completion (singular) # fall back to /v1/chat/completions if first fails - encoder_response_text = await node_manager.generate(request_dict, encoder_url, - '/v1/chat/completion') - if isinstance(encoder_response_text, (bytes, bytearray)): - try: - encoder_response_text = encoder_response_text.decode('utf-8') - except Exception: # noqa - pass + encoder_info = await node_manager.generate(request_dict, encoder_url, '/v1/chat/completions') + print(encoder_info) + encoder_info = json.loads(encoder_info) + remote_session_id = encoder_info['remote_session_id'] + remote_block_ids = encoder_info['remote_block_ids'] + remote_token_ids = encoder_info['token_ids'] + image_mask = encoder_info['image_mask'] + request_dict['encoder_result'] = EncoderResult( + token_ids=remote_token_ids, + image_mask=image_mask, + protocol=node_manager.migration_protocol, + remote_engine_id=encoder_url, + remote_session_id=remote_session_id, + remote_block_ids=remote_block_ids).model_dump(mode='json') + # simple heuristic: if returns timeout structure (bytes) keep original try: - enc_json = json.loads(encoder_response_text) + # enc_json = json.loads(encoder_info) + enc_json = request_dict except Exception: # try alternative endpoint if first not json (maybe 404 HTML) alt_text = await node_manager.generate(request_dict, encoder_url, '/v1/chat/completions') try: enc_json = json.loads(alt_text) - encoder_response_text = alt_text + # encoder_response_text = alt_text except Exception: logger.error('Encoder stage failed: cannot parse JSON; skip encoder stage') enc_json = None diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index a1c37f573d..0a849c4ef9 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -708,22 +708,29 @@ async def async_stream_infer(self, if gen_config.response_format is not None: tokenizer = self.tm_model.tokenizer vocab_size = self.tm_model.config.model_config.vocab_size - decode_grammar_type = gen_config.response_format['type'] - decode_grammar = gen_config.response_format[decode_grammar_type]['schema'] - - tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size) - compiler = _xgr.GrammarCompiler(tokenizer_info) - - if decode_grammar_type == 'json_schema': - decode_grammar = json.dumps(decode_grammar) - grammar = compiler.compile_json_schema(decode_grammar) - elif decode_grammar_type == 'regex_schema': - decode_grammar = str(decode_grammar) - grammar = compiler.compile_regex(decode_grammar) - else: - assert False, f'Decode grammar type {decode_grammar_type} should be in ["json_schema", "regex_schema"]' - self.model_inst.set_grammar(grammar) + try: + tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size) + decode_grammar_type = gen_config.response_format['type'] + decode_grammar = gen_config.response_format[decode_grammar_type]['schema'] + + compiler = _xgr.GrammarCompiler(tokenizer_info) + + if decode_grammar_type == 'json_schema': + decode_grammar = json.dumps(decode_grammar) + grammar = compiler.compile_json_schema(decode_grammar) + elif decode_grammar_type == 'regex_schema': + decode_grammar = str(decode_grammar) + grammar = compiler.compile_regex(decode_grammar) + else: + assert False, f'Decode grammar type {decode_grammar_type} should be in ' \ + '["json_schema", "regex_schema"]' + + self.model_inst.set_grammar(grammar) + except ValueError as e: + logger.warning(f'Failed to initialize guided decoding for tokenizer {tokenizer}, ' + f'disable guided decoding: {e.message}') + gen_config.response_format = None session = _tm.SessionParam(id=session_id, step=step, start=sequence_start, end=sequence_end) From d0b0ddf2e755641212c47f0d1a02f36d961fc08e Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 17 Oct 2025 15:47:21 +0800 Subject: [PATCH 09/11] update test scripts, rm previous vl_oai.py --- 0_encode.sh | 12 +- 0_pd.sh | 10 +- 0_proxy.sh | 4 +- 0_vl_oai.py | 386 ---------------------------------------------------- 4 files changed, 18 insertions(+), 394 deletions(-) delete mode 100644 0_vl_oai.py diff --git a/0_encode.sh b/0_encode.sh index d897bba09c..7942ca091e 100644 --- a/0_encode.sh +++ b/0_encode.sh @@ -1 +1,11 @@ -python 0_vl_oai.py +model_path='/nvme3/interns1-mini-remote' + + +CUDA_VISIBLE_DEVICES=1 lmdeploy serve api_server \ + ${model_path} \ + --tp 1 \ + --role Encoder \ + --backend pytorch \ + --server-port 23334 \ + --proxy-url http://0.0.0.0:8001 \ + --log-level INFO diff --git a/0_pd.sh b/0_pd.sh index 2ebaae1cf1..54c132f9fe 100644 --- a/0_pd.sh +++ b/0_pd.sh @@ -1,12 +1,12 @@ -model_path="/mnt/137_nvme3/interns1-mini-remote" +model_path='/nvme3/interns1-mini-remote' CUDA_VISIBLE_DEVICES=2 lmdeploy serve api_server \ - $model_path \ - --server-port 23334 \ - --role Hybrid \ - --proxy-url http://0.0.0.0:8001 \ + ${model_path} \ --tp 1 \ + --role Hybrid \ --backend pytorch \ + --server-port 23335 \ + --proxy-url http://0.0.0.0:8001 \ --disable-vision-encoder \ --log-level INFO diff --git a/0_proxy.sh b/0_proxy.sh index 46e78d276c..ebcaa2f294 100644 --- a/0_proxy.sh +++ b/0_proxy.sh @@ -5,7 +5,7 @@ curl -X POST http://0.0.0.0:8001/distserve/connection_warmup curl http://0.0.0.0:8001/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "/mnt/137_nvme3/interns1-mini-remote", + "model": "/nvme3/interns1-mini-remote", "messages": [ { "role": "user", @@ -18,7 +18,7 @@ curl http://0.0.0.0:8001/v1/chat/completions \ curl http://0.0.0.0:8001/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "/mnt/137_nvme3/interns1-mini-remote", + "model": "/nvme3/interns1-mini-remote", "messages": [ { "role": "user", diff --git a/0_vl_oai.py b/0_vl_oai.py deleted file mode 100644 index 3a560047fc..0000000000 --- a/0_vl_oai.py +++ /dev/null @@ -1,386 +0,0 @@ -import asyncio -import logging -import os -from contextlib import asynccontextmanager -from dataclasses import asdict, dataclass -from threading import Lock -from typing import Dict, List, Optional - -import torch -import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse, Response - -from lmdeploy import Tokenizer -from lmdeploy.archs import get_model_arch -from lmdeploy.model import ChatTemplateConfig, best_match_model -from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl -from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole -from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, - DistServeConnectionResponse, DistServeConnectionStatus, - DistServeEngineEndpointInfo, DistServeInitRequest, - DistServeInitResponse, MigrationProtocol) -from lmdeploy.pytorch.engine.encoder_cache_engine import EncoderCacheEngine -from lmdeploy.serve.openai.launch_server import get_host_ip -from lmdeploy.serve.openai.protocol import ChatCompletionRequest -from lmdeploy.vl.model.internvl3_hf import InternVL3VisionModel -from lmdeploy.vl.utils import load_image - -os.environ['CUDA_VISIBLE_DEVICES'] = '1' - -# --- 1. 全局模型变量 --- -model_instance: InternVL3VisionModel = None # type: ignore -migration_backend_impl: Optional[MigrationBackendImpl] = None -model_path = '/mnt/137_nvme3/interns1-mini-remote' -SERVER_PORT = 8086 -chat_template_name = best_match_model(model_path.lower()) -# chat_template_config = ChatTemplateConfig(chat_template_name) -chat_template_config = ChatTemplateConfig(model_name=chat_template_name, model_path=model_path) -chat_template = chat_template_config.chat_template -tokenizer = Tokenizer(model_path) -# encoder_url = f"http://{get_host_ip()}:{SERVER_PORT}" -encoder_url = f'http://0.0.0.0:{SERVER_PORT}' -# 初始化 Cache Engine 相关变量 -cache_engine_instance: EncoderCacheEngine = None # type: ignore -NUM_GPU_BLOCKS = 128 -free_blocks: List[int] = [] -session_blocks: Dict[int, List[int]] = {} -session_counter = 0 -block_manager_lock = Lock() # 线程锁,用于安全地分配和释放块 - - -def get_model_list(): - return [model_path] - - -# --- 2. 生命周期事件处理器 --- - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # 启动时的事件 - global model_instance, cache_engine_instance, free_blocks - logger = logging.getLogger('uvicorn.error') - logger.setLevel(logging.INFO) - logger.info('模型加载中,请稍候...') - try: - cfg = get_model_arch(model_path)[1] - kwargs = dict(model_path=model_path, with_llm=False, max_memory=None, hf_config=cfg, backend='pytorch') - model_instance = InternVL3VisionModel(**kwargs) - model_instance.build_model() - model_instance.build_preprocessor() - logger.info('✅ 模型加载成功!服务器已准备就绪。') - except Exception as e: - logger.error(f'❌ 模型加载失败: {e}', exc_info=True) - raise RuntimeError(f'模型初始化失败: {e}') from e - - # TODO MigrationBackendImpl () - - # TODO 增加 memory 页表注册 - logger.info('正在初始化 Cache Engine...') - try: - # 实例化 CacheEngine - cache_engine_instance = EncoderCacheEngine(NUM_GPU_BLOCKS) - - # 初始化空闲块列表 - free_blocks = list(range(NUM_GPU_BLOCKS)) - logger.info(f'✅ Cache Engine 初始化成功,总共 {NUM_GPU_BLOCKS} 个缓存块。') - - except Exception as e: - logger.error(f'❌ Cache Engine 初始化失败: {e}', exc_info=True) - raise RuntimeError(f'Cache Engine 初始化失败: {e}') from e - - # TODO 向 proxy 发送node add - try: - import requests - engine_role = EngineRole.Encoder.value - url = 'http://0.0.0.0:8001/nodes/add' - data = {'url': f'http://0.0.0.0:{SERVER_PORT}', 'status': {'models': get_model_list(), 'role': engine_role}} - headers = {'accept': 'application/json', 'Content-Type': 'application/json'} - response = requests.post(url, headers=headers, json=data) - - if response.status_code != 200: - raise HTTPException(status_code=response.status_code, detail=response.text) - else: - logger.info('✅ 服务注册成功!') - except Exception as e: - logger.error(f'Service registration failed: {e}') - # TODO p2p initialize(warm up) - # /nvme2/share/linbinbin1/src/lmdeploy-encoder/lmdeploy/serve/openai/api_server.py PD DIs - - # TODO p2p conn - - yield # 应用运行期间 - - # 关闭时的事件(如果需要清理资源) - logger.info('🔄 正在关闭服务器...') - del model_instance - torch.cuda.empty_cache() - logger.info('模型资源已释放。') - - -# --- 3. 初始化 FastAPI 应用 --- -app = FastAPI(title='InternVL Vision Model Server (Arrow Edition)', - description='一个用于通过 InternVL3 模型为图片数组提取特征张量,并使用 Apache Arrow 高效返回结果的 API', - version='1.2.0', - lifespan=lifespan) -logger = logging.getLogger('uvicorn.error') -logger.setLevel(logging.INFO) - - -# --- 4. 辅助函数 --- -def find_forward_content(output: list) -> list: - for item in output: - if isinstance(item, dict) and item.get('role') == 'forward': - return item.get('content', []) - return [] - - -async def async_convert_to_pil_images(messages: List[Dict]) -> List[Dict]: - """Scan the provided messages to find image URLs or base64-encoded image - data. Loads the images into Pillow image objects. - - Args: - messages (List[Dict]): a user request of GPT4V message format - """ - if isinstance(messages, Dict): - messages = [messages] - assert isinstance(messages, List) - - out_messages = [None] * len(messages) - - def _inner_call(i, in_messages, out_messages): - role = in_messages[i]['role'] - content = in_messages[i]['content'] - assert role in ['system', 'user', 'assistant'], \ - f'unsupported role "{role}"' - if role != 'user' or isinstance(content, str): - # the content is a user's prompt or an assistant's prompt, - # returning it directly - out_messages[i] = in_messages[i] - return - # the role is a user and the content is a list, in which there - # might be image_url or image_data - assert isinstance(content, List) - message = dict(role=role, content=[]) - for item in content: - # image url or base64-encoded image data - if item['type'] == 'image_url': - """ - convert the following item: - { - 'type': 'image_url', - 'image_url': { - 'url': 'image url or base64-encoded image data', - 'key': 'value' # parameters used in image processing - ... - } - } - to: - { - 'type': 'image', - 'image': Pillow.Image, - 'key': 'value' # parameters used in image processing - ... - } - """ # noqa - data = item['image_url'].copy() - try: - url = data.pop('url') - image = load_image(url) - data.update(type='image', image=image) - message['content'].append(data) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'image_data': - """ - convert the following item: - { - 'type': 'image_data', - 'image_data': { - 'data': Pillow.Image, - 'key': 'value' # parameters used in image processing - ... - } - } - to: - { - 'type': 'image', - 'image': Pillow.Image, - 'key': 'value' # parameters used in image processing - ... - } - """ # noqa - data = item['image_data'].copy() - try: - image = data.pop('data') - data.update(type='image', image=image) - message['content'].append(data) - except KeyError: - logger.error(f'invalid format {message}') - elif item['type'] == 'text': - message['content'].append(item) - else: - logger.error(f'unexpected content type {message}') - out_messages[i] = message - - await asyncio.gather(*[ - asyncio.get_event_loop().run_in_executor(None, _inner_call, i, messages, out_messages) - for i in range(len(messages)) - ]) - return out_messages - - -@app.get('/health') -async def health() -> Response: - """Health check.""" - return Response(status_code=200) - - -@dataclass -class EncoderResult: - token_ids: List[int] - image_mask: List[int] - # MigrationRequest 中相似的字段 - protocol: MigrationProtocol # RDMA - remote_engine_id: str # 标识 encode 引擎编号 - remote_session_id: int # 用于 encode 引擎释放指定区域 - remote_block_ids: List[int] # 从 encode 引擎读取指定区域内容 - - -# --- 5. API 端点:处理图片并返回特征 --- - - -@app.post('/v1/chat/completion', summary='接收 open ai 格式的请求,并且返回给 proxy') -async def process_images(request_raw: ChatCompletionRequest = None): - if model_instance is None: - raise HTTPException(status_code=503, detail='模型正在加载或加载失败,请稍后再试。') - - request = request_raw.model_dump() - messages = await async_convert_to_pil_images(request['messages']) - results = model_instance.preprocess(messages) - # print(results) - # import pdb; pdb.set_trace() - - # prompt = chat_template.messages2prompt(messages) - # input_ids = tokenizer.encode(prompt, add_bos=True) # 只包含了文本部分 - # prompt, input_ids(包含了图片 token 序列), multi_modal - # 这个是将要返回的内容 - to_pt = model_instance.to_pytorch(results, chat_template, tokenizer, True, None, None) - image_mask = [1 if x == to_pt['multimodal'][0]['image_token_id'] else 0 for x in to_pt['input_ids']] - # 这里用来获得 image embedding - output = model_instance.forward(results) - forward_content = find_forward_content(output) - # tensor_shape = forward_content[0].shape - if not forward_content: - raise HTTPException(status_code=500, detail="无法在模型输出中找到 'forward' 内容。") - # store the image embedding to gpu cache - image_embedding = forward_content[0] - image_embedding = image_embedding.to( - torch.bfloat16 - ) # FIXME: forward() is used by turbomind, which returns float16 feature, but pytorch will return bfloat16 - print(f'image_embedding shape: {image_embedding.shape}') - print(f'image_embedding: {image_embedding}') - num_required_blocks = image_embedding.shape[0] // 256 - global session_counter - allocated_block_ids = [] - session_id = -1 - with block_manager_lock: - if len(free_blocks) < num_required_blocks: - raise HTTPException(status_code=503, detail='GPU 缓存已满,请稍后再试。') - - allocated_block_ids = [free_blocks.pop() for _ in range(num_required_blocks)] - session_counter += 1 - session_id = session_counter - session_blocks[session_id] = allocated_block_ids - print('in blocks') - print(allocated_block_ids) - print(cache_engine_instance.gpu_cache[allocated_block_ids].shape) - print(cache_engine_instance.gpu_cache[allocated_block_ids]) - try: - with torch.cuda.stream(cache_engine_instance.cache_stream): - for i in range(num_required_blocks): - src_chunk = image_embedding[i * 256:(i + 1) * 256, :] - dst_block_id = allocated_block_ids[i] - cache_engine_instance.gpu_cache[dst_block_id].copy_(src_chunk) - cache_engine_instance.cache_stream.synchronize() - except Exception as e: - # 如果拷贝失败,必须归还申请的块,防止内存泄漏 - with block_manager_lock: - free_blocks.extend(allocated_block_ids) - del session_blocks[session_id] - logger.error(f'拷贝 embedding 到缓存失败: {e}') - raise HTTPException(status_code=500, detail='缓存图像 embedding 失败。') - - # 返回内容 - - # FIXME, zhouxinyu this should not be empty - # otherwise gen config related information are lost, for instance top_p, top_k, max_new_tokens - # request['messages'] = [] - encoder_result_obj = EncoderResult( - token_ids=to_pt['input_ids'], - image_mask=image_mask, - protocol=MigrationProtocol.RDMA, - remote_engine_id=encoder_url, # encode 引擎的 url - remote_session_id=session_id, # encode 阶段的 session id - remote_block_ids=allocated_block_ids # image embedding 的 memory block id - ) - request['encoder_result'] = asdict(encoder_result_obj) - - return JSONResponse(jsonable_encoder(request)) - - -@app.post('/distserve/p2p_initialize') -async def p2p_initialize(init_request: DistServeInitRequest): - kv_eps = cache_engine_instance.p2p_initialize(init_request) - # 目前 encoder 没有 zmq 通信;返回一个假地址 - zmq_addr = f'tcp://{get_host_ip()}:65001' - resp = DistServeInitResponse( - status=DistServeConnectionStatus.SUCCESS, - engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_addr), - kvtransfer_endpoint_info=kv_eps, - ) - return JSONResponse(jsonable_encoder(resp.model_dump())) - - -@app.post('/distserve/p2p_connect') -async def p2p_connect(conn_request: DistServeConnectionRequest): - cache_engine_instance.p2p_connect( - conn_request.remote_engine_id, - conn_request.remote_kvtransfer_endpoint_info, - ) - resp = DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS) - return JSONResponse(jsonable_encoder(resp.model_dump())) - - -@app.post('/distserve/free_cache') -async def free_cache(free_req: DistServeCacheFreeRequest): - # Free allocated GPU blocks for a given session id - global free_blocks, session_blocks - sid = free_req.remote_session_id - with block_manager_lock: - blocks = session_blocks.pop(sid, []) - if blocks: - free_blocks.extend(blocks) - return JSONResponse({'success': True, 'freed_blocks': blocks if 'blocks' in locals() else []}) - - -@app.get('/distserve/engine_info') -async def engine_info(): - - response = DistServeEngineConfig(tp_size=1, - dp_size=1, - pp_size=1, - ep_size=1, - dp_rank=1, - block_size=256 * 4096, - num_cpu_blocks=0, - num_gpu_blocks=NUM_GPU_BLOCKS) - - return response.model_dump_json() - - -# --- 6. 运行服务器 --- -if __name__ == '__main__': - uvicorn.run(app, host='0.0.0.0', port=SERVER_PORT) From fa68324e25432213fde3e6f25e01dfa04a415a63 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 17 Oct 2025 17:33:24 +0800 Subject: [PATCH 10/11] fix lint --- lmdeploy/multimodal/__init__.py | 1 + lmdeploy/multimodal/engine/__init__.py | 1 + lmdeploy/multimodal/models/__init__.py | 1 + 3 files changed, 3 insertions(+) create mode 100644 lmdeploy/multimodal/__init__.py create mode 100644 lmdeploy/multimodal/engine/__init__.py create mode 100644 lmdeploy/multimodal/models/__init__.py diff --git a/lmdeploy/multimodal/__init__.py b/lmdeploy/multimodal/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/multimodal/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/multimodal/engine/__init__.py b/lmdeploy/multimodal/engine/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/multimodal/engine/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/multimodal/models/__init__.py b/lmdeploy/multimodal/models/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/multimodal/models/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. From 4720904029839633aec1bdf54cddb8b56e7e1578 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 17 Oct 2025 17:52:37 +0800 Subject: [PATCH 11/11] rm duplicate encoder_cache_engine.py --- .../pytorch/engine/encoder_cache_engine.py | 139 ------------------ 1 file changed, 139 deletions(-) delete mode 100644 lmdeploy/pytorch/engine/encoder_cache_engine.py diff --git a/lmdeploy/pytorch/engine/encoder_cache_engine.py b/lmdeploy/pytorch/engine/encoder_cache_engine.py deleted file mode 100644 index 26d092a65a..0000000000 --- a/lmdeploy/pytorch/engine/encoder_cache_engine.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modify from: https://github.com/vllm-project/vllm -import json -from typing import List, Optional, Tuple - -import torch - -from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS -from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl -from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo -from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage -from lmdeploy.utils import get_logger - -from ..config import ModelConfig - -logger = get_logger('lmdeploy') - -FEATURE_BLOCK_SHAPE = (256, 4096) - - -class EncoderCacheEngine: - """Manages the memory pool for image features. - - This engine allocates and manages a contiguous block of GPU memory - to store image embeddings transferred from an encoder. It is adapted for - an encoder-LLM separated architecture. - - Args: - cache_config (CacheConfig): Configuration for the cache, such as the - number of blocks. - model_config (ModelConfig): Model configuration, used for dtype. - rank (int): Distributed rank. - tp_rank (int): Tensor parallelism rank. - world_size (int): Distributed world size. - """ - - def __init__( - self, - num_gpu_blocks: int = 128, - rank: int = 0, - tp_rank: int = 0, - world_size: int = 1, - ) -> None: - self.world_size = world_size - self.rank = rank - self.tp_rank = tp_rank - - # FIXME: turbomind forward() returns float16, pytorch forward uses bfloat16 - self.feature_dtype = torch.bfloat16 - self._num_gpu_blocks = num_gpu_blocks - - self.encoder_gpu_cache = self._allocate_gpu_cache() - - self.migration_backend_impl: Optional[MigrationBackendImpl] = None - - self.cache_stream = torch.cuda.Stream() - assert self.cache_stream != torch.cuda.current_stream() - self.events = torch.cuda.Event() - - logger.debug(f'Initialize feature cache engine with {self.num_gpu_blocks} gpu blocks.') - - @property - def gpu_cache(self) -> torch.Tensor: - """The GPU feature pool tensor.""" - return self.encoder_gpu_cache - - @property - def num_gpu_blocks(self) -> int: - """Number of GPU blocks.""" - return self._num_gpu_blocks - - @staticmethod - def get_feature_block_shape() -> Tuple[int, int]: - """Get the shape of a single image feature block.""" - return FEATURE_BLOCK_SHAPE - - def _allocate_cache(self, num_blocks: int, device: torch.device) -> torch.Tensor: - """Allocate the memory pool on the specified device.""" - block_shape = self.get_feature_block_shape() - - # allocate a large contiguous tensor as the feature pool - encoder_cache = torch.empty( - size=(num_blocks, *block_shape), - dtype=self.feature_dtype, - device=device, - ) - return encoder_cache - - def _allocate_gpu_cache(self) -> torch.Tensor: - """Allocate the feature pool on the GPU.""" - return self._allocate_cache(self.num_gpu_blocks, 'cuda') - - @classmethod - def get_cache_block_size(cls, model_config: ModelConfig) -> int: - """Get the memory size in bytes of a single feature block. - - Args: - model_config (ModelConfig): The model config, used for dtype. - - Return: - int: Required memory size in bytes for one block. - """ - shape = cls.get_feature_block_shape() - dtype = model_config.dtype - - meta_tensor = torch.empty(shape, dtype=dtype, device='meta') - return meta_tensor.numel() * meta_tensor.element_size() - - """ Methods for Disaggregation Begin. """ - - def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> List[DistServeKVTransferEndpointInfo]: - if not self.migration_backend_impl: - self.migration_backend_impl: MigrationBackendImpl = MIGRATION_BACKENDS.module_dict['DLSlime']() - migration_init_request.rank = self.rank - self.migration_backend_impl.p2p_initialize(migration_init_request) - - t = self.encoder_gpu_cache - if t.numel() > 0: - register_mr_request = DistServeRegisterMRMessage( - protocol=migration_init_request.protocol, - remote_engine_id=migration_init_request.remote_engine_id, - mr_key='encoder_cache', # use fixed key - addr=t.data_ptr(), - offset=t.storage_offset(), - length=t.numel() * t.itemsize) - self.migration_backend_impl.register_memory_region(register_mr_request) - - return [ - DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, - endpoint_info=json.dumps( - self.migration_backend_impl.endpoint_info( - migration_init_request.remote_engine_id, - migration_init_request.protocol))) - ] - - def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]): - self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank]) - - """ Methods for Disaggregation End. """