Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 207 additions & 32 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,48 @@

Run `pytest tests/quantization/test_fp8.py --forked`.
"""

import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from tests.quantization.utils import (
is_quant_method_supported,
check_logprobs_close,
check_target_closer,
)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
Fp8LinearMethod)
from vllm.model_executor.layers.quantization.fp8 import (
Fp8KVCacheMethod,
Fp8LinearMethod,
)
from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.platforms import current_platform


MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
"nm-testing/Phi-3-mini-128k-instruct-FP8",
"nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None:
def test_model_load_and_run(
vllm_runner, model_id: str, force_marlin: bool, monkeypatch
) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)
outputs = llm.generate_greedy(prompts=["Hello my name is"], max_tokens=10)
print(outputs[0][1])


Expand All @@ -43,13 +55,17 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
model = (
llm.model.llm_engine.model_executor.driver_worker.model_runner.model
) # noqa: E501
attn = model.model.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
Expand All @@ -59,25 +75,29 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):

# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)
outputs = llm.generate_greedy(prompts=["Hello my name is"], max_tokens=10)
print(outputs[0][1])


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("force_marlin", [False, True])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
monkeypatch) -> None:
def test_load_fp16_model(
vllm_runner, kv_cache_dtype: str, force_marlin: bool, monkeypatch
) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner("facebook/opt-125m",
quantization="fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
with vllm_runner(
"facebook/opt-125m", quantization="fp8", kv_cache_dtype=kv_cache_dtype
) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
model = (
llm.model.llm_engine.model_executor.driver_worker.model_runner.model
) # noqa: E501
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
if kv_cache_dtype == "fp8":
Expand All @@ -95,8 +115,10 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert fc1.weight.dtype == torch.int32


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant(dtype) -> None:

Expand All @@ -105,8 +127,7 @@ def quantize_ref(tensor, inv_scale):
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
max=finfo.max)
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight

Expand All @@ -125,18 +146,172 @@ def per_tensor_dequantize(tensor, inv_scale, dtype):

# Reference dynamic quantizaton
y = quantize_ref(x, inv_scale)
torch.testing.assert_close(ref_y,
per_tensor_dequantize(y, inv_scale, dtype))
torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))

# Static quantization
y, _ = ops.scaled_fp8_quant(x, inv_scale)
torch.testing.assert_close(ref_y,
per_tensor_dequantize(y, inv_scale, dtype))
torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype))

# Padding
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
assert y.shape[0] == 17
torch.testing.assert_close(
ref_y,
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
dtype))
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype),
)


PTPC_MODELS = ["meta-llama/Llama-3.1-8B-Instruct"]

MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8


@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.skipif(not torch.version.hip, reason="Requires HIP")
@pytest.mark.parametrize("test_model", PTPC_MODELS)
@pytest.mark.parametrize("force_marlin", [False, True])
def test_load_fp16_ptpc_model(
vllm_runner, test_model: str, force_marlin: bool, monkeypatch
) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner(test_model, quantization="ptpc_fp8") as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
for layer in model.model.layers:
if not isinstance(layer, PPMissingLayer):
assert isinstance(
layer.self_attn.qkv_proj.quant_method,
PTPCFp8LinearMethod,
)


@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.skipif(not torch.version.hip, reason="Requires HIP")
@pytest.mark.parametrize("test_model", PTPC_MODELS)
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_e4m3"])
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("disable_async_output_proc", [True])
@pytest.mark.parametrize("force_marlin", [False, True])
def test_ptpc_fp18(
vllm_runner,
example_prompts,
test_model: str,
kv_cache_dtype: str,
max_tokens: int,
enforce_eager: bool,
tensor_parallel_size: int,
disable_async_output_proc: bool,
force_marlin: bool,
monkeypatch,
) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)

with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
quantization="ptpc_fp8",
) as vllm_model:
ptpc_fp8_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)

with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
quantization="fp8",
) as vllm_model:
fp8_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)

check_target_closer(baseline_outputs, ptpc_fp8_outputs, fp8_outputs)


@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.skipif(not torch.version.hip, reason="Requires HIP")
@pytest.mark.parametrize("test_model", PTPC_MODELS)
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_e4m3"])
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("disable_async_output_proc", [True])
@pytest.mark.parametrize("force_marlin", [False, True])
def test_ptpc_baseline(
vllm_runner,
example_prompts,
test_model: str,
kv_cache_dtype: str,
max_tokens: int,
enforce_eager: bool,
tensor_parallel_size: int,
disable_async_output_proc: bool,
force_marlin: bool,
monkeypatch,
) -> None:

if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

rtol, atol = (1e-1, 5e-1)

with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)

with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc,
quantization="ptpc_fp8",
) as vllm_model:
ptpc_fp8_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)

check_logprobs_close(baseline_outputs, ptpc_fp8_outputs, rtol=rtol, atol=atol)
Loading
Loading