1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- import os
4
3
from typing import Optional
5
4
6
5
import pytest
99
98
@pytest .mark .parametrize ("num_logprobs" , [5 ])
100
99
@pytest .mark .parametrize (
101
100
"use_rocm_aiter" , [True , False ] if current_platform .is_rocm () else [False ])
101
+ @pytest .mark .parametrize ("use_prompt_embeds" , [True , False ])
102
102
def test_models (hf_runner , vllm_runner , example_prompts , model : str ,
103
103
max_tokens : int , num_logprobs : int , use_rocm_aiter : bool ,
104
- monkeypatch ) -> None :
104
+ use_prompt_embeds : bool , monkeypatch ) -> None :
105
105
106
106
model_info = HF_EXAMPLE_MODELS .find_hf_info (model )
107
107
model_info .check_available_online (on_fail = "skip" )
@@ -119,8 +119,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
119
119
# in parts of the operators
120
120
pytest .skip (f"Skipping '{ model } ' model test with AITER kernel." )
121
121
122
- use_prompt_embeds = os .getenv ("VLLM_USE_V1" ) == "0"
123
-
124
122
with hf_runner (model ) as hf_model :
125
123
hf_outputs = hf_model .generate_greedy_logprobs_limit (
126
124
example_prompts , max_tokens , num_logprobs )
0 commit comments