@@ -84,7 +84,8 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8484 break # this message means "done sending"
8585
8686 def step (
87- self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * , step_id : str
87+ self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * ,
88+ step_id : str , last_validated_position : int
8889 ) -> torch .Tensor :
8990 """
9091 Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -94,6 +95,12 @@ def step(
9495 if self .closed :
9596 raise Exception ("Session is closed, cannot perform step" )
9697
98+ if last_validated_position is not None :
99+ assert last_validated_position <= self ._position
100+ self ._position = last_validated_position
101+ if self .history is not None and self .history .shape [1 ] >= last_validated_position :
102+ self .history = self .history [:, :last_validated_position , :] if last_validated_position > 0 else None
103+
97104 n_input_tokens = inputs .shape [1 ]
98105 if self .history is None :
99106 self .history = inputs
@@ -115,6 +122,8 @@ def step(
115122 request_metadata = dict (session_id = self .session_id , step_id = step_id )
116123 if not self .stepped :
117124 request_metadata .update (self .session_metadata )
125+ if last_validated_position is not None :
126+ request_metadata ["last_validated_position" ] = last_validated_position
118127 elif self .config .use_server_to_server :
119128 next_servers = self ._collect_next_servers ()
120129 if next_servers :
@@ -257,8 +266,13 @@ def __enter__(self) -> "InferenceSession":
257266 return self
258267
259268 def step (
260- self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None , hypo_ids : Optional [torch .Tensor ] = None
269+ self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None ,
270+ hypo_ids : Optional [torch .Tensor ] = None , last_validated_position : Optional [int ] = None
261271 ) -> torch .Tensor :
272+
273+ if last_validated_position is not None :
274+ self ._position = last_validated_position
275+
262276 assert not self ._closed
263277 if torch .is_grad_enabled ():
264278 logger .warning ("Running inference session with grad enabled. Gradients will *not* be propagated correctly." )
@@ -303,7 +317,8 @@ def step(
303317
304318 server_session = self ._server_sessions [server_idx ]
305319 inputs = server_session .step (
306- inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids , step_id = step_id
320+ inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids ,
321+ step_id = step_id , last_validated_position = last_validated_position
307322 )
308323
309324 server_idx += 1
0 commit comments