@@ -84,8 +84,13 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8484 break # this message means "done sending"
8585
8686 def step (
87- self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * ,
88- step_id : str , start_from_position : int
87+ self ,
88+ inputs : torch .Tensor ,
89+ prompts : torch .Tensor ,
90+ hypo_ids : torch .LongTensor ,
91+ * ,
92+ step_id : str ,
93+ start_from_position : int ,
8994 ) -> torch .Tensor :
9095 """
9196 Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -266,8 +271,11 @@ def __enter__(self) -> "InferenceSession":
266271 return self
267272
268273 def step (
269- self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None ,
270- hypo_ids : Optional [torch .Tensor ] = None , start_from_position : Optional [int ] = None
274+ self ,
275+ inputs : torch .Tensor ,
276+ prompts : Optional [torch .Tensor ] = None ,
277+ hypo_ids : Optional [torch .Tensor ] = None ,
278+ start_from_position : Optional [int ] = None ,
271279 ) -> torch .Tensor :
272280
273281 if start_from_position is not None :
@@ -317,8 +325,11 @@ def step(
317325
318326 server_session = self ._server_sessions [server_idx ]
319327 inputs = server_session .step (
320- inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids ,
321- step_id = step_id , start_from_position = start_from_position
328+ inputs ,
329+ prompts [server_session .span .start : server_session .span .end ],
330+ hypo_ids ,
331+ step_id = step_id ,
332+ start_from_position = start_from_position ,
322333 )
323334
324335 server_idx += 1
0 commit comments