Skip to content

Commit ac85dd5

Browse files
committed
fix: Get tracking metadata for running models
Include model info and metadata retrieved from the tracking server when fetching the list of running models, if the 'verbose' query parameter is set to 'true' and the running instance holds a label with the associated model URI. The API was previously attempting to fetch information about a registered model using the URI which would always fail (the function expects a model name, not a URI). Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 431c9c4 commit ac85dd5

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

cogstack_model_gateway/common/tracking.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
import mlflow
4+
import mlflow.models
45
from mlflow import MlflowClient, MlflowException
56
from mlflow.entities import Run, RunStatus
67

@@ -76,6 +77,7 @@ class TrackingClient:
7677
def __init__(self, tracking_uri: str = None):
7778
self.tracking_uri = tracking_uri or mlflow.get_tracking_uri()
7879
self._mlflow_client = MlflowClient(self.tracking_uri)
80+
mlflow.set_tracking_uri(self.tracking_uri)
7981

8082
def get_task(self, tracking_id: str) -> TrackingTask:
8183
"""Get a task by its tracking ID."""
@@ -119,3 +121,20 @@ def get_model_uri(self, tracking_id: str) -> str:
119121
except Exception as e:
120122
log.error(f"Failed to get model URI for task with tracking ID '{tracking_id}': {e}")
121123
return None
124+
125+
def get_model_metadata(self, model_uri: str) -> dict:
126+
"""Get metadata for a model URI."""
127+
try:
128+
model_info = mlflow.models.get_model_info(model_uri)
129+
return {
130+
"uuid": model_info.model_uuid,
131+
"signature": model_info.signature.to_dict() if model_info.signature else {},
132+
"flavors": model_info.flavors,
133+
"run_id": model_info.run_id,
134+
"artifact_path": model_info.artifact_path,
135+
"utc_time_created": model_info.utc_time_created,
136+
"mlflow_version": model_info.mlflow_version,
137+
}
138+
except MlflowException as e:
139+
log.error(f"Failed to get model metadata for model URI '{model_uri}': {e}")
140+
return None

cogstack_model_gateway/gateway/core/models.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import os
22

33
import docker
4-
import mlflow
54
from docker.models.containers import Container
6-
from mlflow.entities.model_registry import RegisteredModel
75

86
from cogstack_model_gateway.common.containers import (
97
IS_MODEL_LABEL,
@@ -37,14 +35,6 @@ def get_running_models() -> list[dict]:
3735
]
3836

3937

40-
def get_model_meta(model_uri: str) -> RegisteredModel:
41-
try:
42-
client = mlflow.tracking.MlflowClient()
43-
return client.get_registered_model(model_uri)
44-
except mlflow.exceptions.MlflowException:
45-
return None
46-
47-
4838
def run_model_container(model_name: str, model_uri: str, ttl: int):
4939
client = docker.from_env()
5040
cms_project = os.getenv(CMS_PROJECT_ENV_VAR)

cogstack_model_gateway/gateway/routers/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44

55
import requests
66
from docker.errors import DockerException
7-
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Request
7+
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Request
88
from starlette.datastructures import UploadFile as StarletteUploadFile
99

1010
from cogstack_model_gateway.common.config import Config, get_config
1111
from cogstack_model_gateway.common.object_store import ObjectStoreManager
1212
from cogstack_model_gateway.common.queue import QueueManager
1313
from cogstack_model_gateway.common.tasks import Status, TaskManager
1414
from cogstack_model_gateway.common.tracking import TrackingClient
15-
from cogstack_model_gateway.gateway.core.models import (
16-
get_model_meta,
17-
get_running_models,
18-
run_model_container,
19-
)
15+
from cogstack_model_gateway.gateway.core.models import get_running_models, run_model_container
2016
from cogstack_model_gateway.gateway.core.priority import calculate_task_priority
2117
from cogstack_model_gateway.gateway.routers.utils import (
2218
get_content_type,
@@ -103,11 +99,15 @@
10399

104100

105101
@router.get("/models/", response_model=list[dict], tags=["models"])
106-
async def get_models():
102+
async def get_models(
103+
verbose: Annotated[
104+
bool | None, Query(description="Include model metadata from the tracking server")
105+
] = False,
106+
):
107107
models = get_running_models()
108108
for model in models:
109-
if model["uri"]:
110-
if model_info := get_model_meta(model["uri"]):
109+
if model["uri"] and verbose:
110+
if model_info := TrackingClient().get_model_metadata(model["uri"]):
111111
model["info"] = model_info
112112
return models
113113

0 commit comments

Comments
 (0)