@@ -52,7 +52,7 @@ def __init__(
5252 self .stepped = False
5353 self .closed = False
5454
55- self ._position = 0
55+ self .position = 0
5656 self .history = None # Used in case of server failures to regenerate attention caches on new servers
5757 self .next_session = None
5858
@@ -102,12 +102,11 @@ def step(
102102 n_input_tokens = inputs .shape [1 ]
103103 if self .history is None :
104104 self .history = inputs
105- elif self .history .shape [1 ] == self ._position :
105+ elif self .history .shape [1 ] == self .position :
106106 self .history = torch .cat ([self .history , inputs [:, - n_input_tokens :]], dim = 1 )
107- assert self .history .shape [1 ] == self ._position + n_input_tokens , (
108- f"Broken input cache: span={ self .span } shape={ self .history .shape } "
109- f"position={ self ._position } n_input_tokens={ n_input_tokens } "
110- )
107+ assert (
108+ self .history .shape [1 ] == self .position + n_input_tokens
109+ ), f"Broken input cache: { self .span = } { self .history .shape = } { self .position = } { n_input_tokens = } "
111110
112111 if not self .stepped :
113112 inputs = self .history # Pass full inputs including prefix
@@ -169,7 +168,7 @@ def step(
169168 outputs [0 ].shape == inputs .shape
170169 ), f"output activation shape is different from input shape: { outputs [0 ].shape } != { inputs .shape } "
171170
172- self ._position += n_input_tokens
171+ self .position += n_input_tokens
173172
174173 return outputs [0 ]
175174
@@ -359,6 +358,10 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
359358 # If there is a failed span, this code replaces it, otherwise it just adds new ones
360359 if server_idx < n_prev_spans :
361360 updated_sessions [0 ].history = self ._server_sessions [server_idx ].history
361+ updated_sessions [0 ].position = self ._position
362+ assert (
363+ updated_sessions [0 ].history .shape [1 ] == self ._position
364+ ), f"Broken input cache: { updated_sessions [0 ].history .shape = } { self ._position = } "
362365 self ._server_sessions [server_idx : server_idx + 1 ] = updated_sessions
363366
364367 # Update links to the next server session for direct server-to-server communication via rpc_push()
0 commit comments