Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions clarifai/runners/models/vllm_openai_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import threading
from typing import Iterator

import httpx
from clarifai_protocol import get_item_id, register_item_abort_callback

from clarifai.runners.models.openai_class import OpenAIModelClass


class VLLMCancellationHandler:
# Important: closing the httpx response kills the TCP connection;
# vLLM detects is_disconnected(), triggers engine.abort() and frees KV cache.
def __init__(self):
self._cancel_events = {}
self._responses = {}
self._early_aborts = set()
self._lock = threading.Lock()
register_item_abort_callback(self._handle_abort)

def _handle_abort(self, item_id: str) -> None:
with self._lock:
event = self._cancel_events.get(item_id)
response = self._responses.get(item_id)
if event:
event.set()
if response:
try:
response.close()
except Exception:
pass
else:
self._early_aborts.add(item_id)

def register_request(self, item_id: str, response=None) -> threading.Event:
cancel_event = threading.Event()
with self._lock:
self._cancel_events[item_id] = cancel_event
if response is not None:
self._responses[item_id] = response
if item_id in self._early_aborts:
cancel_event.set()
self._early_aborts.discard(item_id)
if response is not None:
try:
response.close()
except Exception:
pass
return cancel_event

def unregister_request(self, item_id: str) -> None:
with self._lock:
self._cancel_events.pop(item_id, None)
self._responses.pop(item_id, None)
self._early_aborts.discard(item_id)


class VLLMOpenAIModelClass(OpenAIModelClass):
"""vLLM-backed OpenAI model with /health probes and cancellation support.

Subclasses must set client, model, server and cancellation_handler in load_model(), for example:

def load_model(self):
self.server = vllm_openai_server(checkpoints, **server_args)
self.client = OpenAI(base_url=f"http://{self.server.host}:{self.server.port}/v1", api_key="x")
self.model = self.client.models.list().data[0].id
self.cancellation_handler = VLLMCancellationHandler()

For cancellation in generate() or custom streaming methods, follow this pattern:

def generate(self, prompt, ...) -> Iterator[str]:
item_id = None
cancel_event = None
try:
item_id = get_item_id()
except Exception:
pass
try:
response = self.client.chat.completions.create(..., stream=True)
if item_id:
cancel_event = self.cancellation_handler.register_request(item_id, response=response.response)
for chunk in response:
if cancel_event and cancel_event.is_set():
return
yield ...
except httpx.ReadError:
pass
finally:
if item_id:
self.cancellation_handler.unregister_request(item_id)
"""

server = None
cancellation_handler = None

def handle_liveness_probe(self) -> bool:
if self.server is None:
return super().handle_liveness_probe()
# /health is a non-blocking fast endpoint dedicated for health check
try:
resp = httpx.get(f"http://{self.server.host}:{self.server.port}/health", timeout=5.0)
return resp.status_code == 200
except Exception:
return False

def handle_readiness_probe(self) -> bool:
if self.server is None:
return super().handle_readiness_probe()
# /health is a non-blocking fast endpoint dedicated for health check
try:
resp = httpx.get(f"http://{self.server.host}:{self.server.port}/health", timeout=10.0)
return resp.status_code == 200
except Exception:
return False

@OpenAIModelClass.method
def openai_stream_transport(self, msg: str) -> Iterator[str]:
from pydantic_core import from_json

item_id = None
try:
item_id = get_item_id()
except Exception:
pass
cancel_event = None
try:
request_data = from_json(msg)
request_data = self._update_old_fields(request_data)
endpoint = request_data.pop("openai_endpoint", self.DEFAULT_ENDPOINT)
if endpoint not in [self.ENDPOINT_CHAT_COMPLETIONS, self.ENDPOINT_RESPONSES]:
raise ValueError(
f"Only {self.ENDPOINT_CHAT_COMPLETIONS} and {self.ENDPOINT_RESPONSES} endpoints are supported for streaming."
)

if endpoint == self.ENDPOINT_RESPONSES:
# /responses endpoint — direct call (no retry), same Stream[T] interface
response_args = {**request_data}
response_args.update({"model": self.model})
response = self.client.responses.create(**response_args)
else:
# /chat/completions endpoint
completion_args = self._create_completion_args(request_data)
response = self.client.chat.completions.create(**completion_args)

