@@ -85,7 +85,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8585
8686 def step (
8787 self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * ,
88- step_id : str , last_validated_position : int
88+ step_id : str , start_from_position : int
8989 ) -> torch .Tensor :
9090 """
9191 Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -95,11 +95,11 @@ def step(
9595 if self .closed :
9696 raise Exception ("Session is closed, cannot perform step" )
9797
98- if last_validated_position is not None :
99- assert last_validated_position <= self ._position
100- self ._position = last_validated_position
101- if self .history is not None and self .history .shape [1 ] >= last_validated_position :
102- self .history = self .history [:, :last_validated_position , :] if last_validated_position > 0 else None
98+ if start_from_position is not None :
99+ assert start_from_position <= self ._position
100+ self ._position = start_from_position
101+ if self .history is not None and self .history .shape [1 ] >= start_from_position :
102+ self .history = self .history [:, :start_from_position , :] if start_from_position > 0 else None
103103
104104 n_input_tokens = inputs .shape [1 ]
105105 if self .history is None :
@@ -122,8 +122,8 @@ def step(
122122 request_metadata = dict (session_id = self .session_id , step_id = step_id )
123123 if not self .stepped :
124124 request_metadata .update (self .session_metadata )
125- if last_validated_position is not None :
126- request_metadata ["last_validated_position " ] = last_validated_position
125+ if start_from_position is not None :
126+ request_metadata ["start_from_position " ] = start_from_position
127127 elif self .config .use_server_to_server :
128128 next_servers = self ._collect_next_servers ()
129129 if next_servers :
@@ -267,11 +267,11 @@ def __enter__(self) -> "InferenceSession":
267267
268268 def step (
269269 self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None ,
270- hypo_ids : Optional [torch .Tensor ] = None , last_validated_position : Optional [int ] = None
270+ hypo_ids : Optional [torch .Tensor ] = None , start_from_position : Optional [int ] = None
271271 ) -> torch .Tensor :
272272
273- if last_validated_position is not None :
274- self ._position = last_validated_position
273+ if start_from_position is not None :
274+ self ._position = start_from_position
275275
276276 assert not self ._closed
277277 if torch .is_grad_enabled ():
@@ -318,7 +318,7 @@ def step(
318318 server_session = self ._server_sessions [server_idx ]
319319 inputs = server_session .step (
320320 inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids ,
321- step_id = step_id , last_validated_position = last_validated_position
321+ step_id = step_id , start_from_position = start_from_position
322322 )
323323
324324 server_idx += 1
0 commit comments