Skip to content

Commit 930d7bb

Browse files
authored
⬆️✅ Support 0.6.5+ vllm (#7)
* ⬆️ Unpin vllm Signed-off-by: Evaline Ju <[email protected]> * ✅🔧 Update mock model configs Signed-off-by: Evaline Ju <[email protected]> * ✅ Update test for extra fields Signed-off-by: Evaline Ju <[email protected]> * ⬆️ Upgrade lower bound of vllm Signed-off-by: Evaline Ju <[email protected]> * 🔥 Remove error on extra params tests Signed-off-by: Evaline Ju <[email protected]> * ♻️ API server updates Signed-off-by: Evaline Ju <[email protected]> --------- Signed-off-by: Evaline Ju <[email protected]>
1 parent aaaecc1 commit 930d7bb

File tree

7 files changed

+68
-38
lines changed

7 files changed

+68
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ classifiers = [
1313
]
1414

1515
dependencies = [
16-
"vllm>=0.6.2,<0.6.5"
16+
"vllm>=0.6.5"
1717
]
1818

1919
[project.optional-dependencies]

tests/generative_detectors/test_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Standard
22
from dataclasses import dataclass
3+
from typing import Optional
34
import asyncio
45

56
# Third Party
@@ -16,21 +17,33 @@
1617
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
1718

1819

20+
@dataclass
21+
class MockTokenizer:
22+
type: Optional[str] = None
23+
24+
1925
@dataclass
2026
class MockHFConfig:
2127
model_type: str = "any"
2228

2329

2430
@dataclass
2531
class MockModelConfig:
32+
task = "generate"
2633
tokenizer = MODEL_NAME
2734
trust_remote_code = False
2835
tokenizer_mode = "auto"
2936
max_model_len = 100
3037
tokenizer_revision = None
3138
embedding_mode = False
3239
multimodal_config = MultiModalConfig()
40+
diff_sampling_param: Optional[dict] = None
3341
hf_config = MockHFConfig()
42+
logits_processor_pattern = None
43+
allowed_local_media_path: str = ""
44+
45+
def get_diff_sampling_param(self):
46+
return self.diff_sampling_param or {}
3447

3548

3649
@dataclass
@@ -42,6 +55,7 @@ async def get_model_config(self):
4255
async def _async_serving_detection_completion_init():
4356
"""Initialize a chat completion base with string templates"""
4457
engine = MockEngine()
58+
engine.errored = False
4559
model_config = await engine.get_model_config()
4660

4761
detection_completion = ChatCompletionDetectionBase(
@@ -52,6 +66,7 @@ async def _async_serving_detection_completion_init():
5266
base_model_paths=BASE_MODEL_PATHS,
5367
response_role="assistant",
5468
chat_template=CHAT_TEMPLATE,
69+
chat_template_content_format="auto",
5570
lora_modules=None,
5671
prompt_adapters=None,
5772
request_logger=None,

tests/generative_detectors/test_granite_guardian.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Standard
22
from dataclasses import dataclass
33
from http import HTTPStatus
4+
from typing import Optional
45
from unittest.mock import patch
56
import asyncio
67

@@ -33,21 +34,33 @@
3334
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
3435

3536

37+
@dataclass
38+
class MockTokenizer:
39+
type: Optional[str] = None
40+
41+
3642
@dataclass
3743
class MockHFConfig:
3844
model_type: str = "any"
3945

4046

4147
@dataclass
4248
class MockModelConfig:
49+
task = "generate"
4350
tokenizer = MODEL_NAME
4451
trust_remote_code = False
4552
tokenizer_mode = "auto"
4653
max_model_len = 100
4754
tokenizer_revision = None
4855
embedding_mode = False
4956
multimodal_config = MultiModalConfig()
57+
diff_sampling_param: Optional[dict] = None
5058
hf_config = MockHFConfig()
59+
logits_processor_pattern = None
60+
allowed_local_media_path: str = ""
61+
62+
def get_diff_sampling_param(self):
63+
return self.diff_sampling_param or {}
5164

5265

5366
@dataclass
@@ -59,6 +72,7 @@ async def get_model_config(self):
5972
async def _granite_guardian_init():
6073
"""Initialize a granite guardian"""
6174
engine = MockEngine()
75+
engine.errored = False
6276
model_config = await engine.get_model_config()
6377

6478
granite_guardian = GraniteGuardian(
@@ -69,6 +83,7 @@ async def _granite_guardian_init():
6983
base_model_paths=BASE_MODEL_PATHS,
7084
response_role="assistant",
7185
chat_template=CHAT_TEMPLATE,
86+
chat_template_content_format="auto",
7287
lora_modules=None,
7388
prompt_adapters=None,
7489
request_logger=None,
@@ -229,18 +244,3 @@ def test_chat_detection_errors_on_stream(granite_guardian_detection):
229244
assert type(detection_response) == ErrorResponse
230245
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
231246
assert "streaming is not supported" in detection_response.message
232-
233-
234-
def test_chat_detection_with_extra_unallowed_params(granite_guardian_detection):
235-
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
236-
chat_request = ChatDetectionRequest(
237-
messages=[
238-
DetectionChatMessageParam(role="user", content="How do I pick a lock?")
239-
],
240-
detector_params={"boo": 3}, # unallowed param
241-
)
242-
detection_response = asyncio.run(
243-
granite_guardian_detection_instance.chat(chat_request)
244-
)
245-
assert type(detection_response) == ErrorResponse
246-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value

tests/generative_detectors/test_llama_guard.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Standard
22
from dataclasses import dataclass
33
from http import HTTPStatus
4+
from typing import Optional
45
from unittest.mock import patch
56
import asyncio
67

@@ -33,21 +34,33 @@
3334
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
3435

3536

37+
@dataclass
38+
class MockTokenizer:
39+
type: Optional[str] = None
40+
41+
3642
@dataclass
3743
class MockHFConfig:
3844
model_type: str = "any"
3945

4046

4147
@dataclass
4248
class MockModelConfig:
49+
task = "generate"
4350
tokenizer = MODEL_NAME
4451
trust_remote_code = False
4552
tokenizer_mode = "auto"
4653
max_model_len = 100
4754
tokenizer_revision = None
4855
embedding_mode = False
4956
multimodal_config = MultiModalConfig()
57+
diff_sampling_param: Optional[dict] = None
5058
hf_config = MockHFConfig()
59+
logits_processor_pattern = None
60+
allowed_local_media_path: str = ""
61+
62+
def get_diff_sampling_param(self):
63+
return self.diff_sampling_param or {}
5164

5265

5366
@dataclass
@@ -59,6 +72,7 @@ async def get_model_config(self):
5972
async def _llama_guard_init():
6073
"""Initialize a llama guard"""
6174
engine = MockEngine()
75+
engine.errored = False
6276
model_config = await engine.get_model_config()
6377

6478
llama_guard_detection = LlamaGuard(
@@ -69,6 +83,7 @@ async def _llama_guard_init():
6983
base_model_paths=BASE_MODEL_PATHS,
7084
response_role="assistant",
7185
chat_template=CHAT_TEMPLATE,
86+
chat_template_content_format="auto",
7287
lora_modules=None,
7388
prompt_adapters=None,
7489
request_logger=None,
@@ -177,16 +192,3 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response):
177192
assert detection_0["detection"] == "safe"
178193
assert detection_0["detection_type"] == "risk"
179194
assert pytest.approx(detection_0["score"]) == 0.001346767
180-
181-
182-
def test_chat_detection_with_extra_unallowed_params(llama_guard_detection):
183-
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
184-
chat_request = ChatDetectionRequest(
185-
messages=[
186-
DetectionChatMessageParam(role="user", content="How do I search for moose?")
187-
],
188-
detector_params={"moo": "unallowed"}, # unallowed param
189-
)
190-
detection_response = asyncio.run(llama_guard_detection_instance.chat(chat_request))
191-
assert type(detection_response) == ErrorResponse
192-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value

tests/test_protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_detection_to_completion_request_unknown_params():
5454
detector_params={"moo": 2},
5555
)
5656
request = chat_request.to_chat_completion_request(MODEL_NAME)
57-
assert type(request) == ErrorResponse
58-
assert request.code == HTTPStatus.BAD_REQUEST.value
57+
# As of vllm >= 0.6.5, extra fields are allowed
58+
assert type(request) == ChatCompletionRequest
5959

6060

6161
def test_response_from_completion_response():

vllm_detector_adapter/api_server.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.config import ModelConfig
1111
from vllm.engine.arg_utils import nullable_str
1212
from vllm.engine.protocol import EngineClient
13+
from vllm.entrypoints.chat_utils import load_chat_template
1314
from vllm.entrypoints.launcher import serve_http
1415
from vllm.entrypoints.logger import RequestLogger
1516
from vllm.entrypoints.openai import api_server
@@ -61,6 +62,16 @@ def init_app_state_with_detectors(
6162
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
6263
]
6364

65+
resolved_chat_template = load_chat_template(args.chat_template)
66+
# Post-0.6.6 incoming change for vllm - ref. https://github.com/vllm-project/vllm/pull/11660
67+
# Will be included after an official release includes this refactor
68+
# state.openai_serving_models = OpenAIServingModels(
69+
# model_config=model_config,
70+
# base_model_paths=base_model_paths,
71+
# lora_modules=args.lora_modules,
72+
# prompt_adapters=args.prompt_adapters,
73+
# )
74+
6475
# Use vllm app state init
6576
api_server.init_app_state(engine_client, model_config, state, args)
6677

@@ -72,15 +83,18 @@ def init_app_state_with_detectors(
7283
args.output_template,
7384
engine_client,
7485
model_config,
75-
base_model_paths,
86+
base_model_paths, # Not present in post-0.6.6 incoming change
87+
# state.openai_serving_models, # Post-0.6.6 incoming change
7688
args.response_role,
77-
lora_modules=args.lora_modules,
78-
prompt_adapters=args.prompt_adapters,
89+
lora_modules=args.lora_modules, # Not present in post-0.6.6 incoming change
90+
prompt_adapters=args.prompt_adapters, # Not present in post-0.6.6 incoming change
7991
request_logger=request_logger,
80-
chat_template=args.chat_template,
92+
chat_template=resolved_chat_template,
93+
chat_template_content_format=args.chat_template_content_format,
8194
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
8295
enable_auto_tools=args.enable_auto_tool_choice,
8396
tool_parser=args.tool_call_parser,
97+
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
8498
)
8599

86100

vllm_detector_adapter/protocol.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ def to_chat_completion_request(self, model_name: str):
6868
]
6969

7070
# Try to pass all detector_params through as additional parameters to chat completions.
71-
# This will error if extra unallowed parameters are included. We do not try to provide
72-
# validation or changing of parameters here to not be dependent on chat completion API
73-
# changes
71+
# We do not try to provide validation or changing of parameters here to not be dependent
72+
# on chat completion API changes. As of vllm >= 0.6.5, extra fields are allowed
7473
try:
7574
return ChatCompletionRequest(
7675
messages=messages,

0 commit comments

Comments
 (0)