Skip to content

Commit 5a2117a

Browse files
committed
Add endpoint for tagging a model
Signed-off-by: Emily Casey <[email protected]>
1 parent 77acaee commit 5a2117a

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

pkg/inference/models/manager.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
9292
"GET " + inference.ModelsPrefix: m.handleGetModels,
9393
"GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel,
9494
"DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel,
95+
"POST " + inference.ModelsPrefix + "/{name}/tag": m.handleTagModel,
9596
"GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels,
9697
"GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel,
9798
"GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels,
@@ -288,6 +289,53 @@ func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) {
288289
}
289290
}
290291

292+
// handleTagModel handles POST <inference-prefix>/models/{name}/tag requests.
293+
// The query parameters are:
294+
// - repo: the repository to tag the model with (required)
295+
// - tag: the tag to tag the model with (optional, defaults to "latest")
296+
func (m *Manager) handleTagModel(w http.ResponseWriter, r *http.Request) {
297+
if m.distributionClient == nil {
298+
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
299+
return
300+
}
301+
302+
// Extract the model name from the request path.
303+
model := r.PathValue("name")
304+
305+
// Extract query parameters.
306+
repo := r.URL.Query().Get("repo")
307+
tag := r.URL.Query().Get("tag")
308+
309+
// Validate query parameters.
310+
if repo == "" {
311+
http.Error(w, "missing repo or tag query parameter", http.StatusBadRequest)
312+
return
313+
}
314+
if tag == "" {
315+
tag = "latest"
316+
}
317+
318+
// Construct the target string.
319+
target := fmt.Sprintf("%s:%s", repo, tag)
320+
321+
// Call the Tag method on the distribution client with source and modelName.
322+
if err := m.distributionClient.Tag(model, target); err != nil {
323+
m.log.Warnf("Failed to tag model %q: %v", model, err)
324+
325+
if errors.Is(err, distribution.ErrModelNotFound) {
326+
http.Error(w, err.Error(), http.StatusNotFound)
327+
return
328+
}
329+
330+
http.Error(w, err.Error(), http.StatusInternalServerError)
331+
return
332+
}
333+
334+
// Respond with success.
335+
w.WriteHeader(http.StatusOK)
336+
w.Write([]byte(fmt.Sprintf("Model %q tagged successfully with source %q", modelName, model)))
337+
}
338+
291339
// ServeHTTP implement net/http.Handler.ServeHTTP.
292340
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
293341
m.router.ServeHTTP(w, r)

0 commit comments

Comments
 (0)