if item_id and self.cancellation_handler:
cancel_event = self.cancellation_handler.register_request(
item_id, response=response.response
)

for chunk in response:
if cancel_event and cancel_event.is_set():
return
self._set_usage(chunk)
yield chunk.model_dump_json()
except httpx.ReadError:
pass
finally:
if item_id and self.cancellation_handler:
self.cancellation_handler.unregister_request(item_id)
230 changes: 230 additions & 0 deletions tests/runners/test_vllm_openai_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""Unit tests for VLLMOpenAIModelClass and VLLMCancellationHandler."""

import json
import threading
from unittest.mock import MagicMock, patch

import pytest

from clarifai.runners.models.dummy_openai_model import MockOpenAIClient
from clarifai.runners.models.vllm_openai_class import VLLMCancellationHandler, VLLMOpenAIModelClass


# ---------------------------------------------------------------------------
# Minimal concrete subclass — no real vLLM server needed
# ---------------------------------------------------------------------------
class DummyVLLMModel(VLLMOpenAIModelClass):
client = MockOpenAIClient()
model = "dummy-model"


# ---------------------------------------------------------------------------
# VLLMCancellationHandler
# ---------------------------------------------------------------------------
class TestVLLMCancellationHandler:
def _make_handler(self):
with patch("clarifai.runners.models.vllm_openai_class.register_item_abort_callback"):
return VLLMCancellationHandler()

def test_register_request_returns_unset_event(self):
handler = self._make_handler()
event = handler.register_request("item-1")
assert isinstance(event, threading.Event)
assert not event.is_set()

def test_handle_abort_sets_event_for_registered_item(self):
handler = self._make_handler()
event = handler.register_request("item-1")
handler._handle_abort("item-1")
assert event.is_set()

def test_handle_abort_closes_response(self):
handler = self._make_handler()
mock_response = MagicMock()
handler.register_request("item-1", response=mock_response)
handler._handle_abort("item-1")
mock_response.close.assert_called_once()

def test_early_abort_sets_event_on_late_register(self):
"""Abort arrives before register_request — event is immediately set on registration."""
handler = self._make_handler()
handler._handle_abort("item-early")
event = handler.register_request("item-early")
assert event.is_set()

def test_handle_abort_unknown_item_recorded_as_early_abort(self):
handler = self._make_handler()
handler._handle_abort("unknown-item")
assert "unknown-item" in handler._early_aborts

def test_unregister_removes_all_state(self):
handler = self._make_handler()
mock_response = MagicMock()
handler.register_request("item-1", response=mock_response)
handler.unregister_request("item-1")
assert "item-1" not in handler._cancel_events
assert "item-1" not in handler._responses
assert "item-1" not in handler._early_aborts


# ---------------------------------------------------------------------------
# VLLMOpenAIModelClass — health probes
# ---------------------------------------------------------------------------
class TestVLLMOpenAIModelClassProbes:
def test_liveness_probe_no_server_delegates_to_super(self):
model = DummyVLLMModel()
# server is None → falls back to OpenAIModelClass.handle_liveness_probe() which returns True
assert model.handle_liveness_probe() is True

def test_readiness_probe_no_server_delegates_to_super(self):
model = DummyVLLMModel()
assert model.handle_readiness_probe() is True

def test_liveness_probe_returns_true_on_http_200(self):
model = DummyVLLMModel()
model.server = MagicMock(host="localhost", port=8000)
mock_resp = MagicMock(status_code=200)
with patch("clarifai.runners.models.vllm_openai_class.httpx.get", return_value=mock_resp):
assert model.handle_liveness_probe() is True

def test_liveness_probe_returns_false_on_non_200(self):
model = DummyVLLMModel()
model.server = MagicMock(host="localhost", port=8000)
mock_resp = MagicMock(status_code=503)
with patch("clarifai.runners.models.vllm_openai_class.httpx.get", return_value=mock_resp):
assert model.handle_liveness_probe() is False

def test_liveness_probe_returns_false_on_exception(self):
model = DummyVLLMModel()
model.server = MagicMock(host="localhost", port=8000)
with patch(
"clarifai.runners.models.vllm_openai_class.httpx.get", side_effect=Exception("timeout")
):
assert model.handle_liveness_probe() is False

def test_readiness_probe_returns_true_on_http_200(self):
model = DummyVLLMModel()
model.server = MagicMock(host="localhost", port=8000)
mock_resp = MagicMock(status_code=200)
with patch("clarifai.runners.models.vllm_openai_class.httpx.get", return_value=mock_resp):
assert model.handle_readiness_probe() is True

def test_readiness_probe_returns_false_on_exception(self):
model = DummyVLLMModel()
model.server = MagicMock(host="localhost", port=8000)
with patch(
"clarifai.runners.models.vllm_openai_class.httpx.get",
side_effect=Exception("conn refused"),
):
assert model.handle_readiness_probe() is False


# ---------------------------------------------------------------------------
# VLLMOpenAIModelClass — openai_stream_transport with cancellation
# ---------------------------------------------------------------------------
def _make_mock_stream(*chunk_texts):
"""Return a mock streaming response whose chunks have the expected interface.

