Skip to content

Commit d5340ca

Browse files
Merge pull request IBM#86 from opendatahub-io/main
Sync release to main branches for 2.11
2 parents ed9d828 + 43623db commit d5340ca

File tree

6 files changed

+108
-51
lines changed

6 files changed

+108
-51
lines changed

server/text_generation_server/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_model(
3535
dtype_str: str,
3636
quantize: Optional[str],
3737
max_sequence_length: Optional[int],
38+
memory_scaling_model: Optional[int] = None,
3839
) -> Model:
3940
dtype = get_torch_dtype(dtype_str)
4041
model_path = get_model_path(model_name, revision)
@@ -74,6 +75,7 @@ def get_model(
7475
dtype, quantize,
7576
model_config,
7677
max_sequence_length=max_sequence_length,
78+
memory_scaling_model=memory_scaling_model,
7779
)
7880

7981
if FLASH_ATTENTION:

server/text_generation_server/models/custom_modeling/paged_llama_modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ def __init__(self, config, weights):
434434
weights=weights,
435435
)
436436

437+
def get_kv_cache_block_size(self, block_size: int) -> int:
438+
return block_size * self.model.num_key_value_heads * self.model.head_size * 2
439+
437440
def get_input_embeddings(self) -> nn.Module:
438441
return self.model.embed_tokens
439442

server/text_generation_server/models/custom_modeling/paged_santacoder_modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ def __init__(self, config, weights):
407407
config, prefix="transformer.wte", weights=weights
408408
)
409409

410+
def get_kv_cache_block_size(self, block_size: int) -> int:
411+
return block_size * self.transformer.head_size * 2
412+
410413
def get_input_embeddings(self) -> nn.Module:
411414
return self.transformer.wte
412415

server/text_generation_server/models/paged_causal_lm.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,14 @@
1818
from text_generation_server.utils.token_types import TokenInfo, InputTokens
1919
from text_generation_server.utils.tokens import HeterogeneousNextTokenChooser, get_token_info, get_input_tokens_info
2020
from text_generation_server.utils.paged import (
21+
load_speculator,
2122
prepare_inputs_without_speculation,
2223
prepare_inputs_with_speculation,
2324
process_outputs_with_speculation,
2425
prepare_inputs_for_prefill
2526
)
2627
from text_generation_server.inference_engine import get_inference_engine_class
2728

28-
# HF name or path to speculator model (None means no speculation will be used)
29-
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)
30-
3129
# we will only do speculation if the batch size is <= this parameter
3230
SPECULATOR_MAX_BATCH_SIZE = int(os.getenv("SPECULATOR_MAX_BATCH_SIZE", "16"))
3331

@@ -277,6 +275,7 @@ def __init__(
277275
quantize: Optional[str],
278276
model_config: Union[Any] = None,
279277
max_sequence_length: Optional[int] = None,
278+
memory_scaling_model: Optional["MemoryScalingModel"] = None,
280279
):
281280
model_path = get_model_path(model_name, revision)
282281

