|
| 1 | +"""Compare the outputs of HF and vLLM for Granite models using greedy sampling. |
| 2 | +
|
| 3 | +Run `pytest tests/models/test_granite.py`. |
| 4 | +""" |
| 5 | +import importlib.metadata |
| 6 | + |
| 7 | +import pytest |
| 8 | + |
| 9 | +from .utils import check_logprobs_close |
| 10 | + |
| 11 | +TRANSFORMERS_VERSION = tuple( |
| 12 | + map(int, |
| 13 | + importlib.metadata.version("transformers").split("."))) |
| 14 | + |
| 15 | +MODELS = [ |
| 16 | + "ibm/PowerLM-3b", |
| 17 | +] |
| 18 | + |
| 19 | + |
| 20 | +# GraniteForCausalLM will be in transformers >= 4.45 |
| 21 | +@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), |
| 22 | + reason="granite model test requires transformers >= 4.45") |
| 23 | +@pytest.mark.parametrize("model", MODELS) |
| 24 | +@pytest.mark.parametrize("dtype", ["bfloat16"]) |
| 25 | +@pytest.mark.parametrize("max_tokens", [64]) |
| 26 | +@pytest.mark.parametrize("num_logprobs", [5]) |
| 27 | +def test_models( |
| 28 | + hf_runner, |
| 29 | + vllm_runner, |
| 30 | + example_prompts, |
| 31 | + model: str, |
| 32 | + dtype: str, |
| 33 | + max_tokens: int, |
| 34 | + num_logprobs: int, |
| 35 | +) -> None: |
| 36 | + # TODO(sang): Sliding window should be tested separately. |
| 37 | + with hf_runner(model, dtype=dtype) as hf_model: |
| 38 | + hf_outputs = hf_model.generate_greedy_logprobs_limit( |
| 39 | + example_prompts, max_tokens, num_logprobs) |
| 40 | + |
| 41 | + with vllm_runner(model, dtype=dtype) as vllm_model: |
| 42 | + vllm_outputs = vllm_model.generate_greedy_logprobs( |
| 43 | + example_prompts, max_tokens, num_logprobs) |
| 44 | + check_logprobs_close( |
| 45 | + outputs_0_lst=hf_outputs, |
| 46 | + outputs_1_lst=vllm_outputs, |
| 47 | + name_0="hf", |
| 48 | + name_1="vllm", |
| 49 | + ) |
0 commit comments