_set_usage asserts that a chunk doesn't have both .usage and .response.usage set,
so we explicitly set both to None on each chunk.
"""
chunks = []
for text in chunk_texts:
chunk = MagicMock()
chunk.usage = None
chunk.response = None
chunk.model_dump_json.return_value = json.dumps(
{"choices": [{"delta": {"content": text}}], "usage": None}
)
chunks.append(chunk)
mock_stream = MagicMock()
mock_stream.__iter__ = MagicMock(return_value=iter(chunks))
mock_stream.response = MagicMock()
return mock_stream


class TestVLLMStreamTransportCancellation:
def _model_with_mock_client_and_handler(self, cancel_event):
model = DummyVLLMModel()
mock_handler = MagicMock()
mock_handler.register_request.return_value = cancel_event
model.cancellation_handler = mock_handler
mock_stream = _make_mock_stream("Hello", " world")
model.client = MagicMock()
model.client.chat.completions.create.return_value = mock_stream
return model, mock_handler

def test_cancel_before_iteration_yields_no_chunks(self):
cancel_event = threading.Event()
cancel_event.set() # already cancelled
model, mock_handler = self._model_with_mock_client_and_handler(cancel_event)

request = json.dumps(
{
"model": "dummy-model",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
)
with patch(
"clarifai.runners.models.vllm_openai_class.get_item_id", return_value="item-abc"
):
chunks = list(model.openai_stream_transport(request))

assert chunks == []
mock_handler.unregister_request.assert_called_once_with("item-abc")

def test_no_cancel_yields_all_chunks(self):
cancel_event = threading.Event() # never set
model, mock_handler = self._model_with_mock_client_and_handler(cancel_event)

request = json.dumps(
{
"model": "dummy-model",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
)
with patch(
"clarifai.runners.models.vllm_openai_class.get_item_id", return_value="item-xyz"
):
chunks = list(model.openai_stream_transport(request))

assert len(chunks) == 2
mock_handler.unregister_request.assert_called_once_with("item-xyz")

def test_unregister_called_even_when_get_item_id_fails(self):
"""If get_item_id raises, no cancellation handler is used but stream still works."""
model = DummyVLLMModel()
mock_stream = _make_mock_stream("chunk1")
model.client = MagicMock()
model.client.chat.completions.create.return_value = mock_stream

request = json.dumps(
{
"model": "dummy-model",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
}
)
with patch(
"clarifai.runners.models.vllm_openai_class.get_item_id",
side_effect=Exception("no context"),
):
chunks = list(model.openai_stream_transport(request))

assert len(chunks) == 1

def test_invalid_endpoint_raises_value_error(self):
model = DummyVLLMModel()
request = json.dumps(
{
"model": "dummy-model",
"messages": [{"role": "user", "content": "Hello"}],
"stream": True,
"openai_endpoint": "/unsupported",
}
)
with patch("clarifai.runners.models.vllm_openai_class.get_item_id", side_effect=Exception):
with pytest.raises(ValueError, match="Only"):
list(model.openai_stream_transport(request))
Loading