Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions api/openapi/catalog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ paths:
type: string
in: query
required: false
- name: sourceLabel
description: |-
Filter MCP servers by the label associated with the source. Multiple
values can be separated by commas. If one of the values is the
string `null`, then MCP servers from every source without a label will
be returned.
schema:
type: array
items:
type: string
in: query
required: false
- $ref: "#/components/parameters/mcpServerFilterQuery"
- $ref: "#/components/parameters/namedQuery"
- $ref: "#/components/parameters/includeTools"
Expand Down Expand Up @@ -326,6 +338,7 @@ paths:
- ModelCatalogService
parameters:
- $ref: "#/components/parameters/name"
- $ref: "#/components/parameters/assetType"
- $ref: "#/components/parameters/pageSize"
- $ref: "#/components/parameters/orderBy"
- $ref: "#/components/parameters/sortOrder"
Expand Down Expand Up @@ -2056,6 +2069,13 @@ components:
maximum: 100
in: query
required: false
assetType:
name: assetType
description: Filter sources by asset type.
schema:
$ref: "#/components/schemas/CatalogAssetType"
in: query
required: false
artifactType:
style: form
explode: true
Expand Down
20 changes: 20 additions & 0 deletions api/openapi/src/catalog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ paths:
- ModelCatalogService
parameters:
- $ref: "#/components/parameters/name"
- $ref: "#/components/parameters/assetType"
- $ref: "#/components/parameters/pageSize"
- $ref: "#/components/parameters/orderBy"
- $ref: "#/components/parameters/sortOrder"
Expand Down Expand Up @@ -224,6 +225,18 @@ paths:
type: string
in: query
required: false
- name: sourceLabel
description: |-
Filter MCP servers by the label associated with the source. Multiple
values can be separated by commas. If one of the values is the
string `null`, then MCP servers from every source without a label will
be returned.
schema:
type: array
items:
type: string
in: query
required: false
- $ref: "#/components/parameters/mcpServerFilterQuery"
- $ref: "#/components/parameters/namedQuery"
- $ref: "#/components/parameters/includeTools"
Expand Down Expand Up @@ -1779,6 +1792,13 @@ components:
maximum: 100
in: query
required: false
assetType:
name: assetType
description: Filter sources by asset type.
schema:
$ref: "#/components/schemas/CatalogAssetType"
in: query
required: false
artifactType:
style: form
explode: true
Expand Down
14 changes: 13 additions & 1 deletion catalog/clients/python/src/model_catalog/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import wraps
from typing import Any, TypeVar
from urllib.parse import quote
from catalog_openapi.models import OrderByField, SortOrder

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -338,7 +339,6 @@ def get_models(
Returns:
Dict with models response.
"""
from catalog_openapi.models import OrderByField, SortOrder

source_list = [source] if source else None
page_size_str = str(page_size) if page_size is not None else None
Expand Down Expand Up @@ -370,6 +370,8 @@ def get_artifacts(
model_name: str,
artifact_type: str | list[str] | None = None,
filter_query: str | None = None,
order_by: str | None = None,
sort_order: str | None = None,
page_size: int | None = None,
next_page_token: str | None = None,
) -> dict[str, Any]:
Expand All @@ -382,12 +384,16 @@ def get_artifacts(
Accepts "model-artifact", "metrics-artifact", or a list of these values.
Can be a single string or list of strings.
filter_query: Optional filter query.
order_by: Optional field to order by (ID, NAME, CREATE_TIME, LAST_UPDATE_TIME,
or custom property path like "accuracy.double_value").
sort_order: Optional sort order (ASC or DESC).
page_size: Optional page size.
next_page_token: Optional pagination token.

Returns:
Dict with artifacts response.
"""

# Convert artifact_type to list format expected by OpenAPI client
artifact_type_list = None
if artifact_type is not None:
Expand All @@ -396,12 +402,18 @@ def get_artifacts(
else:
artifact_type_list = artifact_type

sort_order_enum: SortOrder | None = None
if sort_order:
sort_order_enum = SortOrder(sort_order.upper())

page_size_str = str(page_size) if page_size is not None else None
response = self.catalog_api.get_all_model_artifacts(
source_id=_encode_path_param(source_id),
model_name=_encode_path_param(model_name),
artifact_type=artifact_type_list,
filter_query=filter_query,
order_by=order_by,
sort_order=sort_order_enum,
page_size=page_size_str,
next_page_token=next_page_token,
)
Expand Down
26 changes: 10 additions & 16 deletions catalog/clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,46 +236,40 @@ def api_client(user_token: str | None, verify_ssl: bool) -> Generator[CatalogAPI


@pytest.fixture(scope="session")
def model_with_artifacts(api_client: CatalogAPIClient, verify_ssl: bool) -> tuple[str, str]:
def model_with_artifacts(
request: pytest.FixtureRequest, api_client: CatalogAPIClient, verify_ssl: bool
) -> tuple[str, str]:
"""Get a model that has artifacts for testing.

Searches available models to find one with artifacts.
Fails if no models or no models with artifacts are found.
The param value controls the minimum number of artifacts required.
Default is 1. Override via indirect parametrization for tests needing more.

Returns:
Tuple of (source_id, model_name) for a model with artifacts.

Raises:
pytest.fail: If no models are available or no model has artifacts.
pytest.fail: If no models are available or no model has enough artifacts.
"""
min_artifacts = getattr(request, "param", 1)
if not verify_ssl:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

models = api_client.get_models()
if not models.get("items"):
pytest.fail("No models available - test data may not be loaded")

# Find a model that has artifacts
for model in models["items"]:
source_id = model.get("source_id")
model_name = model.get("name")
if not source_id or not model_name:
continue

# Check if this model has artifacts
artifacts = api_client.get_artifacts(source_id=source_id, model_name=model_name)
if artifacts.get("items"):
items = artifacts.get("items", [])
if len(items) >= min_artifacts:
return source_id, model_name

# Fallback to first model with required fields
model = models["items"][0]
source_id = model.get("source_id")
model_name = model.get("name")

if not source_id or not model_name:
pytest.fail("Model missing source_id or name - test data may be malformed")

return source_id, model_name
pytest.fail(f"No model with at least {min_artifacts} artifact(s) found")


@pytest.fixture(scope="session")
Expand Down
68 changes: 53 additions & 15 deletions catalog/clients/python/tests/sorting_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utility functions for sorting tests."""

import pytest
from typing import Any


Expand Down Expand Up @@ -42,25 +41,64 @@ def get_field_value(item: dict[str, Any], field: str) -> Any:
return value


def validate_items_sorted_correctly(items: list[dict], field: str, order: str) -> bool:
"""Verify items are sorted correctly by the specified field.
def sort_items_by_field(items: list[dict], field: str, order: str) -> list[dict]:
"""Sort items by the specified field and order.

Args:
items: List of items to validate
field: Field name to check sorting on (ID, CREATE_TIME, LAST_UPDATE_TIME)
items: List of items to sort
field: Field name to sort on (ID, CREATE_TIME, LAST_UPDATE_TIME)
order: Sort order (ASC or DESC)

Returns:
True if items are sorted correctly, False otherwise
A new list of items sorted by the specified field and order.

Raises:
ValueError: If field or order is invalid
"""
if order in {"ASC", "DESC"}:
return sorted(items, key=lambda item: get_field_value(item, field), reverse=(order == "DESC"))
raise ValueError(f"Invalid sort order: {order}")



def sort_items_by_custom_property(items: list[dict], property_field: str, sort_order: str) -> list[dict]:
"""Sort items by a custom property value with fallback behavior.

Expected behavior:
1. Items WITH the property appear first, sorted by value (ASC/DESC)
2. Items WITHOUT the property appear after, sorted by ID ASC (fallback)

Args:
items: List of artifact items from API response
property_field: Property field path (e.g., "accuracy.double_value")
sort_order: Sort order (ASC or DESC)

Returns:
A new list of items sorted by the custom property with fallback to ID ASC.

Raises:
ValueError: If sort_order is invalid
"""
if len(items) < 2:
pytest.fail("List has fewer than 2 items, double check the data you are passing to this function")
if sort_order not in ("ASC", "DESC"):
raise ValueError(f"Invalid sort order: {sort_order}")

property_name, value_type = property_field.rsplit(".", 1)

items_with = []
items_without = []

for item in items:
custom_props = item.get("customProperties", {})
if property_name in custom_props:
value = custom_props[property_name].get(value_type)
if value is not None:
items_with.append((item, value))
else:
items_without.append(item)
else:
items_without.append(item)

values = [get_field_value(item, field) for item in items]
sorted_with = sorted(items_with, key=lambda x: x[1], reverse=(sort_order == "DESC"))
sorted_without = sorted(items_without, key=lambda item: int(item["id"]))

if order == "ASC":
return all(values[i] <= values[i + 1] for i in range(len(values) - 1))
elif order == "DESC":
return all(values[i] >= values[i + 1] for i in range(len(values) - 1))
else:
raise ValueError(f"Invalid sort order: {order}")
return [item for item, _ in sorted_with] + sorted_without
85 changes: 84 additions & 1 deletion catalog/clients/python/tests/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from model_catalog import CatalogAPIClient, CatalogValidationError
from tests.sorting_utils import sort_items_by_custom_property, sort_items_by_field


class TestArtifacts:
Expand Down Expand Up @@ -428,4 +429,86 @@ def test_filter_artifacts_by_invalid_artifact_type(
source_id=source_id,
model_name=model_name,
artifact_type=invalid_artifact_type,
)
)


@pytest.mark.parametrize("model_with_artifacts", [2], indirect=True)
class TestArtifactsSorting:
"""Basic sorting coverage for artifact endpoints. Thorough testing is done downstream."""

@pytest.mark.parametrize(
"order_by,sort_order",
[
pytest.param("ID", "ASC", id="id_asc"),
pytest.param("ID", "DESC", id="id_desc"),
],
)
def test_artifacts_sorting_by_id(
self,
api_client: CatalogAPIClient,
model_with_artifacts: tuple[str, str],
order_by: str,
sort_order: str,
) -> None:
"""Test artifacts endpoint sorts correctly by ID."""
source_id, model_name = model_with_artifacts

response = api_client.get_artifacts(
source_id=source_id,
model_name=model_name,
order_by=order_by,
sort_order=sort_order,
)

items = response["items"]
assert len(items) >= 2, "Need at least 2 artifacts to validate sorting"
assert items == sort_items_by_field(items, order_by, sort_order)

@pytest.mark.parametrize(
"order_by,sort_order",
[
pytest.param("accuracy.double_value", "ASC", id="accuracy_asc"),
pytest.param("accuracy.double_value", "DESC", id="accuracy_desc"),
],
)
def test_artifacts_sorting_by_custom_property(
self,
api_client: CatalogAPIClient,
model_with_artifacts: tuple[str, str],
order_by: str,
sort_order: str,
) -> None:
"""Test artifacts endpoint sorts correctly by custom property values."""
source_id, model_name = model_with_artifacts

response = api_client.get_artifacts(
source_id=source_id,
model_name=model_name,
order_by=order_by,
sort_order=sort_order,
)

items = response["items"]
assert len(items) >= 2, "Need at least 2 artifacts to validate sorting"
assert items == sort_items_by_custom_property(items, order_by, sort_order)

def test_sorting_by_non_existing_property_falls_back_to_id(
self,
api_client: CatalogAPIClient,
model_with_artifacts: tuple[str, str],
) -> None:
"""Test that sorting by a non-existing property falls back to ID ASC."""
source_id, model_name = model_with_artifacts

response = api_client.get_artifacts(
source_id=source_id,
model_name=model_name,
order_by="non_existing_property.double_value",
sort_order="DESC",
)

items = response["items"]
assert len(items) >= 2, "Need at least 2 artifacts to validate sorting"

# No artifacts have this property, so all should fall back to ID ASC
assert items == sort_items_by_field(items, "ID", "ASC")
Loading
Loading