Skip to content

Commit bfad393

Browse files
authored
[Inference/Feat] Add quant kvcache interface (#5700)
* add quant kvcache interface * delete unused output * complete args comments
1 parent 492520d commit bfad393

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

colossalai/inference/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class InferenceConfig:
8888
max_output_len (int): Maximum output length, defaults to 256.
8989
max_input_len (int): Maximum input length, defaults to 256.
9090
dtype (Union[str, torch.dtype]): The data type for weights and activations.
91+
kv_cache_dtype (Optional[str]): The data type of kv_cache, defaults to None.
9192
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
9293
do_sample (bool): Whether to use sampling for generation, defaults to False.
9394
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
@@ -122,6 +123,7 @@ class InferenceConfig:
122123

123124
# general configs
124125
dtype: Union[str, torch.dtype] = torch.float16 # use fp16 by default
126+
kv_cache_dtype: Optional[str] = None
125127

126128
# generation configs
127129
prompt_template: Optional[str] = None
@@ -177,6 +179,12 @@ def _verify_config(self) -> None:
177179
self.dtype in _ALLOWED_DTYPES
178180
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
179181

182+
if self.kv_cache_dtype:
183+
assert (
184+
self.use_cuda_kernel and self.kv_cache_dtype == "fp8"
185+
), f"FP8 kv_cache is only supported with use_cuda_kernel open now"
186+
self.kv_cache_dtype = torch.uint8
187+
180188
# skip using casting when the data type is float32
181189
if self.dtype == torch.float32:
182190
self.high_precision = False

colossalai/inference/kv_cache/kvcache_manager.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> N
5353
self.tp_size = config.tp_size
5454
# Model settings
5555
self.dtype = config.dtype
56+
57+
if config.kv_cache_dtype is None:
58+
self.kv_cache_dtype = config.dtype
59+
else:
60+
self.kv_cache_dtype = config.kv_cache_dtype
61+
5662
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
5763
self.num_layers = model_config.num_hidden_layers
5864
self.head_num = model_config.num_attention_heads
@@ -488,6 +494,6 @@ def _init_device_caches(
488494
k_cache: List[torch.Tensor] = []
489495
v_cache: List[torch.Tensor] = []
490496
for _ in range(self.num_layers):
491-
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
492-
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
497+
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
498+
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
493499
return k_cache, v_cache

0 commit comments

Comments
 (0)