@@ -23,10 +23,7 @@ def __init__(
2323 self .device = device
2424
2525 self .inputs_embeds = torch .zeros (
26- max_num_tokens ,
27- hidden_size ,
28- dtype = dtype ,
29- device = device ,
26+ max_num_tokens , hidden_size , dtype = dtype , device = device
3027 )
3128 self .req_id_to_mm_features : dict [str , list [MultiModalFeatureSpec ]] = {}
3229 self .encoder_cache : dict [str , torch .Tensor ] = {}
@@ -57,8 +54,7 @@ def remove_request(self, req_id: str) -> None:
5754 self .req_id_to_mm_features .pop (req_id , None )
5855
5956 def prepare_mm_inputs (
60- self ,
61- scheduled_encoder_inputs : dict [str , list [int ]],
57+ self , scheduled_encoder_inputs : dict [str , list [int ]]
6258 ) -> tuple [list [str ], list [tuple [str , MultiModalKwargsItem ]]]:
6359 mm_hashes : list [str ] = []
6460 mm_kwargs : list [tuple [str , MultiModalKwargsItem ]] = []
@@ -85,20 +81,16 @@ def execute_mm_encoder(
8581
8682 encoder_outputs : list [torch .Tensor ] = []
8783 for modality , num_items , mm_kwargs_group in group_mm_kwargs_by_modality (
88- mm_kwargs ,
89- device = self .device ,
90- pin_memory = False ,
84+ mm_kwargs , device = self .device , pin_memory = False
9185 ):
9286 curr_group_outputs = model .embed_multimodal (** mm_kwargs_group )
9387 sanity_check_mm_encoder_outputs (
94- curr_group_outputs ,
95- expected_num_items = num_items ,
88+ curr_group_outputs , expected_num_items = num_items
9689 )
9790 encoder_outputs .extend (curr_group_outputs )
9891
9992 # Cache the encoder outputs by mm_hash
100- for mm_hash , output in zip (mm_hashes , encoder_outputs ):
101- self .encoder_cache [mm_hash ] = output
93+ self .encoder_cache .update (zip (mm_hashes , encoder_outputs ))
10294 return encoder_outputs
10395
10496 def gather_mm_embeddings (
@@ -115,20 +107,15 @@ def gather_mm_embeddings(
115107 if all_decode :
116108 # All decode requests, so no need to gather any embeddings.
117109 return [], torch .zeros (
118- total_num_scheduled_tokens ,
119- dtype = torch .bool ,
120- device = self .device ,
110+ total_num_scheduled_tokens , dtype = torch .bool , device = self .device
121111 )
122112
123113 query_start = computed_prefill_lens .tolist ()
124114 query_end = (computed_prefill_lens + num_scheduled_tokens ).tolist ()
125115
126116 mm_embeds : list [torch .Tensor ] = []
127117 is_mm_embed = torch .zeros (
128- total_num_scheduled_tokens ,
129- dtype = torch .bool ,
130- device = "cpu" ,
131- pin_memory = True ,
118+ total_num_scheduled_tokens , dtype = torch .bool , device = "cpu" , pin_memory = True
132119 )
133120 for i , req_id in enumerate (req_ids ):
134121 if not is_prefilling [i ]:
@@ -189,9 +176,7 @@ def get_inputs_embeds(
189176 is_mm_embed : torch .Tensor ,
190177 ) -> torch .Tensor :
191178 x = model .embed_input_ids (
192- input_ids ,
193- multimodal_embeddings = mm_embeds ,
194- is_multimodal = is_mm_embed ,
179+ input_ids , multimodal_embeddings = mm_embeds , is_multimodal = is_mm_embed
195180 )
196181 # Copy to the pre-allocated buffer for CUDA graphs.
197182 self .inputs_embeds [: x .shape [0 ]] = x
0 commit comments