Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/inference/models/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ type Model struct {
// ID is the globally unique model identifier.
ID string `json:"id"`
// Tags are the list of tags associated with the model.
Tags []string `json:"tags"`
Tags []string `json:"tags,omitempty"`
// Created is the Unix epoch timestamp corresponding to the model creation.
Created int64 `json:"created"`
// Config describes the model.
Expand Down
92 changes: 81 additions & 11 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Manager struct {
router *http.ServeMux
// distributionClient is the client for model distribution.
distributionClient *distribution.Client
// registryClient is the client for model registry.
registryClient *registry.Client
}

type ClientConfig struct {
Expand Down Expand Up @@ -65,12 +67,19 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma
// respond to requests, but may return errors if the client is required.
}

// Create the model registry client.
registryClient := registry.NewClient(
registry.WithTransport(c.Transport),
registry.WithUserAgent(c.UserAgent),
)

// Create the manager.
m := &Manager{
log: log,
pullTokens: make(chan struct{}, maximumConcurrentModelPulls),
router: http.NewServeMux(),
distributionClient: distributionClient,
registryClient: registryClient,
}

// Register routes.
Expand Down Expand Up @@ -189,24 +198,36 @@ func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) {

// handleGetModel handles GET <inference-prefix>/models/{name} requests.
func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
// Parse remote query parameter
remote := false
if r.URL.Query().Has("remote") {
if val, err := strconv.ParseBool(r.URL.Query().Get("remote")); err != nil {
m.log.Warnln("Error while parsing remote query parameter:", err)
} else {
remote = val
}
}

if remote && m.registryClient == nil {
http.Error(w, "registry client unavailable", http.StatusServiceUnavailable)
return
}

// Query the model.
model, err := m.GetModel(r.PathValue("name"))
var apiModel *Model
var err error

if remote {
apiModel, err = getRemoteModel(m, r.PathValue("name"))
} else {
apiModel, err = getLocalModel(m, r.PathValue("name"))
}

if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
if errors.Is(err, distribution.ErrModelNotFound) || errors.Is(err, registry.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
return
}

apiModel, err := ToModel(model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -218,6 +239,55 @@ func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
}
}

func getLocalModel(m *Manager, name string) (*Model, error) {
if m.distributionClient == nil {
return nil, errors.New("model distribution service unavailable")
}

// Query the model.
model, err := m.GetModel(name)
if err != nil {
return nil, err
}

return ToModel(model)
}

func getRemoteModel(m *Manager, name string) (*Model, error) {
if m.registryClient == nil {
return nil, errors.New("registry client unavailable")
}

model, err := m.registryClient.Model(context.Background(), name)
if err != nil {
return nil, err
}

id, err := model.ID()
if err != nil {
return nil, err
}

descriptor, err := model.Descriptor()
if err != nil {
return nil, err
}

config, err := model.Config()
if err != nil {
return nil, err
}

apiModel := &Model{
ID: id,
Tags: make([]string, 0),
Created: descriptor.Created.Unix(),
Config: config,
}

return apiModel, nil
}

// handleDeleteModel handles DELETE <inference-prefix>/models/{name} requests.
// query params:
// - force: if true, delete the model even if it has multiple tags
Expand Down
141 changes: 141 additions & 0 deletions pkg/inference/models/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package models

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"os"
Expand All @@ -13,6 +15,7 @@ import (

"github.com/docker/model-distribution/builder"
reg "github.com/docker/model-distribution/registry"
"github.com/docker/model-runner/pkg/inference"

"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -136,3 +139,141 @@ func TestPullModel(t *testing.T) {
})
}
}

func TestHandleGetModel(t *testing.T) {
// Create temp directory for store
tempDir, err := os.MkdirTemp("", "model-distribution-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)

// Create a test registry
server := httptest.NewServer(registry.New())
defer server.Close()

uri, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("Failed to parse registry URL: %v", err)
}

// Prepare the OCI model artifact
projectRoot := getProjectRoot(t)
model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf"))
if err != nil {
t.Fatalf("Failed to create model builder: %v", err)
}

license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt"))
if err != nil {
t.Fatalf("Failed to add license to model: %v", err)
}

// Build the OCI model artifact + push it
tag := uri.Host + "/ai/model:v1.0.0"
client := reg.NewClient()
target, err := client.NewTarget(tag)
if err != nil {
t.Fatalf("Failed to create model target: %v", err)
}
err = license.Build(context.Background(), target, os.Stdout)
if err != nil {
t.Fatalf("Failed to build model: %v", err)
}

tests := []struct {
name string
remote bool
modelName string
expectedCode int
expectedError string
}{
{
name: "get local model - success",
remote: false,
modelName: tag,
expectedCode: http.StatusOK,
},
{
name: "get local model - not found",
remote: false,
modelName: "nonexistent:v1",
expectedCode: http.StatusNotFound,
expectedError: "error while getting model",
},
{
name: "get remote model - success",
remote: true,
modelName: tag,
expectedCode: http.StatusOK,
},
{
name: "get remote model - not found",
remote: true,
modelName: uri.Host + "/ai/nonexistent:v1",
expectedCode: http.StatusNotFound,
expectedError: "not found",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
log := logrus.NewEntry(logrus.StandardLogger())
m := NewManager(log, ClientConfig{
StoreRootPath: tempDir,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
Transport: http.DefaultTransport,
UserAgent: "test-agent",
}, nil)

// First pull the model if we're testing local access
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`))
w := httptest.NewRecorder()
err = m.PullModel(tt.modelName, r, w)
if err != nil {
t.Fatalf("Failed to pull model: %v", err)
}
}

// Create request with remote query param
path := inference.ModelsPrefix + "/" + tt.modelName
if tt.remote {
path += "?remote=true"
}
r := httptest.NewRequest("GET", path, nil)
w := httptest.NewRecorder()

// Set the path value for {name} so r.PathValue("name") works
r.SetPathValue("name", tt.modelName)

// Call the handler directly
m.handleGetModel(w, r)

// Check response
if w.Code != tt.expectedCode {
t.Errorf("Expected status code %d, got %d", tt.expectedCode, w.Code)
}

if tt.expectedError != "" {
if !strings.Contains(w.Body.String(), tt.expectedError) {
t.Errorf("Expected error containing %q, got %q", tt.expectedError, w.Body.String())
}
} else {
// For successful responses, verify we got a valid JSON response
var response Model
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Errorf("Failed to decode response body: %v", err)
}
}

// Clean tempDir after each test
if err := os.RemoveAll(tempDir); err != nil {
t.Fatalf("Failed to clean temp directory: %v", err)
}
if err := os.MkdirAll(tempDir, 0755); err != nil {
t.Fatalf("Failed to recreate temp directory: %v", err)
}
})
}
}