@@ -52,7 +52,7 @@ def __init__(
5252 self .stepped = False
5353 self .closed = False
5454
55- self ._position = 0
55+ self .position = 0
5656 self .history = None # Used in case of server failures to regenerate attention caches on new servers
5757 self .next_session = None
5858
@@ -97,12 +97,11 @@ def step(
9797 n_input_tokens = inputs .shape [1 ]
9898 if self .history is None :
9999 self .history = inputs
100- elif self .history .shape [1 ] == self ._position :
100+ elif self .history .shape [1 ] == self .position :
101101 self .history = torch .cat ([self .history , inputs [:, - n_input_tokens :]], dim = 1 )
102- assert self .history .shape [1 ] == self ._position + n_input_tokens , (
103- f"Broken input cache: span={ self .span } shape={ self .history .shape } "
104- f"position={ self ._position } n_input_tokens={ n_input_tokens } "
105- )
102+ assert (
103+ self .history .shape [1 ] == self .position + n_input_tokens
104+ ), f"Broken input cache: { self .span = } { self .history .shape = } { self .position = } { n_input_tokens = } "
106105
107106 if not self .stepped :
108107 inputs = self .history # Pass full inputs including prefix
@@ -154,7 +153,7 @@ def step(
154153 outputs [0 ].shape == inputs .shape
155154 ), f"output activation shape is different from input shape: { outputs [0 ].shape } != { inputs .shape } "
156155
157- self ._position += n_input_tokens
156+ self .position += n_input_tokens
158157
159158 return outputs [0 ]
160159
@@ -356,6 +355,10 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
356355 # If there is a failed span, this code replaces it, otherwise it just adds new ones
357356 if server_idx < n_prev_spans :
358357 updated_sessions [0 ].history = self ._server_sessions [server_idx ].history
358+ updated_sessions [0 ].position = self ._position
359+ assert (
360+ updated_sessions [0 ].history .shape [1 ] == self ._position
361+ ), f"Broken input cache: { updated_sessions [0 ].history .shape = } { self ._position = } "
359362 self ._server_sessions [server_idx : server_idx + 1 ] = updated_sessions
360363
361364 # Update links to the next server session for direct server-to-server communication via rpc_push()
0 commit comments