Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit f8d6014

Browse files
shawntannjhill
andauthored
[Model] Add Granite model (vllm-project#7436)
Co-authored-by: Nick Hill <[email protected]>
1 parent 5b86b19 commit f8d6014

File tree

4 files changed

+792
-0
lines changed

4 files changed

+792
-0
lines changed

tests/models/test_granite.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
2+
3+
Run `pytest tests/models/test_granite.py`.
4+
"""
5+
import importlib.metadata
6+
7+
import pytest
8+
9+
from .utils import check_logprobs_close
10+
11+
TRANSFORMERS_VERSION = tuple(
12+
map(int,
13+
importlib.metadata.version("transformers").split(".")))
14+
15+
MODELS = [
16+
"ibm/PowerLM-3b",
17+
]
18+
19+
20+
# GraniteForCausalLM will be in transformers >= 4.45
21+
@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45),
22+
reason="granite model test requires transformers >= 4.45")
23+
@pytest.mark.parametrize("model", MODELS)
24+
@pytest.mark.parametrize("dtype", ["bfloat16"])
25+
@pytest.mark.parametrize("max_tokens", [64])
26+
@pytest.mark.parametrize("num_logprobs", [5])
27+
def test_models(
28+
hf_runner,
29+
vllm_runner,
30+
example_prompts,
31+
model: str,
32+
dtype: str,
33+
max_tokens: int,
34+
num_logprobs: int,
35+
) -> None:
36+
# TODO(sang): Sliding window should be tested separately.
37+
with hf_runner(model, dtype=dtype) as hf_model:
38+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
39+
example_prompts, max_tokens, num_logprobs)
40+
41+
with vllm_runner(model, dtype=dtype) as vllm_model:
42+
vllm_outputs = vllm_model.generate_greedy_logprobs(
43+
example_prompts, max_tokens, num_logprobs)
44+
check_logprobs_close(
45+
outputs_0_lst=hf_outputs,
46+
outputs_1_lst=vllm_outputs,
47+
name_0="hf",
48+
name_1="vllm",
49+
)

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"EAGLEModel": ("eagle", "EAGLE"),
6666
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
6767
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
68+
"GraniteForCausalLM": ("granite", "GraniteForCausalLM")
6869
}
6970

7071
_EMBEDDING_MODELS = {

0 commit comments

Comments
 (0)