Skip to content

Commit 9861625

Browse files
authored
Get remote models (#72)
* Adds test for get model * Get remote model * Use the request context * Add log and use nil for tags field of remote model, as we don't know their tags
1 parent 1658aa6 commit 9861625

File tree

3 files changed

+224
-12
lines changed

3 files changed

+224
-12
lines changed

pkg/inference/models/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ type Model struct {
8888
// ID is the globally unique model identifier.
8989
ID string `json:"id"`
9090
// Tags are the list of tags associated with the model.
91-
Tags []string `json:"tags"`
91+
Tags []string `json:"tags,omitempty"`
9292
// Created is the Unix epoch timestamp corresponding to the model creation.
9393
Created int64 `json:"created"`
9494
// Config describes the model.

pkg/inference/models/manager.go

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ type Manager struct {
3737
router *http.ServeMux
3838
// distributionClient is the client for model distribution.
3939
distributionClient *distribution.Client
40+
// registryClient is the client for model registry.
41+
registryClient *registry.Client
4042
}
4143

4244
type ClientConfig struct {
@@ -65,12 +67,19 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma
6567
// respond to requests, but may return errors if the client is required.
6668
}
6769

70+
// Create the model registry client.
71+
registryClient := registry.NewClient(
72+
registry.WithTransport(c.Transport),
73+
registry.WithUserAgent(c.UserAgent),
74+
)
75+
6876
// Create the manager.
6977
m := &Manager{
7078
log: log,
7179
pullTokens: make(chan struct{}, maximumConcurrentModelPulls),
7280
router: http.NewServeMux(),
7381
distributionClient: distributionClient,
82+
registryClient: registryClient,
7483
}
7584

7685
// Register routes.
@@ -189,24 +198,36 @@ func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) {
189198

190199
// handleGetModel handles GET <inference-prefix>/models/{name} requests.
191200
func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
192-
if m.distributionClient == nil {
193-
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
201+
// Parse remote query parameter
202+
remote := false
203+
if r.URL.Query().Has("remote") {
204+
if val, err := strconv.ParseBool(r.URL.Query().Get("remote")); err != nil {
205+
m.log.Warnln("Error while parsing remote query parameter:", err)
206+
} else {
207+
remote = val
208+
}
209+
}
210+
211+
if remote && m.registryClient == nil {
212+
http.Error(w, "registry client unavailable", http.StatusServiceUnavailable)
194213
return
195214
}
196215

197-
// Query the model.
198-
model, err := m.GetModel(r.PathValue("name"))
216+
var apiModel *Model
217+
var err error
218+
219+
if remote {
220+
apiModel, err = getRemoteModel(r.Context(), m, r.PathValue("name"))
221+
} else {
222+
apiModel, err = getLocalModel(m, r.PathValue("name"))
223+
}
224+
199225
if err != nil {
200-
if errors.Is(err, distribution.ErrModelNotFound) {
226+
if errors.Is(err, distribution.ErrModelNotFound) || errors.Is(err, registry.ErrModelNotFound) {
201227
http.Error(w, err.Error(), http.StatusNotFound)
202-
} else {
203-
http.Error(w, err.Error(), http.StatusInternalServerError)
228+
return
204229
}
205-
return
206-
}
207230

208-
apiModel, err := ToModel(model)
209-
if err != nil {
210231
http.Error(w, err.Error(), http.StatusInternalServerError)
211232
return
212233
}
@@ -218,6 +239,56 @@ func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
218239
}
219240
}
220241

242+
func getLocalModel(m *Manager, name string) (*Model, error) {
243+
if m.distributionClient == nil {
244+
return nil, errors.New("model distribution service unavailable")
245+
}
246+
247+
// Query the model.
248+
model, err := m.GetModel(name)
249+
if err != nil {
250+
return nil, err
251+
}
252+
253+
return ToModel(model)
254+
}
255+
256+
func getRemoteModel(ctx context.Context, m *Manager, name string) (*Model, error) {
257+
if m.registryClient == nil {
258+
return nil, errors.New("registry client unavailable")
259+
}
260+
261+
m.log.Infoln("Getting remote model:", name)
262+
model, err := m.registryClient.Model(ctx, name)
263+
if err != nil {
264+
return nil, err
265+
}
266+
267+
id, err := model.ID()
268+
if err != nil {
269+
return nil, err
270+
}
271+
272+
descriptor, err := model.Descriptor()
273+
if err != nil {
274+
return nil, err
275+
}
276+
277+
config, err := model.Config()
278+
if err != nil {
279+
return nil, err
280+
}
281+
282+
apiModel := &Model{
283+
ID: id,
284+
Tags: nil,
285+
Created: descriptor.Created.Unix(),
286+
Config: config,
287+
}
288+
289+
return apiModel, nil
290+
}
291+
221292
// handleDeleteModel handles DELETE <inference-prefix>/models/{name} requests.
222293
// query params:
223294
// - force: if true, delete the model even if it has multiple tags

pkg/inference/models/manager_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package models
22

33
import (
44
"context"
5+
"encoding/json"
6+
"net/http"
57
"net/http/httptest"
68
"net/url"
79
"os"
@@ -13,6 +15,7 @@ import (
1315

1416
"github.com/docker/model-distribution/builder"
1517
reg "github.com/docker/model-distribution/registry"
18+
"github.com/docker/model-runner/pkg/inference"
1619

1720
"github.com/sirupsen/logrus"
1821
)
@@ -136,3 +139,141 @@ func TestPullModel(t *testing.T) {
136139
})
137140
}
138141
}
142+
143+
func TestHandleGetModel(t *testing.T) {
144+
// Create temp directory for store
145+
tempDir, err := os.MkdirTemp("", "model-distribution-test-*")
146+
if err != nil {
147+
t.Fatalf("Failed to create temp directory: %v", err)
148+
}
149+
defer os.RemoveAll(tempDir)
150+
151+
// Create a test registry
152+
server := httptest.NewServer(registry.New())
153+
defer server.Close()
154+
155+
uri, err := url.Parse(server.URL)
156+
if err != nil {
157+
t.Fatalf("Failed to parse registry URL: %v", err)
158+
}
159+
160+
// Prepare the OCI model artifact
161+
projectRoot := getProjectRoot(t)
162+
model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf"))
163+
if err != nil {
164+
t.Fatalf("Failed to create model builder: %v", err)
165+
}
166+
167+
license, err := model.WithLicense(filepath.Join(projectRoot, "assets", "license.txt"))
168+
if err != nil {
169+
t.Fatalf("Failed to add license to model: %v", err)
170+
}
171+
172+
// Build the OCI model artifact + push it
173+
tag := uri.Host + "/ai/model:v1.0.0"
174+
client := reg.NewClient()
175+
target, err := client.NewTarget(tag)
176+
if err != nil {
177+
t.Fatalf("Failed to create model target: %v", err)
178+
}
179+
err = license.Build(context.Background(), target, os.Stdout)
180+
if err != nil {
181+
t.Fatalf("Failed to build model: %v", err)
182+
}
183+
184+
tests := []struct {
185+
name string
186+
remote bool
187+
modelName string
188+
expectedCode int
189+
expectedError string
190+
}{
191+
{
192+
name: "get local model - success",
193+
remote: false,
194+
modelName: tag,
195+
expectedCode: http.StatusOK,
196+
},
197+
{
198+
name: "get local model - not found",
199+
remote: false,
200+
modelName: "nonexistent:v1",
201+
expectedCode: http.StatusNotFound,
202+
expectedError: "error while getting model",
203+
},
204+
{
205+
name: "get remote model - success",
206+
remote: true,
207+
modelName: tag,
208+
expectedCode: http.StatusOK,
209+
},
210+
{
211+
name: "get remote model - not found",
212+
remote: true,
213+
modelName: uri.Host + "/ai/nonexistent:v1",
214+
expectedCode: http.StatusNotFound,
215+
expectedError: "not found",
216+
},
217+
}
218+
219+
for _, tt := range tests {
220+
t.Run(tt.name, func(t *testing.T) {
221+
log := logrus.NewEntry(logrus.StandardLogger())
222+
m := NewManager(log, ClientConfig{
223+
StoreRootPath: tempDir,
224+
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
225+
Transport: http.DefaultTransport,
226+
UserAgent: "test-agent",
227+
}, nil)
228+
229+
// First pull the model if we're testing local access
230+
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
231+
r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tt.modelName+`"}`))
232+
w := httptest.NewRecorder()
233+
err = m.PullModel(tt.modelName, r, w)
234+
if err != nil {
235+
t.Fatalf("Failed to pull model: %v", err)
236+
}
237+
}
238+
239+
// Create request with remote query param
240+
path := inference.ModelsPrefix + "/" + tt.modelName
241+
if tt.remote {
242+
path += "?remote=true"
243+
}
244+
r := httptest.NewRequest("GET", path, nil)
245+
w := httptest.NewRecorder()
246+
247+
// Set the path value for {name} so r.PathValue("name") works
248+
r.SetPathValue("name", tt.modelName)
249+
250+
// Call the handler directly
251+
m.handleGetModel(w, r)
252+
253+
// Check response
254+
if w.Code != tt.expectedCode {
255+
t.Errorf("Expected status code %d, got %d", tt.expectedCode, w.Code)
256+
}
257+
258+
if tt.expectedError != "" {
259+
if !strings.Contains(w.Body.String(), tt.expectedError) {
260+
t.Errorf("Expected error containing %q, got %q", tt.expectedError, w.Body.String())
261+
}
262+
} else {
263+
// For successful responses, verify we got a valid JSON response
264+
var response Model
265+
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
266+
t.Errorf("Failed to decode response body: %v", err)
267+
}
268+
}
269+
270+
// Clean tempDir after each test
271+
if err := os.RemoveAll(tempDir); err != nil {
272+
t.Fatalf("Failed to clean temp directory: %v", err)
273+
}
274+
if err := os.MkdirAll(tempDir, 0755); err != nil {
275+
t.Fatalf("Failed to recreate temp directory: %v", err)
276+
}
277+
})
278+
}
279+
}

0 commit comments

Comments
 (0)