22
33Run `pytest tests/quantization/test_fp8.py --forked`.
44"""
5+
56import pytest
67import torch
78
8- from tests .quantization .utils import is_quant_method_supported
9+ from tests .quantization .utils import (
10+ is_quant_method_supported ,
11+ check_logprobs_close ,
12+ check_target_closer ,
13+ )
914from vllm import _custom_ops as ops
10- from vllm .model_executor .layers .quantization .fp8 import (Fp8KVCacheMethod ,
11- Fp8LinearMethod )
15+ from vllm .model_executor .layers .quantization .fp8 import (
16+ Fp8KVCacheMethod ,
17+ Fp8LinearMethod ,
18+ )
19+ from vllm .model_executor .layers .quantization .ptpc_fp8 import PTPCFp8LinearMethod
20+ from vllm .model_executor .models .utils import PPMissingLayer
1221from vllm .platforms import current_platform
1322
23+
1424MODELS = [
1525 "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV" ,
1626 "nm-testing/Phi-3-mini-128k-instruct-FP8" ,
1727 "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV" ,
1828]
1929
2030
21- @pytest .mark .skipif (not is_quant_method_supported ("fp8" ),
22- reason = "FP8 is not supported on this GPU type." )
31+ @pytest .mark .skipif (
32+ not is_quant_method_supported ("fp8" ),
33+ reason = "FP8 is not supported on this GPU type." ,
34+ )
2335@pytest .mark .parametrize ("model_id" , MODELS )
2436@pytest .mark .parametrize ("force_marlin" , [False , True ])
25- def test_model_load_and_run (vllm_runner , model_id : str , force_marlin : bool ,
26- monkeypatch ) -> None :
37+ def test_model_load_and_run (
38+ vllm_runner , model_id : str , force_marlin : bool , monkeypatch
39+ ) -> None :
2740 if force_marlin :
2841 monkeypatch .setenv ("VLLM_TEST_FORCE_FP8_MARLIN" , "1" )
2942
3043 with vllm_runner (model_id ) as llm :
3144 # note: this does not test accuracy, just that we can run through
3245 # see lm-eval tests for accuracy
33- outputs = llm .generate_greedy (prompts = ["Hello my name is" ],
34- max_tokens = 10 )
46+ outputs = llm .generate_greedy (prompts = ["Hello my name is" ], max_tokens = 10 )
3547 print (outputs [0 ][1 ])
3648
3749
@@ -43,13 +55,17 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
4355]
4456
4557
46- @pytest .mark .skipif (not is_quant_method_supported ("fp8" ),
47- reason = "FP8 is not supported on this GPU type." )
58+ @pytest .mark .skipif (
59+ not is_quant_method_supported ("fp8" ),
60+ reason = "FP8 is not supported on this GPU type." ,
61+ )
4862@pytest .mark .parametrize ("model_id" , KV_CACHE_MODELS )
4963def test_kv_cache_model_load_and_run (vllm_runner , model_id : str ):
5064 with vllm_runner (model_id , kv_cache_dtype = "fp8" ) as llm :
5165
52- model = llm .model .llm_engine .model_executor .driver_worker .model_runner .model # noqa: E501
66+ model = (
67+ llm .model .llm_engine .model_executor .driver_worker .model_runner .model
68+ ) # noqa: E501
5369 attn = model .model .layers [0 ].self_attn .attn
5470 assert isinstance (attn .quant_method , Fp8KVCacheMethod )
5571 # NOTE: it is valid for scales to be 1.0 (default value), but we know
@@ -59,25 +75,29 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
5975
6076 # note: this does not test accuracy, just that we can run through
6177 # see lm-eval tests for accuracy
62- outputs = llm .generate_greedy (prompts = ["Hello my name is" ],
63- max_tokens = 10 )
78+ outputs = llm .generate_greedy (prompts = ["Hello my name is" ], max_tokens = 10 )
6479 print (outputs [0 ][1 ])
6580
6681
67- @pytest .mark .skipif (not is_quant_method_supported ("fp8" ),
68- reason = "FP8 is not supported on this GPU type." )
82+ @pytest .mark .skipif (
83+ not is_quant_method_supported ("fp8" ),
84+ reason = "FP8 is not supported on this GPU type." ,
85+ )
6986@pytest .mark .parametrize ("kv_cache_dtype" , ["auto" , "fp8" ])
7087@pytest .mark .parametrize ("force_marlin" , [False , True ])
71- def test_load_fp16_model (vllm_runner , kv_cache_dtype : str , force_marlin : bool ,
72- monkeypatch ) -> None :
88+ def test_load_fp16_model (
89+ vllm_runner , kv_cache_dtype : str , force_marlin : bool , monkeypatch
90+ ) -> None :
7391 if force_marlin :
7492 monkeypatch .setenv ("VLLM_TEST_FORCE_FP8_MARLIN" , "1" )
7593
76- with vllm_runner ("facebook/opt-125m" ,
77- quantization = "fp8" ,
78- kv_cache_dtype = kv_cache_dtype ) as llm :
94+ with vllm_runner (
95+ "facebook/opt-125m" , quantization = "fp8" , kv_cache_dtype = kv_cache_dtype
96+ ) as llm :
7997
80- model = llm .model .llm_engine .model_executor .driver_worker .model_runner .model # noqa: E501
98+ model = (
99+ llm .model .llm_engine .model_executor .driver_worker .model_runner .model
100+ ) # noqa: E501
81101 fc1 = model .model .decoder .layers [0 ].fc1
82102 assert isinstance (fc1 .quant_method , Fp8LinearMethod )
83103 if kv_cache_dtype == "fp8" :
@@ -95,8 +115,10 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
95115 assert fc1 .weight .dtype == torch .int32
96116
97117
98- @pytest .mark .skipif (not is_quant_method_supported ("fp8" ),
99- reason = "FP8 is not supported on this GPU type." )
118+ @pytest .mark .skipif (
119+ not is_quant_method_supported ("fp8" ),
120+ reason = "FP8 is not supported on this GPU type." ,
121+ )
100122@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
101123def test_scaled_fp8_quant (dtype ) -> None :
102124
@@ -105,8 +127,7 @@ def quantize_ref(tensor, inv_scale):
105127 # the kernel being tested.
106128 finfo = torch .finfo (torch .float8_e4m3fn )
107129 scale = inv_scale .reciprocal ()
108- qweight = (tensor .to (torch .float32 ) * scale ).clamp (min = finfo .min ,
109- max = finfo .max )
130+ qweight = (tensor .to (torch .float32 ) * scale ).clamp (min = finfo .min , max = finfo .max )
110131 qweight = qweight .to (torch .float8_e4m3fn )
111132 return qweight
112133
@@ -125,18 +146,172 @@ def per_tensor_dequantize(tensor, inv_scale, dtype):
125146
126147 # Reference dynamic quantizaton
127148 y = quantize_ref (x , inv_scale )
128- torch .testing .assert_close (ref_y ,
129- per_tensor_dequantize (y , inv_scale , dtype ))
149+ torch .testing .assert_close (ref_y , per_tensor_dequantize (y , inv_scale , dtype ))
130150
131151 # Static quantization
132152 y , _ = ops .scaled_fp8_quant (x , inv_scale )
133- torch .testing .assert_close (ref_y ,
134- per_tensor_dequantize (y , inv_scale , dtype ))
153+ torch .testing .assert_close (ref_y , per_tensor_dequantize (y , inv_scale , dtype ))
135154
136155 # Padding
137156 y , _ = ops .scaled_fp8_quant (x , inv_scale , num_token_padding = 17 )
138157 assert y .shape [0 ] == 17
139158 torch .testing .assert_close (
140159 ref_y ,
141- per_tensor_dequantize (torch .narrow (y , 0 , 0 , x .shape [0 ]), inv_scale ,
142- dtype ))
160+ per_tensor_dequantize (torch .narrow (y , 0 , 0 , x .shape [0 ]), inv_scale , dtype ),
161+ )
162+
163+
164+ PTPC_MODELS = ["meta-llama/Llama-3.1-8B-Instruct" ]
165+
166+ MAX_MODEL_LEN = 1024
167+ NUM_LOG_PROBS = 8
168+
169+
170+ @pytest .mark .skipif (
171+ not is_quant_method_supported ("fp8" ),
172+ reason = "FP8 is not supported on this GPU type." ,
173+ )
174+ @pytest .mark .skipif (not torch .version .hip , reason = "Requires HIP" )
175+ @pytest .mark .parametrize ("test_model" , PTPC_MODELS )
176+ @pytest .mark .parametrize ("force_marlin" , [False , True ])
177+ def test_load_fp16_ptpc_model (
178+ vllm_runner , test_model : str , force_marlin : bool , monkeypatch
179+ ) -> None :
180+ if force_marlin :
181+ monkeypatch .setenv ("VLLM_TEST_FORCE_FP8_MARLIN" , "1" )
182+
183+ with vllm_runner (test_model , quantization = "ptpc_fp8" ) as llm :
184+
185+ model = llm .model .llm_engine .model_executor .driver_worker .model_runner .model
186+ for layer in model .model .layers :
187+ if not isinstance (layer , PPMissingLayer ):
188+ assert isinstance (
189+ layer .self_attn .qkv_proj .quant_method ,
190+ PTPCFp8LinearMethod ,
191+ )
192+
193+
194+ @pytest .mark .skipif (
195+ not is_quant_method_supported ("fp8" ),
196+ reason = "FP8 is not supported on this GPU type." ,
197+ )
198+ @pytest .mark .skipif (not torch .version .hip , reason = "Requires HIP" )
199+ @pytest .mark .parametrize ("test_model" , PTPC_MODELS )
200+ @pytest .mark .parametrize ("kv_cache_dtype" , ["fp8_e4m3" ])
201+ @pytest .mark .parametrize ("max_tokens" , [4 ])
202+ @pytest .mark .parametrize ("enforce_eager" , [True ])
203+ @pytest .mark .parametrize ("tensor_parallel_size" , [1 ])
204+ @pytest .mark .parametrize ("disable_async_output_proc" , [True ])
205+ @pytest .mark .parametrize ("force_marlin" , [False , True ])
206+ def test_ptpc_fp18 (
207+ vllm_runner ,
208+ example_prompts ,
209+ test_model : str ,
210+ kv_cache_dtype : str ,
211+ max_tokens : int ,
212+ enforce_eager : bool ,
213+ tensor_parallel_size : int ,
214+ disable_async_output_proc : bool ,
215+ force_marlin : bool ,
216+ monkeypatch ,
217+ ) -> None :
218+ if force_marlin :
219+ monkeypatch .setenv ("VLLM_TEST_FORCE_FP8_MARLIN" , "1" )
220+
221+ with vllm_runner (
222+ test_model ,
223+ max_model_len = MAX_MODEL_LEN ,
224+ tensor_parallel_size = tensor_parallel_size ,
225+ enforce_eager = enforce_eager ,
226+ kv_cache_dtype = kv_cache_dtype ,
227+ disable_async_output_proc = disable_async_output_proc ,
228+ ) as vllm_model :
229+ baseline_outputs = vllm_model .generate_greedy_logprobs (
230+ example_prompts , max_tokens , NUM_LOG_PROBS
231+ )
232+
233+ with vllm_runner (
234+ test_model ,
235+ max_model_len = MAX_MODEL_LEN ,
236+ tensor_parallel_size = tensor_parallel_size ,
237+ enforce_eager = enforce_eager ,
238+ kv_cache_dtype = kv_cache_dtype ,
239+ disable_async_output_proc = disable_async_output_proc ,
240+ quantization = "ptpc_fp8" ,
241+ ) as vllm_model :
242+ ptpc_fp8_outputs = vllm_model .generate_greedy_logprobs (
243+ example_prompts , max_tokens , NUM_LOG_PROBS
244+ )
245+
246+ with vllm_runner (
247+ test_model ,
248+ max_model_len = MAX_MODEL_LEN ,
249+ tensor_parallel_size = tensor_parallel_size ,
250+ enforce_eager = enforce_eager ,
251+ kv_cache_dtype = kv_cache_dtype ,
252+ disable_async_output_proc = disable_async_output_proc ,
253+ quantization = "fp8" ,
254+ ) as vllm_model :
255+ fp8_outputs = vllm_model .generate_greedy_logprobs (
256+ example_prompts , max_tokens , NUM_LOG_PROBS
257+ )
258+
259+ check_target_closer (baseline_outputs , ptpc_fp8_outputs , fp8_outputs )
260+
261+
262+ @pytest .mark .skipif (
263+ not is_quant_method_supported ("fp8" ),
264+ reason = "FP8 is not supported on this GPU type." ,
265+ )
266+ @pytest .mark .skipif (not torch .version .hip , reason = "Requires HIP" )
267+ @pytest .mark .parametrize ("test_model" , PTPC_MODELS )
268+ @pytest .mark .parametrize ("kv_cache_dtype" , ["fp8_e4m3" ])
269+ @pytest .mark .parametrize ("max_tokens" , [4 ])
270+ @pytest .mark .parametrize ("enforce_eager" , [True ])
271+ @pytest .mark .parametrize ("tensor_parallel_size" , [1 ])
272+ @pytest .mark .parametrize ("disable_async_output_proc" , [True ])
273+ @pytest .mark .parametrize ("force_marlin" , [False , True ])
274+ def test_ptpc_baseline (
275+ vllm_runner ,
276+ example_prompts ,
277+ test_model : str ,
278+ kv_cache_dtype : str ,
279+ max_tokens : int ,
280+ enforce_eager : bool ,
281+ tensor_parallel_size : int ,
282+ disable_async_output_proc : bool ,
283+ force_marlin : bool ,
284+ monkeypatch ,
285+ ) -> None :
286+
287+ if force_marlin :
288+ monkeypatch .setenv ("VLLM_TEST_FORCE_FP8_MARLIN" , "1" )
289+
290+ rtol , atol = (1e-1 , 5e-1 )
291+
292+ with vllm_runner (
293+ test_model ,
294+ max_model_len = MAX_MODEL_LEN ,
295+ tensor_parallel_size = tensor_parallel_size ,
296+ enforce_eager = enforce_eager ,
297+ kv_cache_dtype = kv_cache_dtype ,
298+ disable_async_output_proc = disable_async_output_proc ,
299+ ) as vllm_model :
300+ baseline_outputs = vllm_model .generate_greedy_logprobs (
301+ example_prompts , max_tokens , NUM_LOG_PROBS
302+ )
303+
304+ with vllm_runner (
305+ test_model ,
306+ max_model_len = MAX_MODEL_LEN ,
307+ tensor_parallel_size = tensor_parallel_size ,
308+ enforce_eager = enforce_eager ,
309+ kv_cache_dtype = kv_cache_dtype ,
310+ disable_async_output_proc = disable_async_output_proc ,
311+ quantization = "ptpc_fp8" ,
312+ ) as vllm_model :
313+ ptpc_fp8_outputs = vllm_model .generate_greedy_logprobs (
314+ example_prompts , max_tokens , NUM_LOG_PROBS
315+ )
316+
317+ check_logprobs_close (baseline_outputs , ptpc_fp8_outputs , rtol = rtol , atol = atol )
0 commit comments