@@ -300,27 +299,41 @@ def __init__(
300299

301300
from fms_extras.utils.cache.paged import PagedKVCacheManager
302301

303-
if SPECULATOR_NAME is not None:
304-
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
305-
speculator_revision = os.getenv("SPECULATOR_REVISION", None)
306-
speculator_model_path = get_model_path(SPECULATOR_NAME, speculator_revision)
307-
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
302+
# load speculator
303+
self.speculator = load_speculator(self.device, dtype)
304+
305+
if self.speculator is not None:
308306
print_rank_n(f"Speculation will be enabled up to batch size {SPECULATOR_MAX_BATCH_SIZE}")
309-
kwargs = {
310-
"pretrained_model_name_or_path": speculator_model_path,
311-
"local_files_only": True,
312-
"torch_dtype": dtype,
313-
}
314-
with self.device:
315-
self.speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
316-
self.speculator.to(device=self.device)
317-
else:
318-
self.speculator = None
307+
308+
block_size = 16
319309

320310
if KV_CACHE_MANAGER_NUM_GPU_BLOCKS is not None:
321311
total_num_gpu_blocks = int(KV_CACHE_MANAGER_NUM_GPU_BLOCKS)
322312
else:
323-
total_num_gpu_blocks = None
313+
# Firstly, let's compute the size of a cache block in bytes
314+
kv_cache_block_size = self.model.get_kv_cache_block_size(block_size)
315+
total_size = model_config.num_hidden_layers * kv_cache_block_size
316+
dtype_size = torch.tensor([], dtype=dtype).element_size()
317+
cache_block_size = dtype_size * total_size
318+
# We then use our memory scaling model to determine the fraction of the prefill memory
319+
# usage that is due to cache blocks (as opposed to the other stuff needed for forward):
320+
pf_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.linear_fit_params[0]
321+
# We can then do the same for the next token (decoding) step:
322+
nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.next_token_params[1]
323+
# In general we know that the next token phase can use many more cache blocks
324+
# relative to the prefill phase (e.g., nt_cache_block_ratio > pf_cache_block_ratio).
325+
# Thus, we need to allocate enough cache blocks to handle the more extreme case:
326+
total_num_gpu_blocks = int(nt_cache_block_ratio * memory_scaling_model.free_memory // cache_block_size)
327+
# This creates an issue though, because if we then try to perform a large prefill, while we
328+
# will certainly have enough cache blocks available, we may not have enough memory leftover
329+
# to allocate the other data structures needed during a forward pass.
330+
# To overcome this, we can set the batch_safety_margin a bit to ensure that:
331+
# free_memory * (1.0-batch_safety_margin/100-0.05) * (1.0-pf_cache_block_ratio) <
332+
# free_memory * (1.0-nf_cache_block_ratio)
333+
# This should ensure that our prefills batches can never get so big as to cause OOM.
334+
recommend_safety_margin = 5 + int(100*(1.0 - (1.0 - nt_cache_block_ratio)/(1.0 - pf_cache_block_ratio)))
335+
if memory_scaling_model.safety_margin < recommend_safety_margin:
336+
print(f"WARN: We recommend increasing the value of BATCH_SAFETY_MARGIN to: {recommend_safety_margin}")
324337

325338
self.kv_cache_manager = PagedKVCacheManager(
326339
model_config.num_hidden_layers,
@@ -331,8 +344,11 @@ def __init__(
331344
dtype=dtype,
332345
device=self.device,
333346
total_num_gpu_blocks=total_num_gpu_blocks,
347+
block_size=block_size,
334348
)
335349

350+
self.memory_scaling_model = memory_scaling_model
351+
336352
# log number of free blocks at init
337353
print("[PagedKVCacheManager] number of free blocks: %d" % (len(self.kv_cache_manager.free_blocks)))
338354

@@ -413,12 +429,18 @@ def _prefill(
413429
)
414430

415431
t0 = time.time_ns()
416-
output = self.model(
417-
input_ids,
418-
position_ids=position_ids,
419-
cache_data=cache_data,
420-
return_embeds=True,
421-
)
432+
try:
433+
output = self.model(
434+
input_ids,
435+
position_ids=position_ids,
436+
cache_data=cache_data,
437+
return_embeds=True,
438+
)
439+
except:
440+
# if something goes wrong during forward, we still need to set the sequence ids
441+
#TODO it would be better to fix the forward method to avoid possibility of partial failures
442+
batch.sequence_ids = cache_data.sequence_ids
443+
raise
422444
t_forward_ns = time.time_ns()-t0
423445
logits, embeds = output
424446

@@ -603,10 +625,7 @@ def generate_token(
603625
)
604626
else:
605627
bsize = batch.input_ids.shape[0]
606-
607-
tokens_remaining = 0
608-
for i in range(len(batch.total_lengths)):
609-
tokens_remaining += batch.total_lengths[i] - batch.input_lengths[i]
628+
weight = sum(batch.total_lengths) * self.memory_scaling_model.next_token_params[1]
610629

611630
spec_ind = []
612631
for i, sample in enumerate(batch.next_token_chooser.do_sample):
@@ -618,7 +637,7 @@ def generate_token(
618637
len(spec_ind) > 0 and
619638
bsize <= SPECULATOR_MAX_BATCH_SIZE and
620639
batch.next_token_chooser.repetition_processor is None and
621-
tokens_remaining < 0.25*len(self.kv_cache_manager.free_blocks)*self.kv_cache_manager.block_size
640+
(weight/self.memory_scaling_model.weight_limit) <= 0.75
622641
)
623642

624643
if speculate:

server/text_generation_server/server.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def func_with_log(*args, **kwargs):
5656

5757

5858
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
59-
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModelPB):
59+
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModel):
6060
self.cache = cache
6161
self.model = model
6262
self.server_urls = server_urls
@@ -81,7 +81,7 @@ async def ModelInfo(self, request: generate_pb2.ModelInfoRequest, context) -> ge
8181
if isinstance(self.model, Seq2SeqLM) else ModelInfoResponse.ModelType.CAUSAL_LM,
8282
eos_token=getattr(self.model.tokenizer, 'model_eos_token_id', self.model.tokenizer.eos_token_id),
8383
batch_padding=not isinstance(self.model, FlashCausalLM),
84-
memory_scaling_model=self.memory_scaling_model,
84+
memory_scaling_model=self.memory_scaling_model.as_pb(),
8585
)
8686

