@@ -52,7 +52,7 @@ def __init__(
5252        self .stepped  =  False 
5353        self .closed  =  False 
5454
55-         self ._position  =  0 
55+         self .position  =  0 
5656        self .history  =  None   # Used in case of server failures to regenerate attention caches on new servers 
5757        self .next_session  =  None 
5858
@@ -97,12 +97,11 @@ def step(
9797        n_input_tokens  =  inputs .shape [1 ]
9898        if  self .history  is  None :
9999            self .history  =  inputs 
100-         elif  self .history .shape [1 ] ==  self ._position :
100+         elif  self .history .shape [1 ] ==  self .position :
101101            self .history  =  torch .cat ([self .history , inputs [:, - n_input_tokens :]], dim = 1 )
102-         assert  self .history .shape [1 ] ==  self ._position  +  n_input_tokens , (
103-             f"Broken input cache: span={ self .span }   shape={ self .history .shape }   " 
104-             f"position={ self ._position }   n_input_tokens={ n_input_tokens }  " 
105-         )
102+         assert  (
103+             self .history .shape [1 ] ==  self .position  +  n_input_tokens 
104+         ), f"Broken input cache: { self .span = }   { self .history .shape = }   { self .position = }   { n_input_tokens = }  " 
106105
107106        if  not  self .stepped :
108107            inputs  =  self .history   # Pass full inputs including prefix 
@@ -154,7 +153,7 @@ def step(
154153            outputs [0 ].shape  ==  inputs .shape 
155154        ), f"output activation shape is different from input shape: { outputs [0 ].shape }   != { inputs .shape }  " 
156155
157-         self ._position  +=  n_input_tokens 
156+         self .position  +=  n_input_tokens 
158157
159158        return  outputs [0 ]
160159
@@ -356,6 +355,10 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
356355        # If there is a failed span, this code replaces it, otherwise it just adds new ones 
357356        if  server_idx  <  n_prev_spans :
358357            updated_sessions [0 ].history  =  self ._server_sessions [server_idx ].history 
358+             updated_sessions [0 ].position  =  self ._position 
359+             assert  (
360+                 updated_sessions [0 ].history .shape [1 ] ==  self ._position 
361+             ), f"Broken input cache: { updated_sessions [0 ].history .shape = }   { self ._position = }  " 
359362        self ._server_sessions [server_idx  : server_idx  +  1 ] =  updated_sessions 
360363
361364        # Update links to the next server session for direct server-to-server communication via rpc_push() 
0 commit comments