Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/8389.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `model_validate` for health check info to apply Pydantic defaults, preventing pydantic validation error when `initial_delay` field is nullable
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description says "when initial_delay field is nullable" but this is inaccurate. The issue is that initial_delay is NOT nullable (it's typed as float, not float | None), which is why passing None prevented Pydantic from using its default value. Consider rephrasing to: "preventing pydantic validation error when initial_delay field is missing from the YAML"

Suggested change
Use `model_validate` for health check info to apply Pydantic defaults, preventing pydantic validation error when `initial_delay` field is nullable
Use `model_validate` for health check info to apply Pydantic defaults, preventing pydantic validation error when `initial_delay` field is missing from the YAML

Copilot uses AI. Check for mistakes.
9 changes: 1 addition & 8 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,14 +2386,7 @@ async def get_health_check_info(

for model_info in model_definition["models"]:
if health_check_info := model_info.get("service", {}).get("health_check"):
_info = ModelHealthCheck(
path=health_check_info["path"],
interval=health_check_info["interval"],
max_retries=health_check_info["max_retries"],
max_wait_time=health_check_info["max_wait_time"],
expected_status_code=health_check_info["expected_status_code"],
initial_delay=health_check_info.get("initial_delay"),
)
_info = ModelHealthCheck.model_validate(health_check_info)
break
elif (
self.config_provider.config.deployment.enable_model_definition_override
Expand Down
246 changes: 243 additions & 3 deletions tests/unit/manager/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,33 @@

import uuid
from collections.abc import AsyncGenerator, Mapping
from dataclasses import dataclass
from decimal import Decimal
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from ai.backend.common.auth import PublicKey, SecretKey
from ai.backend.common.plugin.hook import HookPluginContext
from ai.backend.common.types import BinarySize, DeviceId, SessionId, SlotName
from ai.backend.common.types import (
MODEL_SERVICE_RUNTIME_PROFILES,
BinarySize,
DeviceId,
QuotaScopeID,
QuotaScopeType,
RuntimeVariant,
SessionId,
SlotName,
VFolderID,
)
from ai.backend.manager.config.provider import ManagerConfigProvider
from ai.backend.manager.plugin.network import NetworkPluginContext
from ai.backend.manager.registry import AgentRegistry

if TYPE_CHECKING:
from collections.abc import Iterator


class DummyEtcd:
async def get_prefix(self, key: str) -> Mapping[str, Any]:
Expand Down Expand Up @@ -164,3 +178,229 @@ async def test_convert_resource_spec_to_resource_slot(
converted_allocations = registry.convert_resource_spec_to_resource_slot(allocations)
assert converted_allocations["cpu"] == Decimal("4")
assert converted_allocations["ram"] == Decimal(BinarySize.from_str("1g")) * 3


@dataclass
class MockEndpointData:
"""Mock EndpointData for testing."""

runtime_variant: RuntimeVariant
model_definition_path: str | None = None


@dataclass
class MockVFolderRow:
"""Mock VFolderRow for testing."""

host: str
vfid: VFolderID


@dataclass
class MockDeploymentConfig:
"""Mock deployment config for testing."""

enable_model_definition_override: bool = False


@dataclass
class MockConfig:
"""Mock config for testing."""

deployment: MockDeploymentConfig


@dataclass
class MockConfigProvider:
"""Mock config provider for testing."""

config: MockConfig


@dataclass
class HealthCheckTestCase:
"""Test case for health check configuration."""

input: dict[str, float | int | str]
expected_path: str
expected_interval: float = 10.0
expected_max_retries: int = 10
expected_max_wait_time: float = 15.0
expected_status_code: int = 200
expected_initial_delay: float = 60.0


class TestGetHealthCheckInfo:
"""Tests for get_health_check_info method."""

@pytest.fixture
def mock_storage_manager(self) -> AsyncMock:
return AsyncMock()

@pytest.fixture
def mock_config_provider(self) -> MockConfigProvider:
return MockConfigProvider(
config=MockConfig(
deployment=MockDeploymentConfig(enable_model_definition_override=False)
)
)

@pytest.fixture
def mock_endpoint_custom(self) -> MockEndpointData:
return MockEndpointData(
runtime_variant=RuntimeVariant.CUSTOM,
model_definition_path="model-definition.yaml",
)

@pytest.fixture
def mock_vfolder(self) -> MockVFolderRow:
quota_scope_id = QuotaScopeID(QuotaScopeType.PROJECT, uuid.uuid4())
return MockVFolderRow(
host="local",
vfid=VFolderID(quota_scope_id=quota_scope_id, folder_id=uuid.uuid4()),
)

@pytest.fixture
def patch_model_service_helper(self) -> Iterator[AsyncMock]:
"""Patch ModelServiceHelper methods for testing."""
with (
patch(
"ai.backend.manager.registry.ModelServiceHelper.validate_model_definition_file_exists",
new_callable=AsyncMock,
return_value="model-definition.yaml",
),
patch(
"ai.backend.manager.registry.ModelServiceHelper.validate_model_definition",
new_callable=AsyncMock,
) as mock_validate_definition,
):
yield mock_validate_definition

@pytest.fixture
def mock_registry(
self,
mock_storage_manager: AsyncMock,
mock_config_provider: MockConfigProvider,
) -> MagicMock:
"""Create a mock AgentRegistry with required dependencies."""
registry = MagicMock(spec=AgentRegistry)
registry.storage_manager = mock_storage_manager
registry.config_provider = mock_config_provider
return registry

@pytest.mark.parametrize(
"test_case",
[
HealthCheckTestCase(
input={
"path": "/custom-health",
"interval": 5.0,
"max_retries": 3,
"max_wait_time": 30.0,
"expected_status_code": 201,
"initial_delay": 120.0,
},
expected_path="/custom-health",
expected_interval=5.0,
expected_max_retries=3,
expected_max_wait_time=30.0,
expected_status_code=201,
expected_initial_delay=120.0,
),
HealthCheckTestCase(
input={
"path": "/health",
"interval": 10.0,
"max_retries": 5,
"max_wait_time": 20.0,
"expected_status_code": 200,
# initial_delay omitted - should use Pydantic default (60.0)
},
expected_path="/health",
expected_interval=10.0,
expected_max_retries=5,
expected_max_wait_time=20.0,
expected_status_code=200,
expected_initial_delay=60.0,
),
HealthCheckTestCase(
input={"path": "/health"},
# All optional fields use Pydantic defaults
expected_path="/health",
expected_interval=10.0,
expected_max_retries=10,
expected_max_wait_time=15.0,
expected_status_code=200,
expected_initial_delay=60.0,
),
],
)
async def test_custom_variant_health_check_config(
self,
mock_registry: MagicMock,
mock_endpoint_custom: MockEndpointData,
mock_vfolder: MockVFolderRow,
patch_model_service_helper: AsyncMock,
test_case: HealthCheckTestCase,
) -> None:
"""Test CUSTOM variant with various health check configurations."""
mock_validate_definition = patch_model_service_helper
mock_validate_definition.return_value = {
"models": [{"service": {"health_check": test_case.input}}]
}

result = await AgentRegistry.get_health_check_info(
mock_registry,
mock_endpoint_custom, # type: ignore[arg-type]
mock_vfolder, # type: ignore[arg-type]
)

assert result is not None
assert result.path == test_case.expected_path
assert result.interval == test_case.expected_interval
assert result.max_retries == test_case.expected_max_retries
assert result.max_wait_time == test_case.expected_max_wait_time
assert result.expected_status_code == test_case.expected_status_code
assert result.initial_delay == test_case.expected_initial_delay

async def test_custom_variant_without_health_check_returns_none(
self,
mock_registry: MagicMock,
mock_endpoint_custom: MockEndpointData,
mock_vfolder: MockVFolderRow,
patch_model_service_helper: AsyncMock,
) -> None:
"""Test CUSTOM variant without health_check in model definition returns None."""
mock_validate_definition = patch_model_service_helper
mock_validate_definition.return_value = {
"models": [{"service": {}}] # No health_check defined
}

result = await AgentRegistry.get_health_check_info(
mock_registry,
mock_endpoint_custom, # type: ignore[arg-type]
mock_vfolder, # type: ignore[arg-type]
)

assert result is None

async def test_vllm_variant_returns_default_health_check(
self,
mock_registry: MagicMock,
mock_vfolder: MockVFolderRow,
) -> None:
"""Test VLLM variant returns default health check endpoint from profile."""
endpoint = MockEndpointData(
runtime_variant=RuntimeVariant.VLLM,
model_definition_path=None,
)

result = await AgentRegistry.get_health_check_info(
mock_registry,
endpoint, # type: ignore[arg-type]
mock_vfolder, # type: ignore[arg-type]
)

assert result is not None
expected_path = MODEL_SERVICE_RUNTIME_PROFILES[RuntimeVariant.VLLM].health_check_endpoint
assert result.path == expected_path
Loading