@@ -84,12 +84,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8484 break # this message means "done sending"
8585
8686 def step (
87- self ,
88- inputs : torch .Tensor ,
89- prompts : Optional [torch .Tensor ] = None ,
90- hypo_ids : Optional [torch .Tensor ] = None ,
91- * ,
92- step_id : str ,
87+ self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * , step_id : str
9388 ) -> torch .Tensor :
9489 """
9590 Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -114,21 +109,6 @@ def step(
114109 else :
115110 inputs = inputs [:, - n_input_tokens :] # No need to pass prefix further
116111
117- if prompts is None or is_dummy (prompts ):
118- prompts = DUMMY
119- else :
120- assert prompts .ndim == 4 , "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
121- assert prompts .shape [0 ] == self .num_blocks
122- assert prompts .shape [1 ] in (inputs .shape [0 ], 1 )
123- assert prompts .shape [2 ] <= inputs .shape [1 ]
124- assert prompts .shape [3 ] == inputs .shape [2 ]
125-
126- if hypo_ids is None or is_dummy (hypo_ids ):
127- hypo_ids = DUMMY_INT64
128- else :
129- assert len (hypo_ids ) == len (inputs )
130- assert hypo_ids .dtype == torch .int64
131-
132112 # serialize inputs and put them into the queue
133113 input_tensors , args_structure = pack_args_kwargs (inputs , prompts , hypo_ids )
134114
@@ -275,7 +255,9 @@ def __enter__(self) -> "InferenceSession":
275255 assert not self ._closed and not self ._server_sessions
276256 return self
277257
278- def step (self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None , ** kwargs ) -> torch .Tensor :
258+ def step (
259+ self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None , hypo_ids : Optional [torch .Tensor ] = None
260+ ) -> torch .Tensor :
279261 assert not self ._closed
280262 if torch .is_grad_enabled ():
281263 logger .warning ("Running inference session with grad enabled. Gradients will *not* be propagated correctly." )
@@ -285,11 +267,21 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
285267 else :
286268 assert prompts .ndim == 4 , "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
287269 assert prompts .shape [0 ] == self .num_blocks
270+ assert prompts .shape [1 ] in (inputs .shape [0 ], 1 )
271+ assert prompts .shape [2 ] <= inputs .shape [1 ]
272+ assert prompts .shape [3 ] == inputs .shape [2 ]
273+
274+ if hypo_ids is None or is_dummy (hypo_ids ):
275+ hypo_ids = DUMMY_INT64
276+ else :
277+ assert len (hypo_ids ) == len (inputs )
278+ assert hypo_ids .dtype == torch .int64
288279
289280 inputs_device = inputs .device
290281 inputs_dtype = inputs .dtype
291282 inputs = inputs .cpu ()
292283 prompts = prompts .cpu ()
284+ hypo_ids = hypo_ids .cpu ()
293285 step_id = str (uuid .uuid4 ())
294286
295287 n_input_tokens = inputs .shape [1 ]
@@ -310,7 +302,7 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
310302
311303 server_session = self ._server_sessions [server_idx ]
312304 inputs = server_session .step (
313- inputs , prompts [server_session .span .start : server_session .span .end ], step_id = step_id , ** kwargs
305+ inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids , step_id = step_id
314306 )
315307
316308 server_idx += 1
0 commit comments