Skip to content

Commit 18d67d0

Browse files
[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent de4bf3d commit 18d67d0

File tree

15 files changed

+1032
-63
lines changed

15 files changed

+1032
-63
lines changed

colossalai/inference/config.py

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
33
"""
44
import logging
5+
from abc import ABC, abstractmethod
56
from dataclasses import dataclass, fields
6-
from typing import Any, Dict, Optional, Union
7+
from typing import Any, Dict, List, Optional, Union
78

89
import torch
9-
import torch.distributed as dist
1010
from transformers.generation import GenerationConfig
1111

1212
from colossalai.inference.flash_decoding_utils import FDIntermTensors
@@ -30,8 +30,25 @@
3030
}
3131

3232

33+
class RPC_PARAM(ABC):
34+
"""
35+
NOTE(lry89757) We use rpyc to transport param between client and server.
36+
Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes.
37+
Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`.
38+
"""
39+
40+
@abstractmethod
41+
def to_rpc_param(self):
42+
return NotImplementedError
43+
44+
@staticmethod
45+
@abstractmethod
46+
def from_rpc_param():
47+
return NotImplementedError
48+
49+
3350
@dataclass
34-
class InputMetaData:
51+
class InputMetaData(RPC_PARAM):
3552
"""The input info for a single step
3653
3754
Args:
@@ -48,6 +65,7 @@ class InputMetaData:
4865
dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32.
4966
use_spec_dec (bool): Indicate whether to use speculative decoding.
5067
num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True.
68+
batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process.
5169
"""
5270

5371
block_tables: torch.Tensor = None
@@ -63,6 +81,54 @@ class InputMetaData:
6381
dtype: torch.dtype = torch.float32
6482
use_spec_dec: bool = False
6583
num_tokens_to_verify: int = 0
84+
batch_token_ids: Optional[
85+
List[List[int]]
86+
] = None # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process
87+
88+
def to_rpc_param(self) -> Dict[str, any]:
89+
return {
90+
"block_tables": self.block_tables.tolist(),
91+
"sequence_lengths": self.sequence_lengths.tolist(),
92+
"batch_size": self.batch_size,
93+
"is_prompts": self.is_prompts,
94+
"use_cuda_kernel": self.use_cuda_kernel,
95+
"use_cuda_graph": self.use_cuda_graph,
96+
"kv_seq_len": self.kv_seq_len,
97+
"head_dim": self.head_dim,
98+
"high_precision": self.high_precision,
99+
"dtype": str(self.dtype).split(".")[-1],
100+
"use_spec_dec": self.use_spec_dec,
101+
"num_tokens_to_verify": self.num_tokens_to_verify,
102+
"batch_token_ids": self.batch_token_ids,
103+
}
104+
105+
@staticmethod
106+
def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData":
107+
"""
108+
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
109+
"""
110+
from colossalai.accelerator import get_accelerator
111+
112+
dtype = getattr(torch, rpc_dict["dtype"])
113+
return InputMetaData(
114+
block_tables=torch.tensor(
115+
rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device()
116+
),
117+
sequence_lengths=torch.tensor(
118+
rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device()
119+
),
120+
batch_size=rpc_dict["batch_size"],
121+
is_prompts=rpc_dict["is_prompts"],
122+
use_cuda_kernel=rpc_dict["use_cuda_kernel"],
123+
use_cuda_graph=rpc_dict["use_cuda_graph"],
124+
kv_seq_len=rpc_dict["kv_seq_len"],
125+
head_dim=rpc_dict["head_dim"],
126+
high_precision=rpc_dict["high_precision"],
127+
dtype=dtype,
128+
use_spec_dec=rpc_dict["use_spec_dec"],
129+
num_tokens_to_verify=rpc_dict["num_tokens_to_verify"],
130+
batch_token_ids=rpc_dict["batch_token_ids"],
131+
)
66132

67133
def __repr__(self) -> str:
68134
return (
@@ -80,7 +146,7 @@ def __repr__(self) -> str:
80146

81147

82148
@dataclass
83-
class InferenceConfig:
149+
class InferenceConfig(RPC_PARAM):
84150
"""The inference configuration.
85151
86152
Args:
@@ -193,10 +259,6 @@ def _verify_config(self) -> None:
193259
if self.dtype == torch.float32:
194260
self.high_precision = False
195261

196-
# check distributed
197-
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
198-
self.tp_size * self.pp_size == dist.get_world_size()
199-
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
200262
# check prompt template
201263
if self.prompt_template is None:
202264
return
@@ -226,6 +288,43 @@ def to_generation_config(self, model_config) -> GenerationConfig:
226288

227289
return GenerationConfig.from_dict(meta_config)
228290

291+
def to_rpc_param(self) -> dict:
292+
kwargs = {
293+
"dtype": str(self.dtype).split(".")[-1],
294+
"max_n_spec_tokens": self.max_n_spec_tokens,
295+
"max_batch_size": self.max_batch_size,
296+
"max_input_len": self.max_input_len,
297+
"max_output_len": self.max_output_len,
298+
"tp_size": self.tp_size,
299+
"pp_size": self.pp_size,
300+
"pad_input": self.pad_input,
301+
"early_stopping": self.early_stopping,
302+
"do_sample": self.do_sample,
303+
"beam_width": self.beam_width,
304+
"kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1],
305+
}
306+
return kwargs
307+
308+
@staticmethod
309+
def from_rpc_param(rpc_dict: dict) -> "InferenceConfig":
310+
"""
311+
We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message
312+
"""
313+
return InferenceConfig(
314+
dtype=getattr(torch, rpc_dict["dtype"]),
315+
max_n_spec_tokens=rpc_dict["max_n_spec_tokens"],
316+
max_batch_size=rpc_dict["max_batch_size"],
317+
max_input_len=rpc_dict["max_input_len"],
318+
max_output_len=rpc_dict["max_output_len"],
319+
tp_size=rpc_dict["tp_size"],
320+
pp_size=rpc_dict["pp_size"],
321+
pad_input=rpc_dict["pad_input"],
322+
early_stopping=rpc_dict["early_stopping"],
323+
do_sample=rpc_dict["do_sample"],
324+
beam_width=rpc_dict["beam_width"],
325+
kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None),
326+
)
327+
229328
@classmethod
230329
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
231330
# Get the list of attributes of this dataclass.

colossalai/inference/core/engine.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from colossalai.inference.config import InferenceConfig, InputMetaData
2222
from colossalai.inference.graph_runner import CUDAGraphRunner
2323
from colossalai.inference.modeling.policy import model_policy_map
24+
from colossalai.inference.sampler import search_tokens
2425
from colossalai.inference.spec import Drafter, GlideInput
2526
from colossalai.inference.struct import Sequence
2627
from colossalai.inference.utils import get_model_size, has_index_file
@@ -424,7 +425,7 @@ def steps_spec_dec(self) -> List[Sequence]:
424425

425426
# 2. Prefill main model (Verifier) - fill past kv cache for main model
426427
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
427-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
428+
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
428429
# append new inputs to the batch, temporarily
429430
batch.append_batch_tokens(next_tokens)
430431
self.request_handler.allocate_batch_spec_dec(batch, 1)
@@ -472,7 +473,7 @@ def steps_spec_dec(self) -> List[Sequence]:
472473
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
473474
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
474475

475-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
476+
next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
476477

477478
# 5. Compare and process the results
478479
diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec))
@@ -689,6 +690,13 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
689690
(n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
690691
)
691692

693+
batch_token_ids = None
694+
config_dict = self.generation_config.to_dict()
695+
# process repetition_penalty, no_repeat_ngram_size
696+
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
697+
if type in config_dict and config_dict[type] is not None:
698+
batch_token_ids = batch.batch_token_ids
699+
692700
# only when we have the graph for specific decoding batch size can we use the cuda graph for inference
693701
use_cuda_graph = False
694702
if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys():
@@ -708,6 +716,7 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[torch.Tensor, torch.Tensor,
708716
dtype=batch.dtype,
709717
use_spec_dec=batch.use_spec_dec,
710718
num_tokens_to_verify=batch.num_tokens_to_verify,
719+
batch_token_ids=batch_token_ids,
711720
)
712721

713722
return input_ids, output_tensor, input_meta_data
@@ -738,7 +747,9 @@ def step(self) -> List[str]:
738747
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
739748
if self.inference_config.pad_input:
740749
logits = logits[:, -1, :]
741-
next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch)
750+
next_tokens = search_tokens(
751+
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
752+
)
742753
self.request_handler.append_next_tokens(next_tokens)
743754
finished_sequences = self.request_handler.update()
744755

colossalai/inference/core/request_handler.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from colossalai.inference.batch_bucket import BatchBucket
88
from colossalai.inference.config import InferenceConfig
99
from colossalai.inference.flash_decoding_utils import FDIntermTensors
10-
from colossalai.inference.kv_cache import KVCacheManager
11-
from colossalai.inference.logit_processors import logit_processor
12-
from colossalai.inference.sampler import *
10+
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
1311
from colossalai.inference.struct import RequestStatus, Sequence
12+
from colossalai.logging import get_dist_logger
13+
14+
logger = get_dist_logger(__name__)
1415

1516
__all__ = ["RunningList", "RequestHandler"]
1617

@@ -295,17 +296,6 @@ def _find_sequence(self, request_id: int) -> Sequence:
295296

296297
return None
297298

298-
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
299-
if generation_config.num_beams == 1:
300-
if generation_config.do_sample:
301-
sample_tokens = multinomial_sample(generation_config, probs)
302-
else:
303-
sample_tokens = greedy_sample(generation_config, logprobs)
304-
else:
305-
sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty)
306-
307-
return sample_tokens
308-
309299
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
310300
if (
311301
sequence.output_token_id[-1] == generation_config.eos_token_id
@@ -328,33 +318,6 @@ def check_unfinished_seqs(self) -> bool:
328318
def total_requests_in_batch_bucket(self) -> int:
329319
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
330320

331-
def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket):
332-
"""
333-
Sample tokens for finished requests.
334-
"""
335-
336-
# NOTE: need to decide the granularity to process logits (sequence or batch)
337-
config_dict = generation_config.to_dict()
338-
# process repetition_penalty, no_repeat_ngram_size
339-
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
340-
if type in config_dict and config_dict[type] is not None:
341-
logits = logit_processor(type, logits, config_dict[type], cur_batch)
342-
343-
# do logit processor
344-
if generation_config.do_sample:
345-
# process temperature, top_k, top_p
346-
for type in ["temperature", "top_k", "top_p"]:
347-
if type in config_dict and config_dict[type] is not None:
348-
logits = logit_processor(type, logits, config_dict[type])
349-
350-
# calculate probs
351-
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
352-
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
353-
354-
# sample the next tokens
355-
sample_tokens = self._sample(probs, logprobs, generation_config)
356-
return sample_tokens
357-
358321
def append_next_tokens(self, sample_tokens: torch.Tensor):
359322
assert sample_tokens.dim() == 1
360323
n_elements = sample_tokens.size(0)
@@ -386,3 +349,53 @@ def update(self):
386349
self.done_list.extend(finished_seqs)
387350

388351
return finished_seqs
352+
353+
354+
class RPCRequestHandler(RequestHandler):
355+
"""
356+
RPC Version of request handler
357+
"""
358+
359+
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
360+
self.inference_config = inference_config
361+
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
362+
self.waiting_list: List[List] = [[], [], []]
363+
self.done_list: List[Sequence] = []
364+
self.dtype = inference_config.dtype
365+
self.max_batch_size = inference_config.max_batch_size
366+
367+
# initialize cache
368+
self._init_cache(model_config)
369+
370+
# initialize batch
371+
torch.cuda.current_device()
372+
kv_max_split_num = (
373+
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
374+
) // inference_config.block_size
375+
head_dim = model_config.hidden_size // model_config.num_attention_heads
376+
377+
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
378+
# which may cause bugs and this issue should be fixed later.
379+
self.running_bb = BatchBucket(
380+
num_heads=model_config.num_attention_heads // inference_config.tp_size,
381+
head_dim=head_dim,
382+
max_batch_size=self.max_batch_size,
383+
max_length=inference_config.max_input_len + inference_config.max_output_len,
384+
block_size=inference_config.block_size,
385+
kv_max_split_num=kv_max_split_num,
386+
fd_interm_tensor=None,
387+
dtype=self.dtype,
388+
)
389+
self.prefill_bb = BatchBucket(
390+
num_heads=model_config.num_attention_heads // inference_config.tp_size,
391+
head_dim=head_dim,
392+
max_batch_size=self.max_batch_size,
393+
max_length=inference_config.max_input_len + inference_config.max_output_len,
394+
block_size=inference_config.block_size,
395+
kv_max_split_num=kv_max_split_num,
396+
fd_interm_tensor=None,
397+
dtype=self.dtype,
398+
)
399+
400+
def _init_cache(self, model_config):
401+
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)

0 commit comments

Comments
 (0)