diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index a0c1d7e24c50..b5dfaa920074 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -2,15 +2,25 @@ Run `pytest tests/quantization/test_fp8.py --forked`. """ + import pytest import torch -from tests.quantization.utils import is_quant_method_supported +from tests.quantization.utils import ( + is_quant_method_supported, + check_logprobs_close, + check_target_closer, +) from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, - Fp8LinearMethod) +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8KVCacheMethod, + Fp8LinearMethod, +) +from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod +from vllm.model_executor.models.utils import PPMissingLayer from vllm.platforms import current_platform + MODELS = [ "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", "nm-testing/Phi-3-mini-128k-instruct-FP8", @@ -18,20 +28,22 @@ ] -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) -def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - monkeypatch) -> None: +def test_model_load_and_run( + vllm_runner, model_id: str, force_marlin: bool, monkeypatch +) -> None: if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(prompts=["Hello my name is"], max_tokens=10) print(outputs[0][1]) @@ -43,13 +55,17 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, ] -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + model = ( + llm.model.llm_engine.model_executor.driver_worker.model_runner.model + ) # noqa: E501 attn = model.model.layers[0].self_attn.attn assert isinstance(attn.quant_method, Fp8KVCacheMethod) # NOTE: it is valid for scales to be 1.0 (default value), but we know @@ -59,25 +75,29 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(prompts=["Hello my name is"], max_tokens=10) print(outputs[0][1]) -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) -def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - monkeypatch) -> None: +def test_load_fp16_model( + vllm_runner, kv_cache_dtype: str, force_marlin: bool, monkeypatch +) -> None: if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") - with vllm_runner("facebook/opt-125m", - quantization="fp8", - kv_cache_dtype=kv_cache_dtype) as llm: + with vllm_runner( + "facebook/opt-125m", quantization="fp8", kv_cache_dtype=kv_cache_dtype + ) as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + model = ( + llm.model.llm_engine.model_executor.driver_worker.model_runner.model + ) # noqa: E501 fc1 = model.model.decoder.layers[0].fc1 assert isinstance(fc1.quant_method, Fp8LinearMethod) if kv_cache_dtype == "fp8": @@ -95,8 +115,10 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, assert fc1.weight.dtype == torch.int32 -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_scaled_fp8_quant(dtype) -> None: @@ -105,8 +127,7 @@ def quantize_ref(tensor, inv_scale): # the kernel being tested. finfo = torch.finfo(torch.float8_e4m3fn) scale = inv_scale.reciprocal() - qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, - max=finfo.max) + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) qweight = qweight.to(torch.float8_e4m3fn) return qweight @@ -125,18 +146,172 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): # Reference dynamic quantizaton y = quantize_ref(x, inv_scale) - torch.testing.assert_close(ref_y, - per_tensor_dequantize(y, inv_scale, dtype)) + torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) # Static quantization y, _ = ops.scaled_fp8_quant(x, inv_scale) - torch.testing.assert_close(ref_y, - per_tensor_dequantize(y, inv_scale, dtype)) + torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) # Padding y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17) assert y.shape[0] == 17 torch.testing.assert_close( ref_y, - per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, - dtype)) + per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype), + ) + + +PTPC_MODELS = ["meta-llama/Llama-3.1-8B-Instruct"] + +MAX_MODEL_LEN = 1024 +NUM_LOG_PROBS = 8 + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +@pytest.mark.skipif(not torch.version.hip, reason="Requires HIP") +@pytest.mark.parametrize("test_model", PTPC_MODELS) +@pytest.mark.parametrize("force_marlin", [False, True]) +def test_load_fp16_ptpc_model( + vllm_runner, test_model: str, force_marlin: bool, monkeypatch +) -> None: + if force_marlin: + monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") + + with vllm_runner(test_model, quantization="ptpc_fp8") as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + for layer in model.model.layers: + if not isinstance(layer, PPMissingLayer): + assert isinstance( + layer.self_attn.qkv_proj.quant_method, + PTPCFp8LinearMethod, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +@pytest.mark.skipif(not torch.version.hip, reason="Requires HIP") +@pytest.mark.parametrize("test_model", PTPC_MODELS) +@pytest.mark.parametrize("kv_cache_dtype", ["fp8_e4m3"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("disable_async_output_proc", [True]) +@pytest.mark.parametrize("force_marlin", [False, True]) +def test_ptpc_fp18( + vllm_runner, + example_prompts, + test_model: str, + kv_cache_dtype: str, + max_tokens: int, + enforce_eager: bool, + tensor_parallel_size: int, + disable_async_output_proc: bool, + force_marlin: bool, + monkeypatch, +) -> None: + if force_marlin: + monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") + + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS + ) + + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + quantization="ptpc_fp8", + ) as vllm_model: + ptpc_fp8_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS + ) + + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + quantization="fp8", + ) as vllm_model: + fp8_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS + ) + + check_target_closer(baseline_outputs, ptpc_fp8_outputs, fp8_outputs) + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +@pytest.mark.skipif(not torch.version.hip, reason="Requires HIP") +@pytest.mark.parametrize("test_model", PTPC_MODELS) +@pytest.mark.parametrize("kv_cache_dtype", ["fp8_e4m3"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("disable_async_output_proc", [True]) +@pytest.mark.parametrize("force_marlin", [False, True]) +def test_ptpc_baseline( + vllm_runner, + example_prompts, + test_model: str, + kv_cache_dtype: str, + max_tokens: int, + enforce_eager: bool, + tensor_parallel_size: int, + disable_async_output_proc: bool, + force_marlin: bool, + monkeypatch, +) -> None: + + if force_marlin: + monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") + + rtol, atol = (1e-1, 5e-1) + + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS + ) + + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + quantization="ptpc_fp8", + ) as vllm_model: + ptpc_fp8_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS + ) + + check_logprobs_close(baseline_outputs, ptpc_fp8_outputs, rtol=rtol, atol=atol) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 8ebd8dd2be0d..f325ee479088 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,5 +1,9 @@ +from typing import Dict, List, Optional, Sequence, Tuple, Union +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs + from vllm.model_executor.layers.quantization import get_quantization_config from vllm.platforms import current_platform +import torch def is_quant_method_supported(quant_method: str) -> bool: @@ -13,3 +17,94 @@ def is_quant_method_supported(quant_method: str) -> bool: min_capability = get_quantization_config(quant_method).get_min_capability() return capability.to_int() >= min_capability + + +TokensTextLogprobs = Tuple[ + List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]] +] + +TextTextLogprobs = Tuple[ + List[str], str, Optional[Union[List[Dict[str, float]], List[Dict[str, Logprob]]]] +] + +TokensTextLogprobsPromptLogprobs = Tuple[ + List[int], + str, + Optional[Union[List[Dict[int, float]], SampleLogprobs]], + Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]], +] + +ModelOutputSequence = Sequence[ + Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs] +] + + +def extract_log_probs(log_prob_dict_seq: Sequence[Dict[str, Logprob]]) -> List[float]: + return [ + logprob.logprob + for log_prob_dict in log_prob_dict_seq + for logprob in log_prob_dict.values() + ] + + +def extract_log_probs_from_model_ouput_sequence( + model_output_sequence: ModelOutputSequence, +) -> torch.Tensor: + log_probs_all = [] + for model_output in model_output_sequence: + if len(model_output) == 3: + _, _, logprobs_list = model_output + elif len(model_output) == 4: + _, _, logprobs_list, _ = model_output + else: + raise ValueError( + f"Outputs tuple must have 3 or 4 elements but " + f"{len(model_output)} elements were provided: " + f"{model_output}" + ) + + log_probs_all.extend(extract_log_probs(logprobs_list)) + return torch.tensor(log_probs_all) + + +def check_target_closer( + base_output_sequence: ModelOutputSequence, + target_a_output_sequence: ModelOutputSequence, + target_b_output_sequence: ModelOutputSequence, +) -> None: + """Compare the logprobs of target outputs against a base model output + to determine if target a is closer to the base model than target b. + + Args: + base_output_sequence: Output from the base model + target_a_out: Output from the first target model + target_b_out: Output from the second target model + """ + + base_model_log_probs = extract_log_probs_from_model_ouput_sequence( + base_output_sequence + ) + target_1_log_probs = extract_log_probs_from_model_ouput_sequence( + target_a_output_sequence + ) + target_2_log_probs = extract_log_probs_from_model_ouput_sequence( + target_b_output_sequence + ) + assert torch.linalg.norm( + base_model_log_probs - target_1_log_probs + ) < torch.linalg.norm(base_model_log_probs - target_2_log_probs) + + +def check_logprobs_close( + a_output_sequence: ModelOutputSequence, + b_output_sequence: ModelOutputSequence, + rtol: float, + atol: float, +) -> None: + """ + Compare log probabilities of two model output sequences for closeness. + """ + + a_log_probs = extract_log_probs_from_model_ouput_sequence(a_output_sequence) + b_log_probs = extract_log_probs_from_model_ouput_sequence(b_output_sequence) + assert torch.allclose(a_log_probs, b_log_probs, rtol=rtol, atol=atol)