Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/inference/models/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ type Model struct {
// ID is the globally unique model identifier.
ID string `json:"id"`
// Tags are the list of tags associated with the model.
Tags []string `json:"tags"`
Tags []string `json:"tags,omitempty"`
// Created is the Unix epoch timestamp corresponding to the model creation.
Created int64 `json:"created"`
// Config describes the model.
Expand Down
93 changes: 82 additions & 11 deletions pkg/inference/models/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Manager struct {
router *http.ServeMux
// distributionClient is the client for model distribution.
distributionClient *distribution.Client
// registryClient is the client for model registry.
registryClient *registry.Client
}

type ClientConfig struct {
Expand Down Expand Up @@ -65,12 +67,19 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma
// respond to requests, but may return errors if the client is required.
}

// Create the model registry client.
registryClient := registry.NewClient(
registry.WithTransport(c.Transport),
registry.WithUserAgent(c.UserAgent),
)

// Create the manager.
m := &Manager{
log: log,
pullTokens: make(chan struct{}, maximumConcurrentModelPulls),
router: http.NewServeMux(),
distributionClient: distributionClient,
registryClient: registryClient,
}

// Register routes.
Expand Down Expand Up @@ -189,24 +198,36 @@ func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) {

// handleGetModel handles GET <inference-prefix>/models/{name} requests.
func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
// Parse remote query parameter
remote := false
if r.URL.Query().Has("remote") {
if val, err := strconv.ParseBool(r.URL.Query().Get("remote")); err != nil {
m.log.Warnln("Error while parsing remote query parameter:", err)
} else {
remote = val
}
}

if remote && m.registryClient == nil {
http.Error(w, "registry client unavailable", http.StatusServiceUnavailable)
return
}

// Query the model.
model, err := m.GetModel(r.PathValue("name"))
var apiModel *Model
var err error

if remote {
apiModel, err = getRemoteModel(r.Context(), m, r.PathValue("name"))
} else {
apiModel, err = getLocalModel(m, r.PathValue("name"))
}

if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
if errors.Is(err, distribution.ErrModelNotFound) || errors.Is(err, registry.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
return
}

apiModel, err := ToModel(model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -218,6 +239,56 @@ func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
}
}

func getLocalModel(m *Manager, name string) (*Model, error) {
if m.distributionClient == nil {
return nil, errors.New("model distribution service unavailable")
}

// Query the model.
model, err := m.GetModel(name)
if err != nil {
return nil, err
}

return ToModel(model)
}

func getRemoteModel(ctx context.Context, m *Manager, name string) (*Model, error) {
if m.registryClient == nil {
return nil, errors.New("registry client unavailable")
}

m.log.Infoln("Getting remote model:", name)
model, err := m.registryClient.Model(ctx, name)
if err != nil {
return nil, err
}

id, err := model.ID()
if err != nil {
return nil, err
}

descriptor, err := model.Descriptor()
if err != nil {
return nil, err
}

config, err := model.Config()
if err != nil {
return nil, err
}

apiModel := &Model{
ID: id,
Tags: nil,
Created: descriptor.Created.Unix(),
Config: config,
}

return apiModel, nil
}

// handleDeleteModel handles DELETE <inference-prefix>/models/{name} requests.
// query params:
// - force: if true, delete the model even if it has multiple tags
Expand Down
141 changes: 141 additions & 0 deletions pkg/inference/models/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package models

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"os"
Expand All @@ -13,6 +15,7 @@ import (

"github.com/docker/model-distribution/builder"
reg "github.com/docker/model-distribution/registry"
"github.com/docker/model-runner/pkg/inference"

"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -136,3 +139,141 @@ func TestPullModel(t *testing.T) {
})
}
}

func TestHandleGetModel(t *testing.T) {
// Create temp directory for store
tempDir, err := os.MkdirTemp("", "model-distribution-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)

// Create a test registry
server := httptest.NewServer(registry.New())
defer server.Close()

uri, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("Failed to parse registry URL: %v", err)
}

// Prepare the OCI model artifact
projectRoot := getProjectRoot(t)
model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf"))
if err != nil {
t.Fatalf("Failed to create model builder: %v", err)
}

license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt"))
if err != nil {
t.Fatalf("Failed to add license to model: %v", err)
}

// Build the OCI model artifact + push it
tag := uri.Host + "/ai/model:v1.0.0"
client := reg.NewClient()
target, err := client.NewTarget(tag)
if err != nil {
t.Fatalf("Failed to create model target: %v", err)
}
err = license.Build(context.Background(), target, os.Stdout)
if err != nil {
t.Fatalf("Failed to build model: %v", err)
}

tests := []struct {
name string
remote bool
modelName string
expectedCode int
expectedError string
}{
{
name: "get local model - success",
remote: false,
modelName: tag,
expectedCode: http.StatusOK,
},
{
name: "get local model - not found",
remote: false,
modelName: "nonexistent:v1",
expectedCode: http.StatusNotFound,
expectedError: "error while getting model",
},
{
name: "get remote model - success",
remote: true,
modelName: tag,
expectedCode: http.StatusOK,
},
{
name: "get remote model - not found",
remote: true,
modelName: uri.Host + "/ai/nonexistent:v1",
expectedCode: http.StatusNotFound,
expectedError: "not found",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
log := logrus.NewEntry(logrus.StandardLogger())
m := NewManager(log, ClientConfig{
StoreRootPath: tempDir,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
Transport: http.DefaultTransport,
UserAgent: "test-agent",
}, nil)

// First pull the model if we're testing local access
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`))
w := httptest.NewRecorder()
err = m.PullModel(tt.modelName, r, w)
if err != nil {
t.Fatalf("Failed to pull model: %v", err)
}
}

// Create request with remote query param
path := inference.ModelsPrefix + "/" + tt.modelName
if tt.remote {
path += "?remote=true"
}
r := httptest.NewRequest("GET", path, nil)
w := httptest.NewRecorder()

// Set the path value for {name} so r.PathValue("name") works
r.SetPathValue("name", tt.modelName)

// Call the handler directly
m.handleGetModel(w, r)

// Check response
if w.Code != tt.expectedCode {
t.Errorf("Expected status code %d, got %d", tt.expectedCode, w.Code)
}

if tt.expectedError != "" {
if !strings.Contains(w.Body.String(), tt.expectedError) {
t.Errorf("Expected error containing %q, got %q", tt.expectedError, w.Body.String())
}
} else {
// For successful responses, verify we got a valid JSON response
var response Model
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Errorf("Failed to decode response body: %v", err)
}
}

// Clean tempDir after each test
if err := os.RemoveAll(tempDir); err != nil {
t.Fatalf("Failed to clean temp directory: %v", err)
}
if err := os.MkdirAll(tempDir, 0755); err != nil {
t.Fatalf("Failed to recreate temp directory: %v", err)
}
})
}
}