8787
@log_rpc_handler_errors
@@ -244,8 +244,9 @@ def _free_paged_sequences(self, batch: "Batch", completed_ids: Optional[List[int
244244
]
245245
else:
246246
return
247-
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)
248247

248+
if sequence_ids_to_free is not None:
249+
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)
249250

250251
def serve(
251252
model_name: str,
@@ -273,6 +274,22 @@ async def serve_inner(
273274
batch_safety_margin: int,
274275
sharded: bool = False,
275276
):
277+
if quantize not in [None, "gptq", "bitsandbytes"]:
278+
raise ValueError(f"Unrecognized quantization method specified: {quantize}")
279+
280+
if quantize is None and dtype_str == "int8":
281+
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
282+
quantize = "bitsandbytes"
283+
284+
cuda_available = torch.cuda.is_available()
285+
286+
# Default dtype based on device if not provided
287+
if dtype_str is None:
288+
dtype_str = "float16" if cuda_available else "float32"
289+
290+
if quantize is not None and not cuda_available:
291+
raise ValueError("Quantization requires CUDA")
292+
276293
if ESTIMATE_MEMORY == "auto" and PAGED_ATTENTION:
277294
# fit memory model using flash model in separate process (ensures GPU memory is entirely cleaned up)
278295
from text_generation_server.utils.paged import fit_memory_scaling_model
@@ -286,6 +303,8 @@ async def serve_inner(
286303
proc.start()
287304
memory_scaling_model_ext = q_out.get()
288305
proc.join()
306+
else:
307+
memory_scaling_model_ext = None
289308

290309
unix_socket_template = "unix://{}-{}"
291310
world_size = int(os.getenv("WORLD_SIZE", "1"))
@@ -296,28 +315,12 @@ async def serve_inner(
296315
]
297316
local_url = server_urls[local_rank]
298317

299-
if quantize not in [None, "gptq", "bitsandbytes"]:
300-
raise ValueError(f"Unrecognized quantization method specified: {quantize}")
301-
302-
# Default dtype based on device if not provided
303-
if dtype_str is None:
304-
dtype_str = "float16" if torch.cuda.is_available() else "float32"
305-
306-
if quantize is None and dtype_str == "int8":
307-
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
308-
quantize = "bitsandbytes"
309-
310-
cuda_available = torch.cuda.is_available()
311-
312-
if quantize is not None and not cuda_available:
313-
raise ValueError("Quantization requires CUDA")
314-
315318
# Set the fraction of cuda/gpu mem available to this process, then load the model
316319
if cuda_available and cuda_process_memory_fraction < 1:
317320
torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)
318321

319322
model = get_model(
320-
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
323+
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length, memory_scaling_model_ext,
321324
)
322325

323326
device = model.engine.get_device()
@@ -424,7 +427,7 @@ def estimate_memory():
424427

425428
server = aio.server()
426429
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
427-
TextGenerationService(model, Cache(), server_urls, memory_scaling_model.as_pb()), server
430+
TextGenerationService(model, Cache(), server_urls, memory_scaling_model), server
428431
)
429432
# SERVICE_NAMES = (
430433
# generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,

server/text_generation_server/utils/paged.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,37 @@
55

66
from fms_extras.models.speculator import flatten_batch, apply_index_map
77

8+
# HF name or path to speculator model (None means no speculation will be used)
9+
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)
10+
11+
# speculator revision
12+
SPECULATOR_REVISION = os.getenv("SPECULATOR_REVISION", None)
13+
814
# number of candidates during speculation
915
SPECULATOR_N_CANDIDATES = os.getenv("SPECULATOR_N_CANDIDATES", None)
1016

1117
# number of candidates per head
1218
SPECULATOR_TOP_K_TOKENS_PER_HEAD = os.getenv("SPECULATOR_TOP_K_TOKENS_PER_HEAD", None)
1319

20+
def load_speculator(device, dtype):
21+
22+
if SPECULATOR_NAME is not None:
23+
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
24+
from text_generation_server.utils.hub import get_model_path
25+
from text_generation_server.utils import print_rank_n
26+
speculator_model_path = get_model_path(SPECULATOR_NAME, SPECULATOR_REVISION)
27+
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
28+
kwargs = {
29+
"pretrained_model_name_or_path": speculator_model_path,
30+
"local_files_only": True,
31+
"torch_dtype": dtype,
32+
}
33+
with device:
34+
speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
35+
speculator.to(device=device)
36+
return speculator
37+
else:
38+
return None
1439

1540
def fit_memory_scaling_model(
1641
model_name: str,
@@ -38,6 +63,8 @@ def fit_memory_scaling_model(
3863
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
3964
)
4065

66+
speculator = load_speculator(model.device, model.dtype)
67+
4168
memory_scaling_model = Estimator.build_from_env(
4269
model,
4370
batch_safety_margin,

0 commit comments

Comments
 (0)