@@ -88,6 +88,7 @@ class InferenceConfig:
88
88
max_output_len (int): Maximum output length, defaults to 256.
89
89
max_input_len (int): Maximum input length, defaults to 256.
90
90
dtype (Union[str, torch.dtype]): The data type for weights and activations.
91
+ kv_cache_dtype (Optional[str]): The data type of kv_cache, defaults to None.
91
92
prompt_template (Optional[str]): The prompt template for generation, defaults to None.
92
93
do_sample (bool): Whether to use sampling for generation, defaults to False.
93
94
beam_width (int): The maximum beam width used to initialize KV Cache, defaults to 1.
@@ -122,6 +123,7 @@ class InferenceConfig:
122
123
123
124
# general configs
124
125
dtype : Union [str , torch .dtype ] = torch .float16 # use fp16 by default
126
+ kv_cache_dtype : Optional [str ] = None
125
127
126
128
# generation configs
127
129
prompt_template : Optional [str ] = None
@@ -177,6 +179,12 @@ def _verify_config(self) -> None:
177
179
self .dtype in _ALLOWED_DTYPES
178
180
), f"Expected dtype to be in { _ALLOWED_DTYPES } but found an unknown dtype: { self .dtype } "
179
181
182
+ if self .kv_cache_dtype :
183
+ assert (
184
+ self .use_cuda_kernel and self .kv_cache_dtype == "fp8"
185
+ ), f"FP8 kv_cache is only supported with use_cuda_kernel open now"
186
+ self .kv_cache_dtype = torch .uint8
187
+
180
188
# skip using casting when the data type is float32
181
189
if self .dtype == torch .float32 :
182
190
self .high_precision = False
0 commit comments