diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index 68be185dd..a21864dd0 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -88,7 +88,7 @@ type Model struct { // ID is the globally unique model identifier. ID string `json:"id"` // Tags are the list of tags associated with the model. - Tags []string `json:"tags"` + Tags []string `json:"tags,omitempty"` // Created is the Unix epoch timestamp corresponding to the model creation. Created int64 `json:"created"` // Config describes the model. diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 65f50a388..43389b5cc 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -37,6 +37,8 @@ type Manager struct { router *http.ServeMux // distributionClient is the client for model distribution. distributionClient *distribution.Client + // registryClient is the client for model registry. + registryClient *registry.Client } type ClientConfig struct { @@ -65,12 +67,19 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma // respond to requests, but may return errors if the client is required. } + // Create the model registry client. + registryClient := registry.NewClient( + registry.WithTransport(c.Transport), + registry.WithUserAgent(c.UserAgent), + ) + // Create the manager. m := &Manager{ log: log, pullTokens: make(chan struct{}, maximumConcurrentModelPulls), router: http.NewServeMux(), distributionClient: distributionClient, + registryClient: registryClient, } // Register routes. @@ -189,24 +198,36 @@ func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) { // handleGetModel handles GET /models/{name} requests. func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) { - if m.distributionClient == nil { - http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable) + // Parse remote query parameter + remote := false + if r.URL.Query().Has("remote") { + if val, err := strconv.ParseBool(r.URL.Query().Get("remote")); err != nil { + m.log.Warnln("Error while parsing remote query parameter:", err) + } else { + remote = val + } + } + + if remote && m.registryClient == nil { + http.Error(w, "registry client unavailable", http.StatusServiceUnavailable) return } - // Query the model. - model, err := m.GetModel(r.PathValue("name")) + var apiModel *Model + var err error + + if remote { + apiModel, err = getRemoteModel(r.Context(), m, r.PathValue("name")) + } else { + apiModel, err = getLocalModel(m, r.PathValue("name")) + } + if err != nil { - if errors.Is(err, distribution.ErrModelNotFound) { + if errors.Is(err, distribution.ErrModelNotFound) || errors.Is(err, registry.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) - } else { - http.Error(w, err.Error(), http.StatusInternalServerError) + return } - return - } - apiModel, err := ToModel(model) - if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -218,6 +239,56 @@ func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) { } } +func getLocalModel(m *Manager, name string) (*Model, error) { + if m.distributionClient == nil { + return nil, errors.New("model distribution service unavailable") + } + + // Query the model. + model, err := m.GetModel(name) + if err != nil { + return nil, err + } + + return ToModel(model) +} + +func getRemoteModel(ctx context.Context, m *Manager, name string) (*Model, error) { + if m.registryClient == nil { + return nil, errors.New("registry client unavailable") + } + + m.log.Infoln("Getting remote model:", name) + model, err := m.registryClient.Model(ctx, name) + if err != nil { + return nil, err + } + + id, err := model.ID() + if err != nil { + return nil, err + } + + descriptor, err := model.Descriptor() + if err != nil { + return nil, err + } + + config, err := model.Config() + if err != nil { + return nil, err + } + + apiModel := &Model{ + ID: id, + Tags: nil, + Created: descriptor.Created.Unix(), + Config: config, + } + + return apiModel, nil +} + // handleDeleteModel handles DELETE /models/{name} requests. // query params: // - force: if true, delete the model even if it has multiple tags diff --git a/pkg/inference/models/manager_test.go b/pkg/inference/models/manager_test.go index 62ae16ba8..7edd357a4 100644 --- a/pkg/inference/models/manager_test.go +++ b/pkg/inference/models/manager_test.go @@ -2,6 +2,8 @@ package models import ( "context" + "encoding/json" + "net/http" "net/http/httptest" "net/url" "os" @@ -13,6 +15,7 @@ import ( "github.com/docker/model-distribution/builder" reg "github.com/docker/model-distribution/registry" + "github.com/docker/model-runner/pkg/inference" "github.com/sirupsen/logrus" ) @@ -136,3 +139,141 @@ func TestPullModel(t *testing.T) { }) } } + +func TestHandleGetModel(t *testing.T) { + // Create temp directory for store + tempDir, err := os.MkdirTemp("", "model-distribution-test-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create a test registry + server := httptest.NewServer(registry.New()) + defer server.Close() + + uri, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse registry URL: %v", err) + } + + // Prepare the OCI model artifact + projectRoot := getProjectRoot(t) + model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf")) + if err != nil { + t.Fatalf("Failed to create model builder: %v", err) + } + + license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt")) + if err != nil { + t.Fatalf("Failed to add license to model: %v", err) + } + + // Build the OCI model artifact + push it + tag := uri.Host + "/ai/model:v1.0.0" + client := reg.NewClient() + target, err := client.NewTarget(tag) + if err != nil { + t.Fatalf("Failed to create model target: %v", err) + } + err = license.Build(context.Background(), target, os.Stdout) + if err != nil { + t.Fatalf("Failed to build model: %v", err) + } + + tests := []struct { + name string + remote bool + modelName string + expectedCode int + expectedError string + }{ + { + name: "get local model - success", + remote: false, + modelName: tag, + expectedCode: http.StatusOK, + }, + { + name: "get local model - not found", + remote: false, + modelName: "nonexistent:v1", + expectedCode: http.StatusNotFound, + expectedError: "error while getting model", + }, + { + name: "get remote model - success", + remote: true, + modelName: tag, + expectedCode: http.StatusOK, + }, + { + name: "get remote model - not found", + remote: true, + modelName: uri.Host + "/ai/nonexistent:v1", + expectedCode: http.StatusNotFound, + expectedError: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + log := logrus.NewEntry(logrus.StandardLogger()) + m := NewManager(log, ClientConfig{ + StoreRootPath: tempDir, + Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), + Transport: http.DefaultTransport, + UserAgent: "test-agent", + }, nil) + + // First pull the model if we're testing local access + if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") { + r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`)) + w := httptest.NewRecorder() + err = m.PullModel(tt.modelName, r, w) + if err != nil { + t.Fatalf("Failed to pull model: %v", err) + } + } + + // Create request with remote query param + path := inference.ModelsPrefix + "/" + tt.modelName + if tt.remote { + path += "?remote=true" + } + r := httptest.NewRequest("GET", path, nil) + w := httptest.NewRecorder() + + // Set the path value for {name} so r.PathValue("name") works + r.SetPathValue("name", tt.modelName) + + // Call the handler directly + m.handleGetModel(w, r) + + // Check response + if w.Code != tt.expectedCode { + t.Errorf("Expected status code %d, got %d", tt.expectedCode, w.Code) + } + + if tt.expectedError != "" { + if !strings.Contains(w.Body.String(), tt.expectedError) { + t.Errorf("Expected error containing %q, got %q", tt.expectedError, w.Body.String()) + } + } else { + // For successful responses, verify we got a valid JSON response + var response Model + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { + t.Errorf("Failed to decode response body: %v", err) + } + } + + // Clean tempDir after each test + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("Failed to clean temp directory: %v", err) + } + if err := os.MkdirAll(tempDir, 0755); err != nil { + t.Fatalf("Failed to recreate temp directory: %v", err) + } + }) + } +}