@@ -2,6 +2,8 @@ package models
22
33import (
44 "context"
5+ "encoding/json"
6+ "net/http"
57 "net/http/httptest"
68 "net/url"
79 "os"
@@ -13,6 +15,7 @@ import (
1315
1416 "github.com/docker/model-distribution/builder"
1517 reg "github.com/docker/model-distribution/registry"
18+ "github.com/docker/model-runner/pkg/inference"
1619
1720 "github.com/sirupsen/logrus"
1821)
@@ -136,3 +139,141 @@ func TestPullModel(t *testing.T) {
136139 })
137140 }
138141}
142+
143+ func TestHandleGetModel (t * testing.T ) {
144+ // Create temp directory for store
145+ tempDir , err := os .MkdirTemp ("" , "model-distribution-test-*" )
146+ if err != nil {
147+ t .Fatalf ("Failed to create temp directory: %v" , err )
148+ }
149+ defer os .RemoveAll (tempDir )
150+
151+ // Create a test registry
152+ server := httptest .NewServer (registry .New ())
153+ defer server .Close ()
154+
155+ uri , err := url .Parse (server .URL )
156+ if err != nil {
157+ t .Fatalf ("Failed to parse registry URL: %v" , err )
158+ }
159+
160+ // Prepare the OCI model artifact
161+ projectRoot := getProjectRoot (t )
162+ model , err := builder .FromGGUF (filepath .Join (projectRoot , "assets" , "dummy.gguf" ))
163+ if err != nil {
164+ t .Fatalf ("Failed to create model builder: %v" , err )
165+ }
166+
167+ license , err := model .WithLicense (filepath .Join (projectRoot , "assets" , "license.txt" ))
168+ if err != nil {
169+ t .Fatalf ("Failed to add license to model: %v" , err )
170+ }
171+
172+ // Build the OCI model artifact + push it
173+ tag := uri .Host + "/ai/model:v1.0.0"
174+ client := reg .NewClient ()
175+ target , err := client .NewTarget (tag )
176+ if err != nil {
177+ t .Fatalf ("Failed to create model target: %v" , err )
178+ }
179+ err = license .Build (context .Background (), target , os .Stdout )
180+ if err != nil {
181+ t .Fatalf ("Failed to build model: %v" , err )
182+ }
183+
184+ tests := []struct {
185+ name string
186+ remote bool
187+ modelName string
188+ expectedCode int
189+ expectedError string
190+ }{
191+ {
192+ name : "get local model - success" ,
193+ remote : false ,
194+ modelName : tag ,
195+ expectedCode : http .StatusOK ,
196+ },
197+ {
198+ name : "get local model - not found" ,
199+ remote : false ,
200+ modelName : "nonexistent:v1" ,
201+ expectedCode : http .StatusNotFound ,
202+ expectedError : "error while getting model" ,
203+ },
204+ {
205+ name : "get remote model - success" ,
206+ remote : true ,
207+ modelName : tag ,
208+ expectedCode : http .StatusOK ,
209+ },
210+ {
211+ name : "get remote model - not found" ,
212+ remote : true ,
213+ modelName : uri .Host + "/ai/nonexistent:v1" ,
214+ expectedCode : http .StatusNotFound ,
215+ expectedError : "not found" ,
216+ },
217+ }
218+
219+ for _ , tt := range tests {
220+ t .Run (tt .name , func (t * testing.T ) {
221+ log := logrus .NewEntry (logrus .StandardLogger ())
222+ m := NewManager (log , ClientConfig {
223+ StoreRootPath : tempDir ,
224+ Logger : log .WithFields (logrus.Fields {"component" : "model-manager" }),
225+ Transport : http .DefaultTransport ,
226+ UserAgent : "test-agent" ,
227+ }, nil )
228+
229+ // First pull the model if we're testing local access
230+ if ! tt .remote && ! strings .Contains (tt .modelName , "nonexistent" ) {
231+ r := httptest .NewRequest ("POST" , "/models/create" , strings .NewReader (`{"from": "` + tt .modelName + `"}` ))
232+ w := httptest .NewRecorder ()
233+ err = m .PullModel (tt .modelName , r , w )
234+ if err != nil {
235+ t .Fatalf ("Failed to pull model: %v" , err )
236+ }
237+ }
238+
239+ // Create request with remote query param
240+ path := inference .ModelsPrefix + "/" + tt .modelName
241+ if tt .remote {
242+ path += "?remote=true"
243+ }
244+ r := httptest .NewRequest ("GET" , path , nil )
245+ w := httptest .NewRecorder ()
246+
247+ // Set the path value for {name} so r.PathValue("name") works
248+ r .SetPathValue ("name" , tt .modelName )
249+
250+ // Call the handler directly
251+ m .handleGetModel (w , r )
252+
253+ // Check response
254+ if w .Code != tt .expectedCode {
255+ t .Errorf ("Expected status code %d, got %d" , tt .expectedCode , w .Code )
256+ }
257+
258+ if tt .expectedError != "" {
259+ if ! strings .Contains (w .Body .String (), tt .expectedError ) {
260+ t .Errorf ("Expected error containing %q, got %q" , tt .expectedError , w .Body .String ())
261+ }
262+ } else {
263+ // For successful responses, verify we got a valid JSON response
264+ var response Model
265+ if err := json .NewDecoder (w .Body ).Decode (& response ); err != nil {
266+ t .Errorf ("Failed to decode response body: %v" , err )
267+ }
268+ }
269+
270+ // Clean tempDir after each test
271+ if err := os .RemoveAll (tempDir ); err != nil {
272+ t .Fatalf ("Failed to clean temp directory: %v" , err )
273+ }
274+ if err := os .MkdirAll (tempDir , 0755 ); err != nil {
275+ t .Fatalf ("Failed to recreate temp directory: %v" , err )
276+ }
277+ })
278+ }
279+ }
0 commit comments