Skip to content

Commit fb23def

Browse files
authored
Free blocks in KVCacheManager upon error (IBM#96)
#### Motivation We are see pods with spec. decoding getting restarted in BAM due to health checks failing. Upon inspection of the logs, it looks like we are running out of blocks, and never recovering from it. #### Modifications I added a simple check that if something goes wrong when generating a token, we free the blocks associated with that batch. I also had to ensure that the we free the child sequences that get created during speculation if something goes wrong there too. #### Result I've verified this allow us to recover from failures related to running out of blocks. Hopefully after this fix, we don't see the inference server getting restarted. Signed-off-by: Thomas Parnell <[email protected]>
1 parent 0734973 commit fb23def

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

server/text_generation_server/server.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,15 @@ async def Prefill(self, request: generate_pb2.PrefillRequest, context) -> genera
141141
batch_id = 0
142142
if batch is not None:
143143
for_concat = len(self.cache) > 0
144-
# Prefill and generate first token
145-
output_tokens, input_token_info, decode_errors, forward_time_ns = self.model.generate_token(
146-
batch, first=True, for_concat=for_concat,
147-
)
144+
try:
145+
# Prefill and generate first token
146+
output_tokens, input_token_info, decode_errors, forward_time_ns = self.model.generate_token(
147+
batch, first=True, for_concat=for_concat,
148+
)
149+
except:
150+
self._free_paged_sequences(batch, None)
151+
raise
152+
148153
if hasattr(batch, "past_key_values"):
149154
clean_attribute("past_key_values", batch.past_key_values)
150155
if not is_healthcheck:
@@ -206,7 +211,12 @@ async def NextToken(self, request: generate_pb2.NextTokenRequest, context) -> ge
206211
# Ensure batches are garbage-collected post-concatenation
207212
del batches
208213

209-
output_tokens, _, errors, forward_time_ns = self.model.generate_token(batch)
214+
try:
215+
output_tokens, _, errors, forward_time_ns = self.model.generate_token(batch)
216+
except:
217+
self._free_paged_sequences(batch, None)
218+
raise
219+
210220
self.cache.set(batch)
211221

212222
return generate_pb2.NextTokenResponse(

server/text_generation_server/utils/paged.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,12 @@ def prepare_inputs_with_speculation(
169169
child_sequence_ids_flattened.extend(child_sequence_ids)
170170

171171
# add n_adds tokens to each candidate
172-
cache_data = kv_cache_manager.allocate_tokens(num_tokens_per_sequence, child_sequence_ids_flattened)
172+
try:
173+
cache_data = kv_cache_manager.allocate_tokens(num_tokens_per_sequence, child_sequence_ids_flattened)
174+
except:
175+
kv_cache_manager.free_sequences(child_sequence_ids_flattened)
176+
raise
177+
173178
position_ids = cache_data.position_ids
174179

175180
# Get candidate set of speculations

0 commit comments

Comments
 (0)