diff --git a/pyproject.toml b/pyproject.toml index 330d01a..9ce45d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ ] dependencies = [ - "vllm>=0.6.2,<0.6.5" + "vllm>=0.6.5" ] [project.optional-dependencies] diff --git a/tests/generative_detectors/test_base.py b/tests/generative_detectors/test_base.py index 60b10dd..0805714 100644 --- a/tests/generative_detectors/test_base.py +++ b/tests/generative_detectors/test_base.py @@ -1,5 +1,6 @@ # Standard from dataclasses import dataclass +from typing import Optional import asyncio # Third Party @@ -16,6 +17,11 @@ BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +@dataclass +class MockTokenizer: + type: Optional[str] = None + + @dataclass class MockHFConfig: model_type: str = "any" @@ -23,6 +29,7 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -30,7 +37,13 @@ class MockModelConfig: tokenizer_revision = None embedding_mode = False multimodal_config = MultiModalConfig() + diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() + logits_processor_pattern = None + allowed_local_media_path: str = "" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} @dataclass @@ -42,6 +55,7 @@ async def get_model_config(self): async def _async_serving_detection_completion_init(): """Initialize a chat completion base with string templates""" engine = MockEngine() + engine.errored = False model_config = await engine.get_model_config() detection_completion = ChatCompletionDetectionBase( @@ -52,6 +66,7 @@ async def _async_serving_detection_completion_init(): base_model_paths=BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None, diff --git a/tests/generative_detectors/test_granite_guardian.py b/tests/generative_detectors/test_granite_guardian.py index 7afb02a..7862684 100644 --- a/tests/generative_detectors/test_granite_guardian.py +++ b/tests/generative_detectors/test_granite_guardian.py @@ -1,6 +1,7 @@ # Standard from dataclasses import dataclass from http import HTTPStatus +from typing import Optional from unittest.mock import patch import asyncio @@ -33,6 +34,11 @@ BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +@dataclass +class MockTokenizer: + type: Optional[str] = None + + @dataclass class MockHFConfig: model_type: str = "any" @@ -40,6 +46,7 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -47,7 +54,13 @@ class MockModelConfig: tokenizer_revision = None embedding_mode = False multimodal_config = MultiModalConfig() + diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() + logits_processor_pattern = None + allowed_local_media_path: str = "" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} @dataclass @@ -59,6 +72,7 @@ async def get_model_config(self): async def _granite_guardian_init(): """Initialize a granite guardian""" engine = MockEngine() + engine.errored = False model_config = await engine.get_model_config() granite_guardian = GraniteGuardian( @@ -69,6 +83,7 @@ async def _granite_guardian_init(): base_model_paths=BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None, @@ -229,18 +244,3 @@ def test_chat_detection_errors_on_stream(granite_guardian_detection): assert type(detection_response) == ErrorResponse assert detection_response.code == HTTPStatus.BAD_REQUEST.value assert "streaming is not supported" in detection_response.message - - -def test_chat_detection_with_extra_unallowed_params(granite_guardian_detection): - granite_guardian_detection_instance = asyncio.run(granite_guardian_detection) - chat_request = ChatDetectionRequest( - messages=[ - DetectionChatMessageParam(role="user", content="How do I pick a lock?") - ], - detector_params={"boo": 3}, # unallowed param - ) - detection_response = asyncio.run( - granite_guardian_detection_instance.chat(chat_request) - ) - assert type(detection_response) == ErrorResponse - assert detection_response.code == HTTPStatus.BAD_REQUEST.value diff --git a/tests/generative_detectors/test_llama_guard.py b/tests/generative_detectors/test_llama_guard.py index 3d59705..92e9954 100644 --- a/tests/generative_detectors/test_llama_guard.py +++ b/tests/generative_detectors/test_llama_guard.py @@ -1,6 +1,7 @@ # Standard from dataclasses import dataclass from http import HTTPStatus +from typing import Optional from unittest.mock import patch import asyncio @@ -33,6 +34,11 @@ BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +@dataclass +class MockTokenizer: + type: Optional[str] = None + + @dataclass class MockHFConfig: model_type: str = "any" @@ -40,6 +46,7 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -47,7 +54,13 @@ class MockModelConfig: tokenizer_revision = None embedding_mode = False multimodal_config = MultiModalConfig() + diff_sampling_param: Optional[dict] = None hf_config = MockHFConfig() + logits_processor_pattern = None + allowed_local_media_path: str = "" + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} @dataclass @@ -59,6 +72,7 @@ async def get_model_config(self): async def _llama_guard_init(): """Initialize a llama guard""" engine = MockEngine() + engine.errored = False model_config = await engine.get_model_config() llama_guard_detection = LlamaGuard( @@ -69,6 +83,7 @@ async def _llama_guard_init(): base_model_paths=BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None, @@ -177,16 +192,3 @@ def test_chat_detection(llama_guard_detection, llama_guard_completion_response): assert detection_0["detection"] == "safe" assert detection_0["detection_type"] == "risk" assert pytest.approx(detection_0["score"]) == 0.001346767 - - -def test_chat_detection_with_extra_unallowed_params(llama_guard_detection): - llama_guard_detection_instance = asyncio.run(llama_guard_detection) - chat_request = ChatDetectionRequest( - messages=[ - DetectionChatMessageParam(role="user", content="How do I search for moose?") - ], - detector_params={"moo": "unallowed"}, # unallowed param - ) - detection_response = asyncio.run(llama_guard_detection_instance.chat(chat_request)) - assert type(detection_response) == ErrorResponse - assert detection_response.code == HTTPStatus.BAD_REQUEST.value diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 93ae9e8..8480691 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -54,8 +54,8 @@ def test_detection_to_completion_request_unknown_params(): detector_params={"moo": 2}, ) request = chat_request.to_chat_completion_request(MODEL_NAME) - assert type(request) == ErrorResponse - assert request.code == HTTPStatus.BAD_REQUEST.value + # As of vllm >= 0.6.5, extra fields are allowed + assert type(request) == ChatCompletionRequest def test_response_from_completion_response(): diff --git a/vllm_detector_adapter/api_server.py b/vllm_detector_adapter/api_server.py index bff5ffb..51ce754 100644 --- a/vllm_detector_adapter/api_server.py +++ b/vllm_detector_adapter/api_server.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import nullable_str from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai import api_server @@ -61,6 +62,16 @@ def init_app_state_with_detectors( BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] + resolved_chat_template = load_chat_template(args.chat_template) + # Post-0.6.6 incoming change for vllm - ref. https://github.com/vllm-project/vllm/pull/11660 + # Will be included after an official release includes this refactor + # state.openai_serving_models = OpenAIServingModels( + # model_config=model_config, + # base_model_paths=base_model_paths, + # lora_modules=args.lora_modules, + # prompt_adapters=args.prompt_adapters, + # ) + # Use vllm app state init api_server.init_app_state(engine_client, model_config, state, args) @@ -72,15 +83,18 @@ def init_app_state_with_detectors( args.output_template, engine_client, model_config, - base_model_paths, + base_model_paths, # Not present in post-0.6.6 incoming change + # state.openai_serving_models, # Post-0.6.6 incoming change args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, + lora_modules=args.lora_modules, # Not present in post-0.6.6 incoming change + prompt_adapters=args.prompt_adapters, # Not present in post-0.6.6 incoming change request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) diff --git a/vllm_detector_adapter/protocol.py b/vllm_detector_adapter/protocol.py index c952b21..2c6bb92 100644 --- a/vllm_detector_adapter/protocol.py +++ b/vllm_detector_adapter/protocol.py @@ -68,9 +68,8 @@ def to_chat_completion_request(self, model_name: str): ] # Try to pass all detector_params through as additional parameters to chat completions. - # This will error if extra unallowed parameters are included. We do not try to provide - # validation or changing of parameters here to not be dependent on chat completion API - # changes + # We do not try to provide validation or changing of parameters here to not be dependent + # on chat completion API changes. As of vllm >= 0.6.5, extra fields are allowed try: return ChatCompletionRequest( messages=messages,