1515from queue import Queue
1616from typing import Any , Dict , List , Optional
1717
18- import numpy as np
1918import pybase64
2019import torch
2120import yaml
22- from torch .nn .utils .rnn import pad_sequence
2321
2422import lmdeploy
2523from lmdeploy .messages import EngineOutput , GenerationConfig , ResponseType , ScheduleMetrics , TurbomindEngineConfig
@@ -195,28 +193,22 @@ def _load_weights(self):
195193 def _process_weights (self ):
196194 """Process weight."""
197195 with ThreadPoolExecutor (max_workers = self .gpu_count ) as e :
198- ranks = [self .node_id * self .gpu_count + device_id for device_id in range (self .gpu_count )]
199- for _ in e .map (self .model_comm .process_weight , range (self .gpu_count ), ranks ):
196+ for _ in e .map (self .model_comm .process_weight , range (self .gpu_count )):
200197 pass
201198
202199 def _create_engine (self ):
203200 """Create engine."""
204201 with ThreadPoolExecutor (max_workers = self .gpu_count ) as e :
205- ranks = [self .node_id * self .gpu_count + device_id for device_id in range (self .gpu_count )]
206- for _ in e .map (self .model_comm .create_engine , range (self .gpu_count ), ranks ):
202+ for _ in e .map (self .model_comm .create_engine , range (self .gpu_count )):
207203 pass
208204 self ._engine_created = True
209205
210206 def _create_weight (self , model_comm ):
211207 """Allocate weight buffer, load params if from_workspace."""
212208
213- engine_cfg = self .config_dict ['engine_config' ]
214- self .node_id = engine_cfg ['node_rank' ]
215-
216209 # create weight
217210 def _create_weight_func (device_id ):
218- rank = self .node_id * self .gpu_count + device_id
219- model_comm .create_shared_weights (device_id , rank )
211+ model_comm .create_weights (device_id )
220212
221213 with ThreadPoolExecutor (max_workers = self .gpu_count ) as executor :
222214 futures = []
@@ -233,8 +225,7 @@ def _get_model_params(self):
233225 tm_params .clear ()
234226
235227 def _get_params (device_id , que ):
236- rank = self .node_id * self .gpu_count + device_id
237- out = model_comm .get_params (device_id , rank )
228+ out = model_comm .get_weights (device_id )
238229 que .put (out )
239230
240231 que = Queue ()
@@ -266,12 +257,6 @@ def _postprocess_config(self, tm_config: TurbomindModelConfig, engine_config: Tu
266257 # update some attributes of `engine_config` which depends on
267258 # `session_len`
268259 self .engine_config = engine_config
269- if engine_config .max_prefill_token_num is not None \
270- and engine_config .num_tokens_per_iter == 0 :
271- self .engine_config .num_tokens_per_iter = \
272- engine_config .max_prefill_token_num
273- self .engine_config .max_prefill_iters = (self .config .session_len + engine_config .max_prefill_token_num -
274- 1 ) // engine_config .max_prefill_token_num
275260
276261 # pack `self.config` and `self.engine_config` into a dict
277262 self .config_dict = self .config .to_dict ()
@@ -290,9 +275,9 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig):
290275
291276 self ._postprocess_config (tm_model .tm_config , engine_config )
292277
293- model_comm = _tm .AbstractTransformerModel . create_llama_model (model_dir = '' ,
294- config = yaml .safe_dump (self .config_dict ),
295- weight_type = self .config .model_config .weight_type )
278+ model_comm = _tm .TurboMind . create (model_dir = '' ,
279+ config = yaml .safe_dump (self .config_dict ),
280+ weight_type = self .config .model_config .weight_type )
296281
297282 # create empty weight
298283 self ._create_weight (model_comm )
@@ -311,8 +296,7 @@ def wakeup(self, tags: Optional[list[str]] = None):
311296 if tags is None :
312297 tags = ['weights' , 'kv_cache' ]
313298 with ThreadPoolExecutor (max_workers = self .gpu_count ) as e :
314- ranks = [self .node_id * self .gpu_count + device_id for device_id in range (self .gpu_count )]
315- for _ in e .map (self .model_comm .wakeup , range (self .gpu_count ), [tags ] * self .gpu_count , ranks ):
299+ for _ in e .map (self .model_comm .wakeup , range (self .gpu_count ), [tags ] * self .gpu_count ):
316300 pass
317301
318302 def update_params (self , request : UpdateParamsRequest ):
@@ -501,7 +485,7 @@ def _func(out: EngineOutput, step: int, **kwargs):
501485 out .req_metrics = RequestMetrics (token_timestamp = time .time ())
502486 else :
503487 events = [
504- EngineEvent (EventType .QUEUED , metrics .enque_time / 1000000 ),
488+ EngineEvent (EventType .QUEUED , metrics .enqueue_time / 1000000 ),
505489 EngineEvent (EventType .SCHEDULED , metrics .scheduled_time / 1000000 ),
506490 ]
507491 out .req_metrics = RequestMetrics (token_timestamp = time .time (), engine_events = events )
@@ -547,7 +531,7 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
547531
548532 # create model instances
549533 lazy_init = self .tm_model .config_dict ['engine_config' ].get ('empty_init' , False )
550- self ._model_inst = None if lazy_init else self ._create_model_instance (0 )
534+ self ._model_inst = None if lazy_init else self ._create_model_instance ()
551535
552536 self .config = config
553537 self .lock = None
@@ -564,17 +548,18 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
564548 7 : ResponseType .FINISH ,
565549 8 : ResponseType .CANCEL ,
566550 9 : ResponseType .PREFIX_CACHE_CONFLICT_INTERACTIVE_MODE ,
551+ 10 : ResponseType .NO_QUEUE ,
567552 - 1 : ResponseType .INTERNAL_ENGINE_ERROR ,
568553 }
569554
570555 @property
571556 def model_inst (self ):
572557 if self ._model_inst is None :
573- self ._model_inst = self ._create_model_instance (0 )
558+ self ._model_inst = self ._create_model_instance ()
574559 return self ._model_inst
575560
576- def _create_model_instance (self , device_id ):
577- model_inst = self .tm_model .model_comm .create_model_instance ( device_id )
561+ def _create_model_instance (self ):
562+ model_inst = self .tm_model .model_comm .create_request ( )
578563 return model_inst
579564
580565 def _get_extra_output_processors (self , outputs : Dict [str , torch .Tensor ], gen_config : GenerationConfig ,
@@ -598,47 +583,27 @@ def _get_offset(type):
598583
599584 def prepare_embeddings (self , input_embeddings = None , input_embedding_ranges = None ):
600585 """Convert embeddings."""
601- if input_embeddings is None :
586+ if not input_embeddings :
602587 return None , None
603588
589+ assert isinstance (input_embeddings , List )
590+ assert isinstance (input_embedding_ranges , List )
604591 assert len (input_embeddings ) == len (input_embedding_ranges )
605- if not isinstance (input_embeddings [0 ], (list , type (None ))):
606- input_embeddings = [input_embeddings ]
607- input_embedding_ranges = [input_embedding_ranges ]
608592
609- if all ([isinstance (x , type (None )) for x in input_embeddings ]):
610- return None , None
593+ length = sum ([x .shape [0 ] for x in input_embeddings ])
594+
595+ _MAP = dict (bfloat16 = torch .bfloat16 , float16 = torch .float16 )
596+ dtype = _MAP [self .tm_model .config .model_config .data_type ]
597+
598+ values = torch .empty ((length , input_embeddings [0 ].shape [- 1 ]), dtype = dtype , device = 'cpu' )
599+ ranges = torch .tensor (input_embedding_ranges , dtype = torch .int32 , device = 'cpu' )
600+
601+ offset = 0
602+ for embeds in input_embeddings :
603+ values [offset :offset + embeds .shape [0 ]].copy_ (embeds )
604+ offset += embeds .shape [0 ]
611605
612- hidden_dim = None
613- for embeddings in input_embeddings :
614- if embeddings is not None :
615- hidden_dim = embeddings [0 ].squeeze ().shape [- 1 ]
616- break
617- assert hidden_dim is not None
618-
619- # construct input_embeddings
620- for i in range (len (input_embeddings )):
621- item = input_embeddings [i ] or []
622- # convert to torch.Tensor if input is np.ndarray
623- if item and isinstance (item [0 ], np .ndarray ):
624- item = [torch .from_numpy (x ).squeeze () for x in item ]
625- # convert to lookup table type
626- _MAP = dict (float = torch .float , bfloat16 = torch .bfloat16 , float16 = torch .float16 , fp8 = torch .bfloat16 )
627- dtype = _MAP .get (self .tm_model .config .weight_type , torch .float16 )
628- item = [x .to (dtype = dtype ) for x in item ]
629- item = item or [torch .zeros (0 , hidden_dim , dtype = dtype )]
630- input_embeddings [i ] = item
631- input_embeddings = [torch .cat (x ) for x in input_embeddings ]
632- input_embeddings = pad_sequence (input_embeddings , batch_first = True )
633- input_embeddings = input_embeddings .reshape (input_embeddings .shape [0 ], - 1 ).view (torch .int8 )
634- # construct input_embedding_ranges
635- for i in range (len (input_embedding_ranges )):
636- item = input_embedding_ranges [i ] or []
637- item = torch .IntTensor (item ).reshape (- 1 , 2 )
638- input_embedding_ranges [i ] = item
639- input_embedding_ranges = pad_sequence (input_embedding_ranges , batch_first = True , padding_value = - 1 )
640-
641- return input_embeddings , input_embedding_ranges
606+ return values , ranges
642607
643608 def prepare_mrope (self , input_meta : Dict [str , Any ], input_len : int ):
644609 mrope_position_ids = input_meta ['mrope_position_ids' ]
0 commit comments