18
18
from text_generation_server .utils .token_types import TokenInfo , InputTokens
19
19
from text_generation_server .utils .tokens import HeterogeneousNextTokenChooser , get_token_info , get_input_tokens_info
20
20
from text_generation_server .utils .paged import (
21
+ load_speculator ,
21
22
prepare_inputs_without_speculation ,
22
23
prepare_inputs_with_speculation ,
23
24
process_outputs_with_speculation ,
24
25
prepare_inputs_for_prefill
25
26
)
26
27
from text_generation_server .inference_engine import get_inference_engine_class
27
28
28
- # HF name or path to speculator model (None means no speculation will be used)
29
- SPECULATOR_NAME = os .getenv ("SPECULATOR_NAME" , None )
30
-
31
29
# we will only do speculation if the batch size is <= this parameter
32
30
SPECULATOR_MAX_BATCH_SIZE = int (os .getenv ("SPECULATOR_MAX_BATCH_SIZE" , "16" ))
33
31
@@ -277,6 +275,7 @@ def __init__(
277
275
quantize : Optional [str ],
278
276
model_config : Union [Any ] = None ,
279
277
max_sequence_length : Optional [int ] = None ,
278
+ memory_scaling_model : Optional ["MemoryScalingModel" ] = None ,
280
279
):
281
280
model_path = get_model_path (model_name , revision )
282
281
@@ -300,27 +299,41 @@ def __init__(
300
299
301
300
from fms_extras .utils .cache .paged import PagedKVCacheManager
302
301
303
- if SPECULATOR_NAME is not None :
304
- from fms_extras .models .hf .modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
305
- speculator_revision = os .getenv ("SPECULATOR_REVISION" , None )
306
- speculator_model_path = get_model_path (SPECULATOR_NAME , speculator_revision )
307
- print_rank_n (f"Loading speculator model from: { speculator_model_path } " )
302
+ # load speculator
303
+ self .speculator = load_speculator (self .device , dtype )
304
+
305
+ if self .speculator is not None :
308
306
print_rank_n (f"Speculation will be enabled up to batch size { SPECULATOR_MAX_BATCH_SIZE } " )
309
- kwargs = {
310
- "pretrained_model_name_or_path" : speculator_model_path ,
311
- "local_files_only" : True ,
312
- "torch_dtype" : dtype ,
313
- }
314
- with self .device :
315
- self .speculator = MLPSpeculatorPreTrainedModel .from_pretrained (** kwargs )
316
- self .speculator .to (device = self .device )
317
- else :
318
- self .speculator = None
307
+
308
+ block_size = 16
319
309
320
310
if KV_CACHE_MANAGER_NUM_GPU_BLOCKS is not None :
321
311
total_num_gpu_blocks = int (KV_CACHE_MANAGER_NUM_GPU_BLOCKS )
322
312
else :
323
- total_num_gpu_blocks = None
313
+ # Firstly, let's compute the size of a cache block in bytes
314
+ kv_cache_block_size = self .model .get_kv_cache_block_size (block_size )
315
+ total_size = model_config .num_hidden_layers * kv_cache_block_size
316
+ dtype_size = torch .tensor ([], dtype = dtype ).element_size ()
317
+ cache_block_size = dtype_size * total_size
318
+ # We then use our memory scaling model to determine the fraction of the prefill memory
319
+ # usage that is due to cache blocks (as opposed to the other stuff needed for forward):
320
+ pf_cache_block_ratio = cache_block_size / block_size / memory_scaling_model .linear_fit_params [0 ]
321
+ # We can then do the same for the next token (decoding) step:
322
+ nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model .next_token_params [1 ]
323
+ # In general we know that the next token phase can use many more cache blocks
324
+ # relative to the prefill phase (e.g., nt_cache_block_ratio > pf_cache_block_ratio).
325
+ # Thus, we need to allocate enough cache blocks to handle the more extreme case:
326
+ total_num_gpu_blocks = int (nt_cache_block_ratio * memory_scaling_model .free_memory // cache_block_size )
327
+ # This creates an issue though, because if we then try to perform a large prefill, while we
328
+ # will certainly have enough cache blocks available, we may not have enough memory leftover
329
+ # to allocate the other data structures needed during a forward pass.
330
+ # To overcome this, we can set the batch_safety_margin a bit to ensure that:
331
+ # free_memory * (1.0-batch_safety_margin/100-0.05) * (1.0-pf_cache_block_ratio) <
332
+ # free_memory * (1.0-nf_cache_block_ratio)
333
+ # This should ensure that our prefills batches can never get so big as to cause OOM.
334
+ recommend_safety_margin = 5 + int (100 * (1.0 - (1.0 - nt_cache_block_ratio )/ (1.0 - pf_cache_block_ratio )))
335
+ if memory_scaling_model .safety_margin < recommend_safety_margin :
336
+ print (f"WARN: We recommend increasing the value of BATCH_SAFETY_MARGIN to: { recommend_safety_margin } " )
324
337
325
338
self .kv_cache_manager = PagedKVCacheManager (
326
339
model_config .num_hidden_layers ,
@@ -331,8 +344,11 @@ def __init__(
331
344
dtype = dtype ,
332
345
device = self .device ,
333
346
total_num_gpu_blocks = total_num_gpu_blocks ,
347
+ block_size = block_size ,
334
348
)
335
349
350
+ self .memory_scaling_model = memory_scaling_model
351
+
336
352
# log number of free blocks at init
337
353
print ("[PagedKVCacheManager] number of free blocks: %d" % (len (self .kv_cache_manager .free_blocks )))
338
354
@@ -413,12 +429,18 @@ def _prefill(
413
429
)
414
430
415
431
t0 = time .time_ns ()
416
- output = self .model (
417
- input_ids ,
418
- position_ids = position_ids ,
419
- cache_data = cache_data ,
420
- return_embeds = True ,
421
- )
432
+ try :
433
+ output = self .model (
434
+ input_ids ,
435
+ position_ids = position_ids ,
436
+ cache_data = cache_data ,
437
+ return_embeds = True ,
438
+ )
439
+ except :
440
+ # if something goes wrong during forward, we still need to set the sequence ids
441
+ #TODO it would be better to fix the forward method to avoid possibility of partial failures
442
+ batch .sequence_ids = cache_data .sequence_ids
443
+ raise
422
444
t_forward_ns = time .time_ns ()- t0
423
445
logits , embeds = output
424
446
@@ -603,10 +625,7 @@ def generate_token(
603
625
)
604
626
else :
605
627
bsize = batch .input_ids .shape [0 ]
606
-
607
- tokens_remaining = 0
608
- for i in range (len (batch .total_lengths )):
609
- tokens_remaining += batch .total_lengths [i ] - batch .input_lengths [i ]
628
+ weight = sum (batch .total_lengths ) * self .memory_scaling_model .next_token_params [1 ]
610
629
611
630
spec_ind = []
612
631
for i , sample in enumerate (batch .next_token_chooser .do_sample ):
@@ -618,7 +637,7 @@ def generate_token(
618
637
len (spec_ind ) > 0 and
619
638
bsize <= SPECULATOR_MAX_BATCH_SIZE and
620
639
batch .next_token_chooser .repetition_processor is None and
621
- tokens_remaining < 0.25 * len ( self .kv_cache_manager . free_blocks ) * self . kv_cache_manager . block_size
640
+ ( weight / self .memory_scaling_model . weight_limit ) <= 0.75
622
641
)
623
642
624
643
if speculate :
0 commit comments