diff --git a/main.go b/main.go index 7e188b71..76545881 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/docker/model-runner/pkg/anthropic" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/mlx" @@ -174,6 +175,10 @@ func main() { ollamaHandler := ollama.NewHTTPHandler(log, scheduler, schedulerHTTP, nil, modelManager) router.Handle(ollama.APIPrefix+"/", ollamaHandler) + // Add Anthropic Messages API compatibility layer + anthropicHandler := anthropic.NewHandler(log, schedulerHTTP, nil, modelManager) + router.Handle(anthropic.APIPrefix+"/", anthropicHandler) + // Register root handler LAST - it will only catch exact "/" requests that don't match other patterns router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // Only respond to exact root path diff --git a/pkg/anthropic/handler.go b/pkg/anthropic/handler.go new file mode 100644 index 00000000..a03ff57f --- /dev/null +++ b/pkg/anthropic/handler.go @@ -0,0 +1,174 @@ +package anthropic + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/scheduling" + "github.com/docker/model-runner/pkg/internal/utils" + "github.com/docker/model-runner/pkg/logging" + "github.com/docker/model-runner/pkg/middleware" +) + +const ( + // APIPrefix is the prefix for Anthropic API routes. + // llama.cpp implements Anthropic API at /v1/messages, matching the official Anthropic API structure. + APIPrefix = "/anthropic" + + // maxRequestBodySize is the maximum allowed size for request bodies (10MB). + maxRequestBodySize = 10 * 1024 * 1024 +) + +// Handler implements the Anthropic Messages API compatibility layer. +// It forwards requests to the scheduler which proxies to llama.cpp, +// which natively supports the Anthropic Messages API format. +type Handler struct { + log logging.Logger + router *http.ServeMux + httpHandler http.Handler + modelManager *models.Manager + schedulerHTTP *scheduling.HTTPHandler +} + +// NewHandler creates a new Anthropic API handler. +func NewHandler(log logging.Logger, schedulerHTTP *scheduling.HTTPHandler, allowedOrigins []string, modelManager *models.Manager) *Handler { + h := &Handler{ + log: log, + router: http.NewServeMux(), + schedulerHTTP: schedulerHTTP, + modelManager: modelManager, + } + + // Register routes + for route, handler := range h.routeHandlers() { + h.router.HandleFunc(route, handler) + } + + // Apply CORS middleware + h.httpHandler = middleware.CorsMiddleware(allowedOrigins, h.router) + + return h +} + +// ServeHTTP implements the http.Handler interface. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + safeMethod := utils.SanitizeForLog(r.Method, -1) + safePath := utils.SanitizeForLog(r.URL.Path, -1) + h.log.Infof("Anthropic API request: %s %s", safeMethod, safePath) + h.httpHandler.ServeHTTP(w, r) +} + +// routeHandlers returns the mapping of routes to their handlers. +func (h *Handler) routeHandlers() map[string]http.HandlerFunc { + return map[string]http.HandlerFunc{ + // Messages API endpoint - main chat completion endpoint + "POST " + APIPrefix + "/v1/messages": h.handleMessages, + // Token counting endpoint + "POST " + APIPrefix + "/v1/messages/count_tokens": h.handleCountTokens, + } +} + +// MessagesRequest represents an Anthropic Messages API request. +// This is used to extract the model field for routing purposes. +type MessagesRequest struct { + Model string `json:"model"` +} + +// handleMessages handles POST /anthropic/v1/messages requests. +// It forwards the request to the scheduler which proxies to the llama.cpp backend. +// The llama.cpp backend natively handles the Anthropic Messages API format conversion. +func (h *Handler) handleMessages(w http.ResponseWriter, r *http.Request) { + h.proxyToBackend(w, r, "/v1/messages") +} + +// handleCountTokens handles POST /anthropic/v1/messages/count_tokens requests. +// It forwards the request to the scheduler which proxies to the llama.cpp backend. +func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) { + h.proxyToBackend(w, r, "/v1/messages/count_tokens") +} + +// proxyToBackend proxies the request to the llama.cpp backend via the scheduler. +func (h *Handler) proxyToBackend(w http.ResponseWriter, r *http.Request, targetPath string) { + ctx := r.Context() + + // Read the request body + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxRequestBodySize)) + if err != nil { + var maxBytesError *http.MaxBytesError + if errors.As(err, &maxBytesError) { + h.writeAnthropicError(w, http.StatusRequestEntityTooLarge, "request_too_large", "Request body too large") + } else { + h.writeAnthropicError(w, http.StatusInternalServerError, "internal_error", "Failed to read request body") + } + return + } + + // Parse the model field from the request to route to the correct backend + var req MessagesRequest + if err := json.Unmarshal(body, &req); err != nil { + h.writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", "Invalid JSON in request body") + return + } + + if req.Model == "" { + h.writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", "Missing required field: model") + return + } + + // Use model name from request + modelName := req.Model + + // Verify the model exists locally + _, err = h.modelManager.GetLocal(modelName) + if err != nil { + h.writeAnthropicError(w, http.StatusNotFound, "not_found_error", "Model not found: "+modelName) + return + } + + // Create the proxied request to the inference endpoint + // The scheduler will route to the appropriate backend + newReq := r.Clone(ctx) + newReq.URL.Path = inference.InferencePrefix + targetPath + newReq.Body = io.NopCloser(bytes.NewReader(body)) + newReq.ContentLength = int64(len(body)) + newReq.Header.Set("Content-Type", "application/json") + newReq.Header.Set(inference.RequestOriginHeader, inference.OriginAnthropicMessages) + + // Forward to the scheduler HTTP handler + h.schedulerHTTP.ServeHTTP(w, newReq) +} + +// AnthropicError represents an error response in the Anthropic API format. +type AnthropicError struct { + Type string `json:"type"` + Error AnthropicErrorObj `json:"error"` +} + +// AnthropicErrorObj represents the error object in an Anthropic error response. +type AnthropicErrorObj struct { + Type string `json:"type"` + Message string `json:"message"` +} + +// writeAnthropicError writes an error response in the Anthropic API format. +func (h *Handler) writeAnthropicError(w http.ResponseWriter, statusCode int, errorType, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + errResp := AnthropicError{ + Type: "error", + Error: AnthropicErrorObj{ + Type: errorType, + Message: message, + }, + } + + if err := json.NewEncoder(w).Encode(errResp); err != nil { + h.log.Errorf("Failed to encode error response: %v", err) + } +} diff --git a/pkg/anthropic/handler_test.go b/pkg/anthropic/handler_test.go new file mode 100644 index 00000000..b2925fab --- /dev/null +++ b/pkg/anthropic/handler_test.go @@ -0,0 +1,204 @@ +package anthropic + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/sirupsen/logrus" +) + +func TestWriteAnthropicError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + errorType string + message string + wantBody string + }{ + { + name: "invalid request error", + statusCode: http.StatusBadRequest, + errorType: "invalid_request_error", + message: "Missing required field: model", + wantBody: `{"type":"error","error":{"type":"invalid_request_error","message":"Missing required field: model"}}`, + }, + { + name: "not found error", + statusCode: http.StatusNotFound, + errorType: "not_found_error", + message: "Model not found: test-model", + wantBody: `{"type":"error","error":{"type":"not_found_error","message":"Model not found: test-model"}}`, + }, + { + name: "internal error", + statusCode: http.StatusInternalServerError, + errorType: "internal_error", + message: "An internal error occurred", + wantBody: `{"type":"error","error":{"type":"internal_error","message":"An internal error occurred"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + discard := logrus.New() + discard.SetOutput(io.Discard) + h := &Handler{log: logrus.NewEntry(discard)} + h.writeAnthropicError(rec, tt.statusCode, tt.errorType, tt.message) + + if rec.Code != tt.statusCode { + t.Errorf("expected status %d, got %d", tt.statusCode, rec.Code) + } + + if contentType := rec.Header().Get("Content-Type"); contentType != "application/json" { + t.Errorf("expected Content-Type application/json, got %s", contentType) + } + + body := strings.TrimSpace(rec.Body.String()) + if body != tt.wantBody { + t.Errorf("expected body %s, got %s", tt.wantBody, body) + } + }) + } +} + +func TestRouteHandlers(t *testing.T) { + t.Parallel() + + h := &Handler{ + router: http.NewServeMux(), + } + + routes := h.routeHandlers() + + expectedRoutes := []string{ + "POST " + APIPrefix + "/v1/messages", + "POST " + APIPrefix + "/v1/messages/count_tokens", + } + + for _, route := range expectedRoutes { + if _, exists := routes[route]; !exists { + t.Errorf("expected route %s to be registered", route) + } + } + + if len(routes) != len(expectedRoutes) { + t.Errorf("expected %d routes, got %d", len(expectedRoutes), len(routes)) + } +} + +func TestAPIPrefix(t *testing.T) { + t.Parallel() + + if APIPrefix != "/anthropic" { + t.Errorf("expected APIPrefix to be /anthropic, got %s", APIPrefix) + } +} + +func TestProxyToBackend_InvalidJSON(t *testing.T) { + t.Parallel() + + discard := logrus.New() + discard.SetOutput(io.Discard) + h := &Handler{log: logrus.NewEntry(discard)} + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{invalid json`)) + + h.proxyToBackend(rec, req, "/v1/messages") + + if rec.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } + + body := rec.Body.String() + if !strings.Contains(body, "invalid_request_error") { + t.Errorf("expected body to contain 'invalid_request_error', got %s", body) + } + if !strings.Contains(body, "Invalid JSON") { + t.Errorf("expected body to contain 'Invalid JSON', got %s", body) + } +} + +func TestProxyToBackend_MissingModel(t *testing.T) { + t.Parallel() + + discard := logrus.New() + discard.SetOutput(io.Discard) + h := &Handler{log: logrus.NewEntry(discard)} + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"messages": []}`)) + + h.proxyToBackend(rec, req, "/v1/messages") + + if rec.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } + + body := rec.Body.String() + if !strings.Contains(body, "invalid_request_error") { + t.Errorf("expected body to contain 'invalid_request_error', got %s", body) + } + if !strings.Contains(body, "Missing required field: model") { + t.Errorf("expected body to contain 'Missing required field: model', got %s", body) + } +} + +func TestProxyToBackend_EmptyModel(t *testing.T) { + t.Parallel() + + discard := logrus.New() + discard.SetOutput(io.Discard) + h := &Handler{log: logrus.NewEntry(discard)} + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model": ""}`)) + + h.proxyToBackend(rec, req, "/v1/messages") + + if rec.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code) + } + + body := rec.Body.String() + if !strings.Contains(body, "invalid_request_error") { + t.Errorf("expected body to contain 'invalid_request_error', got %s", body) + } + if !strings.Contains(body, "Missing required field: model") { + t.Errorf("expected body to contain 'Missing required field: model', got %s", body) + } +} + +func TestProxyToBackend_RequestTooLarge(t *testing.T) { + t.Parallel() + + discard := logrus.New() + discard.SetOutput(io.Discard) + h := &Handler{log: logrus.NewEntry(discard)} + + // Create a request body that exceeds the maxRequestBodySize (10MB) + // We'll use a reader that simulates a large body without actually allocating it + largeBody := strings.NewReader(`{"model": "test-model", "data": "` + strings.Repeat("x", maxRequestBodySize+1) + `"}`) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", largeBody) + + h.proxyToBackend(rec, req, "/v1/messages") + + if rec.Code != http.StatusRequestEntityTooLarge { + t.Errorf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code) + } + + body := rec.Body.String() + if !strings.Contains(body, "request_too_large") { + t.Errorf("expected body to contain 'request_too_large', got %s", body) + } +} diff --git a/pkg/inference/api.go b/pkg/inference/api.go index 0a1093e4..367d18ab 100644 --- a/pkg/inference/api.go +++ b/pkg/inference/api.go @@ -20,4 +20,6 @@ const RequestOriginHeader = "X-Request-Origin" const ( // OriginOllamaCompletion indicates the request came from the Ollama /api/chat or /api/generate endpoints OriginOllamaCompletion = "ollama/completion" + // OriginAnthropicMessages indicates the request came from the Anthropic /v1/messages endpoint + OriginAnthropicMessages = "anthropic/messages" ) diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index e66a28ff..f9460dc2 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -29,7 +29,7 @@ func trimRequestPathToOpenAIRoot(path string) string { } // backendModeForRequest determines the backend operation mode to handle an -// OpenAI inference request. Its second parameter is true if and only if a valid +// OpenAI or Anthropic inference request. Its second parameter is true if and only if a valid // mode could be determined. func backendModeForRequest(path string) (inference.BackendMode, bool) { if strings.HasSuffix(path, "/v1/chat/completions") || strings.HasSuffix(path, "/v1/completions") { @@ -38,6 +38,9 @@ func backendModeForRequest(path string) (inference.BackendMode, bool) { return inference.BackendModeEmbedding, true } else if strings.HasSuffix(path, "/rerank") || strings.HasSuffix(path, "/score") { return inference.BackendModeReranking, true + } else if strings.HasSuffix(path, "/v1/messages") || strings.HasSuffix(path, "/v1/messages/count_tokens") { + // Anthropic Messages API - treated as completion mode + return inference.BackendModeCompletion, true } return inference.BackendMode(0), false } diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 4d2ccb43..82e57b4d 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -67,8 +67,17 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { "POST " + inference.InferencePrefix + "/{backend}/score", "POST " + inference.InferencePrefix + "/score", } + + // Anthropic Messages API routes + anthropicRoutes := []string{ + "POST " + inference.InferencePrefix + "/{backend}/v1/messages", + "POST " + inference.InferencePrefix + "/v1/messages", + "POST " + inference.InferencePrefix + "/{backend}/v1/messages/count_tokens", + "POST " + inference.InferencePrefix + "/v1/messages/count_tokens", + } + m := make(map[string]http.HandlerFunc) - for _, route := range openAIRoutes { + for _, route := range append(openAIRoutes, anthropicRoutes...) { m[route] = h.handleOpenAIInference }