diff --git a/main.go b/main.go index 76545881..36fdeb8c 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "github.com/docker/model-runner/pkg/metrics" "github.com/docker/model-runner/pkg/middleware" "github.com/docker/model-runner/pkg/ollama" + "github.com/docker/model-runner/pkg/responses" "github.com/docker/model-runner/pkg/routing" "github.com/sirupsen/logrus" ) @@ -165,6 +166,16 @@ func main() { router.Handle(inference.ModelsPrefix, modelHandler) router.Handle(inference.ModelsPrefix+"/", modelHandler) router.Handle(inference.InferencePrefix+"/", schedulerHTTP) + // Add OpenAI Responses API compatibility layer + responsesHandler := responses.NewHTTPHandler(log, schedulerHTTP, nil) + router.Handle(responses.APIPrefix+"/", responsesHandler) + router.Handle(responses.APIPrefix, responsesHandler) // Also register for exact match without trailing slash + router.Handle("/v1"+responses.APIPrefix+"/", responsesHandler) + router.Handle("/v1"+responses.APIPrefix, responsesHandler) + // Also register Responses API under inference prefix to support all inference engines + router.Handle(inference.InferencePrefix+responses.APIPrefix+"/", responsesHandler) + router.Handle(inference.InferencePrefix+responses.APIPrefix, responsesHandler) + // Add path aliases: /v1 -> /engines/v1, /rerank -> /engines/rerank, /score -> /engines/score. aliasHandler := &middleware.AliasHandler{Handler: schedulerHTTP} router.Handle("/v1/", aliasHandler) diff --git a/pkg/inference/models/adapter.go b/pkg/inference/models/adapter.go index a686199f..2ca6d7fe 100644 --- a/pkg/inference/models/adapter.go +++ b/pkg/inference/models/adapter.go @@ -1,6 +1,7 @@ package models import ( + "encoding/json" "fmt" "github.com/docker/model-runner/pkg/distribution/types" @@ -27,12 +28,23 @@ func ToModel(m types.Model) (*Model, error) { created = desc.Created.Unix() } - return &Model{ + model := &Model{ ID: id, Tags: m.Tags(), Created: created, Config: cfg, - }, nil + } + + // Marshal the config to populate RawConfig + if cfg != nil { + configData, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal config: %w", err) + } + model.RawConfig = configData + } + + return model, nil } // ToModelFromArtifact converts a types.ModelArtifact (typically from remote registry) @@ -58,10 +70,21 @@ func ToModelFromArtifact(artifact types.ModelArtifact) (*Model, error) { created = desc.Created.Unix() } - return &Model{ + model := &Model{ ID: id, Tags: nil, // Remote models don't have local tags Created: created, Config: cfg, - }, nil + } + + // Marshal the config to populate RawConfig + if cfg != nil { + configData, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal config: %w", err) + } + model.RawConfig = configData + } + + return model, nil } diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index ffb724c1..e963603c 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -112,7 +112,35 @@ type Model struct { Created int64 `json:"created"` // Config describes the model. Can be either Docker format (*types.Config) // or ModelPack format (*modelpack.Model). - Config types.ModelConfig `json:"config"` + Config types.ModelConfig `json:"-"` + // RawConfig is used for JSON marshaling/unmarshaling + RawConfig json.RawMessage `json:"config"` +} + +// MarshalJSON implements custom marshaling for Model +func (m Model) MarshalJSON() ([]byte, error) { + // Define a temporary struct to avoid recursion + type Alias Model + aux := struct { + *Alias + RawConfig json.RawMessage `json:"config"` + }{ + Alias: (*Alias)(&m), + } + + // Marshal the config separately + if m.Config != nil { + configData, err := json.Marshal(m.Config) + if err != nil { + return nil, err + } + aux.RawConfig = configData + } else { + // If Config is nil, use the RawConfig if available + aux.RawConfig = m.RawConfig + } + + return json.Marshal(aux) } // UnmarshalJSON implements custom JSON unmarshaling for Model. diff --git a/pkg/responses/api.go b/pkg/responses/api.go new file mode 100644 index 00000000..214401f5 --- /dev/null +++ b/pkg/responses/api.go @@ -0,0 +1,404 @@ +// Package responses implements the OpenAI Responses API compatibility layer. +// The Responses API is a stateful API that combines chat completions with +// conversation state management and tool use capabilities. +package responses + +import ( + "crypto/rand" + "encoding/json" + "time" +) + +// APIPrefix is the URL prefix for the Responses API. +const APIPrefix = "/responses" + +// Response status values +const ( + StatusQueued = "queued" + StatusInProgress = "in_progress" + StatusCompleted = "completed" + StatusCancelled = "cancelled" + StatusFailed = "failed" +) + +// Content types +const ( + ContentTypeInputText = "input_text" + ContentTypeOutputText = "output_text" + ContentTypeInputImage = "input_image" + ContentTypeInputFile = "input_file" + ContentTypeRefusal = "refusal" + ContentTypeFunctionCall = "function_call" + ContentTypeFunctionCallOutput = "function_call_output" +) + +// Item types +const ( + ItemTypeMessage = "message" + ItemTypeFunctionCall = "function_call" + ItemTypeFunctionCallOutput = "function_call_output" +) + +// Streaming event types +const ( + EventResponseCreated = "response.created" + EventResponseInProgress = "response.in_progress" + EventResponseCompleted = "response.completed" + EventResponseFailed = "response.failed" + EventResponseIncomplete = "response.incomplete" + EventOutputItemAdded = "response.output_item.added" + EventOutputItemDone = "response.output_item.done" + EventContentPartAdded = "response.content_part.added" + EventContentPartDone = "response.content_part.done" + EventOutputTextDelta = "response.output_text.delta" + EventOutputTextDone = "response.output_text.done" + EventRefusalDelta = "response.refusal.delta" + EventRefusalDone = "response.refusal.done" + EventFunctionCallArgsDelta = "response.function_call_arguments.delta" + EventFunctionCallArgsDone = "response.function_call_arguments.done" + EventError = "error" +) + +// CreateRequest represents a request to create a response. +type CreateRequest struct { + // Model is the model to use for generating the response. + Model string `json:"model"` + + // Input is the input to the model. Can be a string or array of input items. + Input json.RawMessage `json:"input"` + + // Instructions is an optional system prompt/instructions for the model. + Instructions string `json:"instructions,omitempty"` + + // PreviousResponseID links this request to a previous response for conversation chaining. + PreviousResponseID string `json:"previous_response_id,omitempty"` + + // Tools is the list of tools available to the model. + Tools []Tool `json:"tools,omitempty"` + + // ToolChoice controls how the model uses tools. + ToolChoice interface{} `json:"tool_choice,omitempty"` + + // ParallelToolCalls enables parallel tool calls. + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + + // Temperature controls randomness (0-2). + Temperature *float64 `json:"temperature,omitempty"` + + // TopP controls nucleus sampling. + TopP *float64 `json:"top_p,omitempty"` + + // MaxOutputTokens limits the response length. + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + + // Stream enables streaming responses. + Stream bool `json:"stream,omitempty"` + + // Metadata is user-defined metadata for the response. + Metadata map[string]string `json:"metadata,omitempty"` + + // ReasoningEffort controls reasoning model effort (low, medium, high). + ReasoningEffort string `json:"reasoning_effort,omitempty"` + + // User is an optional user identifier. + User string `json:"user,omitempty"` +} + +// Response represents a complete response from the API. +type Response struct { + // ID is the unique identifier for this response. + ID string `json:"id"` + + // Object is always "response". + Object string `json:"object"` + + // CreatedAt is the Unix timestamp when the response was created. + CreatedAt float64 `json:"created_at"` + + // Model is the model used to generate the response. + Model string `json:"model"` + + // Status is the current status of the response. + Status string `json:"status"` + + // Output is the array of output items (messages, function calls, etc.). + Output []OutputItem `json:"output"` + + // OutputText is a convenience field containing concatenated text output. + OutputText string `json:"output_text,omitempty"` + + // Error contains error details if status is "failed". + Error *ErrorDetail `json:"error"` + + // IncompleteDetails contains details if the response was incomplete. + IncompleteDetails *IncompleteDetails `json:"incomplete_details"` + + // FinishReason contains the reason the model stopped generating (e.g., stop, length, function_call, etc.). + FinishReason string `json:"finish_reason,omitempty"` + + // Instructions is the system instructions used. + Instructions *string `json:"instructions"` + + // Metadata is user-defined metadata. + Metadata map[string]string `json:"metadata"` + + // ParallelToolCalls indicates if parallel tool calls were enabled. + ParallelToolCalls *bool `json:"parallel_tool_calls"` + + // Temperature used for generation. + Temperature *float64 `json:"temperature"` + + // ToolChoice used for generation. + ToolChoice interface{} `json:"tool_choice"` + + // Tools available during generation. + Tools []Tool `json:"tools"` + + // TopP used for generation. + TopP *float64 `json:"top_p"` + + // MaxOutputTokens limit used. + MaxOutputTokens *int `json:"max_output_tokens"` + + // PreviousResponseID is the ID of the previous response in the chain. + PreviousResponseID *string `json:"previous_response_id"` + + // Reasoning contains reasoning details for reasoning models. + Reasoning *ReasoningDetails `json:"reasoning"` + + // Usage contains token usage statistics. + Usage *Usage `json:"usage"` + + // User identifier if provided. + User *string `json:"user"` + + // ReasoningEffort used for reasoning models. + ReasoningEffort *string `json:"reasoning_effort"` +} + +// OutputItem represents an item in the response output. +type OutputItem struct { + // ID is the unique identifier for this output item. + ID string `json:"id"` + + // Type is the type of output item (message, function_call, etc.). + Type string `json:"type"` + + // Role is the role for message items (assistant). + Role string `json:"role,omitempty"` + + // Content is the content array for message items. + Content []ContentPart `json:"content,omitempty"` + + // Status is the status of this output item. + Status string `json:"status,omitempty"` + + // CallID is the ID for function call items. + CallID string `json:"call_id,omitempty"` + + // Name is the function name for function call items. + Name string `json:"name,omitempty"` + + // Arguments is the function arguments for function call items. + Arguments string `json:"arguments,omitempty"` + + // Output is the function output for function_call_output items. + Output string `json:"output,omitempty"` +} + +// ContentPart represents a part of content within an output item. +type ContentPart struct { + // Type is the content type (output_text, refusal, etc.). + Type string `json:"type"` + + // Text is the text content for output_text type. + Text string `json:"text,omitempty"` + + // Refusal is the refusal message for refusal type. + Refusal string `json:"refusal,omitempty"` + + // Annotations contains any annotations on the content. + Annotations []Annotation `json:"annotations,omitempty"` +} + +// Annotation represents an annotation on content. +type Annotation struct { + Type string `json:"type"` + StartIndex int `json:"start_index,omitempty"` + EndIndex int `json:"end_index,omitempty"` + URL string `json:"url,omitempty"` + Title string `json:"title,omitempty"` +} + +// InputItem represents an input item in the request. +type InputItem struct { + // Type is the type of input item. + Type string `json:"type,omitempty"` + + // Role is the role for message-style inputs. + Role string `json:"role,omitempty"` + + // Content can be a string or array of content parts. + Content json.RawMessage `json:"content,omitempty"` + + // CallID is for function_call_output items. + CallID string `json:"call_id,omitempty"` + + // Output is the function output for function_call_output items. + Output string `json:"output,omitempty"` + + // ID is for referencing items. + ID string `json:"id,omitempty"` + + // Name is for function calls. + Name string `json:"name,omitempty"` + + // Arguments is for function calls. + Arguments string `json:"arguments,omitempty"` +} + +// InputContentPart represents a content part in the input. +type InputContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` + FileID string `json:"file_id,omitempty"` + FileData string `json:"file_data,omitempty"` + Filename string `json:"filename,omitempty"` +} + +// Tool represents a tool available to the model. +type Tool struct { + // Type is the tool type (function, etc.). + Type string `json:"type"` + + // Name is the function name (for function tools). + Name string `json:"name,omitempty"` + + // Description is the function description. + Description string `json:"description,omitempty"` + + // Parameters is the JSON schema for function parameters. + Parameters interface{} `json:"parameters,omitempty"` + + // Function contains function details (alternative structure). + Function *FunctionDef `json:"function,omitempty"` +} + +// FunctionDef defines a function tool. +type FunctionDef struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +// Usage contains token usage statistics. +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + OutputTokensDetails *OutputTokensDetails `json:"output_tokens_details,omitempty"` +} + +// OutputTokensDetails contains detailed output token breakdown. +type OutputTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` +} + +// ErrorDetail contains error information. +type ErrorDetail struct { + Code string `json:"code,omitempty"` + Message string `json:"message"` +} + +// IncompleteDetails contains details about why a response was incomplete. +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` +} + +// ReasoningDetails contains reasoning information for reasoning models. +type ReasoningDetails struct { + EncryptedContent string `json:"encrypted_content,omitempty"` +} + +// StreamEvent represents a streaming event. +type StreamEvent struct { + // Type is the event type. + Type string `json:"type"` + + // SequenceNumber is the sequence number for ordering. + SequenceNumber int `json:"sequence_number"` + + // Response is included in response lifecycle events. + Response *Response `json:"response,omitempty"` + + // Item is included in output item events. + Item *OutputItem `json:"item,omitempty"` + + // OutputIndex is the index in the output array. + OutputIndex int `json:"output_index,omitempty"` + + // ContentIndex is the index within content array. + ContentIndex int `json:"content_index,omitempty"` + + // Part is the content part for content events. + Part *ContentPart `json:"part,omitempty"` + + // Delta is the text delta for delta events. + Delta string `json:"delta,omitempty"` + + // ItemID is the ID of the item being modified. + ItemID string `json:"item_id,omitempty"` + + // Error is included in error events. + Error *ErrorDetail `json:"error,omitempty"` +} + +// NewResponse creates a new Response with default values. +func NewResponse(id, model string) *Response { + return &Response{ + ID: id, + Object: "response", + CreatedAt: float64(time.Now().Unix()), + Model: model, + Status: StatusInProgress, + Output: []OutputItem{}, + Tools: []Tool{}, + Metadata: map[string]string{}, + } +} + +// GenerateResponseID generates a unique response ID. +func GenerateResponseID() string { + return "resp_" + GenerateID(24) +} + +// GenerateItemID generates a unique item ID. +func GenerateItemID() string { + return "item_" + GenerateID(24) +} + +// GenerateMessageID generates a unique message ID. +func GenerateMessageID() string { + return "msg_" + GenerateID(24) +} + +// GenerateCallID generates a unique call ID for function calls. +func GenerateCallID() string { + return "call_" + GenerateID(24) +} + +// GenerateID generates a random alphanumeric ID of the specified length. +func GenerateID(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, length) + if _, err := rand.Read(b); err != nil { + // A failure of crypto/rand indicates a critical problem with the OS's + // entropy source, so we should panic. + panic("failed to read random bytes for ID generation: " + err.Error()) + } + for i := range b { + b[i] = charset[b[i]%byte(len(charset))] + } + return string(b) +} diff --git a/pkg/responses/handler.go b/pkg/responses/handler.go new file mode 100644 index 00000000..125efe6c --- /dev/null +++ b/pkg/responses/handler.go @@ -0,0 +1,344 @@ +package responses + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/docker/model-runner/pkg/logging" + "github.com/docker/model-runner/pkg/middleware" +) + +// HTTPHandler handles Responses API HTTP requests. +type HTTPHandler struct { + log logging.Logger + router *http.ServeMux + httpHandler http.Handler + schedulerHTTP http.Handler + store *Store + maxRequestBodyBytes int64 +} + +// NewHTTPHandler creates a new Responses API handler. +func NewHTTPHandler(log logging.Logger, schedulerHTTP http.Handler, allowedOrigins []string) *HTTPHandler { + h := &HTTPHandler{ + log: log, + router: http.NewServeMux(), + schedulerHTTP: schedulerHTTP, + store: NewStore(DefaultTTL), + maxRequestBodyBytes: 10 * 1024 * 1024, // Default to 10MB + } + + // Register routes + h.router.HandleFunc("POST "+APIPrefix, h.handleCreate) + h.router.HandleFunc("GET "+APIPrefix+"/{id}", h.handleGet) + h.router.HandleFunc("GET "+APIPrefix+"/{id}/input_items", h.handleListInputItems) + h.router.HandleFunc("DELETE "+APIPrefix+"/{id}", h.handleDelete) + // Also register /v1/responses routes + h.router.HandleFunc("POST /v1"+APIPrefix, h.handleCreate) + h.router.HandleFunc("GET /v1"+APIPrefix+"/{id}", h.handleGet) + h.router.HandleFunc("GET /v1"+APIPrefix+"/{id}/input_items", h.handleListInputItems) + h.router.HandleFunc("DELETE /v1"+APIPrefix+"/{id}", h.handleDelete) + + // Apply CORS middleware + h.httpHandler = middleware.CorsMiddleware(allowedOrigins, h.router) + + return h +} + +// ServeHTTP implements http.Handler. +func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cleanPath := strings.ReplaceAll(r.URL.Path, "\n", "") + cleanPath = strings.ReplaceAll(cleanPath, "\r", "") + h.log.Infof("Responses API request: %s %s", r.Method, cleanPath) + h.httpHandler.ServeHTTP(w, r) +} + +// handleCreate handles POST /responses (or /v1/responses). +func (h *HTTPHandler) handleCreate(w http.ResponseWriter, r *http.Request) { + // Read request body with a configurable limit + reader := http.MaxBytesReader(w, r.Body, h.maxRequestBodyBytes) + body, err := io.ReadAll(reader) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + h.sendError( + w, + http.StatusRequestEntityTooLarge, + "request_too_large", + fmt.Sprintf("Request body too large (max %d bytes)", h.maxRequestBodyBytes), + ) + return + } + + h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body") + return + } + + // Parse request + var req CreateRequest + if err := json.Unmarshal(body, &req); err != nil { + h.sendError(w, http.StatusBadRequest, "invalid_request", "Invalid JSON: "+err.Error()) + return + } + + // Validate required fields + if req.Model == "" { + h.sendError(w, http.StatusBadRequest, "invalid_request", "model is required") + return + } + + // Create a new response + respID := GenerateResponseID() + resp := NewResponse(respID, req.Model) + resp.Instructions = nilIfEmpty(req.Instructions) + resp.Temperature = req.Temperature + resp.TopP = req.TopP + resp.MaxOutputTokens = req.MaxOutputTokens + resp.Tools = req.Tools + resp.ToolChoice = req.ToolChoice + resp.ParallelToolCalls = req.ParallelToolCalls + resp.Metadata = req.Metadata + if req.PreviousResponseID != "" { + resp.PreviousResponseID = &req.PreviousResponseID + } + if req.ReasoningEffort != "" { + resp.ReasoningEffort = &req.ReasoningEffort + } + if req.User != "" { + resp.User = &req.User + } + + // Transform to chat completion request + chatReq, err := TransformRequestToChatCompletion(&req, h.store) + if err != nil { + h.sendError(w, http.StatusBadRequest, "invalid_request", err.Error()) + return + } + + // Marshal chat request + chatBody, err := MarshalChatCompletionRequest(chatReq) + if err != nil { + h.sendError(w, http.StatusInternalServerError, "internal_error", "Failed to marshal request") + return + } + + // Create upstream request + upstreamReq, err := http.NewRequestWithContext(r.Context(), http.MethodPost, "/engines/v1/chat/completions", bytes.NewReader(chatBody)) + if err != nil { + h.sendError(w, http.StatusInternalServerError, "internal_error", "Failed to create request") + return + } + upstreamReq.Header.Set("Content-Type", "application/json") + // Copy relevant headers + if auth := r.Header.Get("Authorization"); auth != "" { + upstreamReq.Header.Set("Authorization", auth) + } + + if req.Stream { + // Handle streaming response + h.handleStreaming(w, upstreamReq, resp) + } else { + // Handle non-streaming response + h.handleNonStreaming(w, upstreamReq, resp) + } +} + +// handleStreaming handles streaming responses. +func (h *HTTPHandler) handleStreaming(w http.ResponseWriter, upstreamReq *http.Request, resp *Response) { + // Create streaming writer + streamWriter := NewStreamingResponseWriter(w, resp, h.store) + + // Forward to scheduler + h.schedulerHTTP.ServeHTTP(streamWriter, upstreamReq) +} + +// handleNonStreaming handles non-streaming responses. +func (h *HTTPHandler) handleNonStreaming(w http.ResponseWriter, upstreamReq *http.Request, resp *Response) { + // Capture upstream response + capture := NewNonStreamingResponseCapture() + + // Forward to scheduler + h.schedulerHTTP.ServeHTTP(capture, upstreamReq) + + // Check for errors + if capture.StatusCode != http.StatusOK { + // Try to parse error + var errResp struct { + Error struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + } `json:"error"` + } + if err := json.Unmarshal([]byte(capture.Body.String()), &errResp); err == nil && errResp.Error.Message != "" { + resp.Status = StatusFailed + resp.Error = &ErrorDetail{ + Code: errResp.Error.Code, + Message: errResp.Error.Message, + } + h.store.Save(resp) + h.sendJSON(w, capture.StatusCode, resp) + return + } + // Generic error + resp.Status = StatusFailed + resp.Error = &ErrorDetail{ + Code: "upstream_error", + Message: capture.Body.String(), + } + h.store.Save(resp) + h.sendJSON(w, capture.StatusCode, resp) + return + } + + // Parse chat completion response + var chatResp ChatCompletionResponse + if err := json.Unmarshal([]byte(capture.Body.String()), &chatResp); err != nil { + resp.Status = StatusFailed + resp.Error = &ErrorDetail{ + Code: "parse_error", + Message: "Failed to parse upstream response", + } + h.store.Save(resp) + h.sendJSON(w, http.StatusInternalServerError, resp) + return + } + + // Transform response + finalResp := TransformChatCompletionToResponse(&chatResp, resp.ID, resp.Model) + // Preserve request parameters + finalResp.Instructions = resp.Instructions + finalResp.Temperature = resp.Temperature + finalResp.TopP = resp.TopP + finalResp.MaxOutputTokens = resp.MaxOutputTokens + finalResp.Tools = resp.Tools + finalResp.ToolChoice = resp.ToolChoice + finalResp.ParallelToolCalls = resp.ParallelToolCalls + finalResp.Metadata = resp.Metadata + finalResp.PreviousResponseID = resp.PreviousResponseID + finalResp.ReasoningEffort = resp.ReasoningEffort + finalResp.User = resp.User + finalResp.CreatedAt = resp.CreatedAt + + // Store the response + h.store.Save(finalResp) + + // Send response + h.sendJSON(w, http.StatusOK, finalResp) +} + +// handleGet handles GET /responses/{id}. +func (h *HTTPHandler) handleGet(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + h.sendError(w, http.StatusBadRequest, "invalid_request", "Response ID is required") + return + } + + resp, ok := h.store.Get(id) + if !ok { + h.sendError(w, http.StatusNotFound, "not_found", "Response not found") + return + } + + // Check if streaming is requested + if r.URL.Query().Get("stream") == "true" { + // For completed responses, we can't re-stream + // Just return the response as JSON + h.sendJSON(w, http.StatusOK, resp) + return + } + + h.sendJSON(w, http.StatusOK, resp) +} + +// handleListInputItems handles GET /responses/{id}/input_items. +func (h *HTTPHandler) handleListInputItems(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + h.sendError(w, http.StatusBadRequest, "invalid_request", "Response ID is required") + return + } + + _, ok := h.store.Get(id) + if !ok { + h.sendError(w, http.StatusNotFound, "not_found", "Response not found") + return + } + + // For now, return an empty list since input items are not stored separately + // In a real implementation, this would return the input items associated with the response + h.sendJSON(w, http.StatusOK, map[string]interface{}{ + "object": "list", + "data": []interface{}{}, + }) +} + +// handleDelete handles DELETE /responses/{id}. +func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + h.sendError(w, http.StatusBadRequest, "invalid_request", "Response ID is required") + return + } + + if !h.store.Delete(id) { + h.sendError(w, http.StatusNotFound, "not_found", "Response not found") + return + } + + h.sendJSON(w, http.StatusOK, map[string]interface{}{ + "id": id, + "object": "response.deleted", + "deleted": true, + }) +} + +// sendJSON sends a JSON response. +func (h *HTTPHandler) sendJSON(w http.ResponseWriter, statusCode int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(data); err != nil { + h.log.Errorf("Failed to encode JSON response: %v", err) + } +} + +// sendError sends an error response. +func (h *HTTPHandler) sendError(w http.ResponseWriter, statusCode int, code, message string) { + h.sendJSON(w, statusCode, map[string]interface{}{ + "error": map[string]interface{}{ + "code": code, + "message": message, + }, + }) +} + +// nilIfEmpty returns a pointer to the string if non-empty, otherwise nil. +func nilIfEmpty(s string) *string { + if s == "" { + return nil + } + return &s +} + +// GetStore returns the response store (for testing). +func (h *HTTPHandler) GetStore() *Store { + return h.store +} + +// SetMaxRequestBodyBytes sets the maximum request body size in bytes. +func (h *HTTPHandler) SetMaxRequestBodyBytes(bytes int64) { + h.maxRequestBodyBytes = bytes +} + +// ResponseWithTimestamp adds a timestamp helper. +func ResponseWithTimestamp(resp *Response) *Response { + resp.CreatedAt = float64(time.Now().Unix()) + return resp +} diff --git a/pkg/responses/handler_test.go b/pkg/responses/handler_test.go new file mode 100644 index 00000000..9a2444fe --- /dev/null +++ b/pkg/responses/handler_test.go @@ -0,0 +1,707 @@ +package responses + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/sirupsen/logrus" +) + +// mockSchedulerHTTP is a mock scheduler that returns predefined responses. +type mockSchedulerHTTP struct { + response string + statusCode int + streaming bool + streamChunks []string +} + +func (m *mockSchedulerHTTP) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if m.streaming { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + for _, chunk := range m.streamChunks { + w.Write([]byte(chunk)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(m.statusCode) + w.Write([]byte(m.response)) +} + +func newTestHandler(mock *mockSchedulerHTTP) *HTTPHandler { + log := logrus.New() + log.SetOutput(io.Discard) + return NewHTTPHandler(log, mock, nil) +} + +func TestHandler_CreateResponse_NonStreaming(t *testing.T) { + mock := &mockSchedulerHTTP{ + statusCode: http.StatusOK, + response: `{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I help you?" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 7, + "total_tokens": 17 + } + }`, + } + + handler := newTestHandler(mock) + + reqBody := `{ + "model": "gpt-4", + "input": "Hello" + }` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + var result Response + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if result.Object != "response" { + t.Errorf("object = %s, want response", result.Object) + } + if result.Model != "gpt-4" { + t.Errorf("model = %s, want gpt-4", result.Model) + } + if result.Status != StatusCompleted { + t.Errorf("status = %s, want %s", result.Status, StatusCompleted) + } + if result.OutputText != "Hello! How can I help you?" { + t.Errorf("output_text = %s, want Hello! How can I help you?", result.OutputText) + } + if !strings.HasPrefix(result.ID, "resp_") { + t.Errorf("id should start with resp_, got %s", result.ID) + } +} + +func TestHandler_CreateResponse_MissingModel(t *testing.T) { + mock := &mockSchedulerHTTP{} + handler := newTestHandler(mock) + + reqBody := `{"input": "Hello"}` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } + + var errResp map[string]interface{} + json.NewDecoder(resp.Body).Decode(&errResp) + + if errResp["error"] == nil { + t.Error("expected error in response") + } +} + +func TestHandler_CreateResponse_InvalidJSON(t *testing.T) { + mock := &mockSchedulerHTTP{} + handler := newTestHandler(mock) + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{invalid`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestHandler_GetResponse(t *testing.T) { + mock := &mockSchedulerHTTP{} + handler := newTestHandler(mock) + + // First, store a response + testResp := NewResponse("resp_test123", "gpt-4") + testResp.Status = StatusCompleted + testResp.OutputText = "Test output" + handler.store.Save(testResp) + + // Now retrieve it + req := httptest.NewRequest(http.MethodGet, "/v1/responses/resp_test123", http.NoBody) + req.SetPathValue("id", "resp_test123") + w := httptest.NewRecorder() + + handler.handleGet(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + var result Response + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if result.ID != "resp_test123" { + t.Errorf("id = %s, want resp_test123", result.ID) + } + if result.OutputText != "Test output" { + t.Errorf("output_text = %s, want Test output", result.OutputText) + } +} + +func TestHandler_GetResponse_NotFound(t *testing.T) { + mock := &mockSchedulerHTTP{} + handler := newTestHandler(mock) + + req := httptest.NewRequest(http.MethodGet, "/v1/responses/nonexistent", http.NoBody) + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + handler.handleGet(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestHandler_DeleteResponse(t *testing.T) { + mock := &mockSchedulerHTTP{} + handler := newTestHandler(mock) + + // First, store a response + testResp := NewResponse("resp_test123", "gpt-4") + handler.store.Save(testResp) + + // Delete it + req := httptest.NewRequest(http.MethodDelete, "/v1/responses/resp_test123", http.NoBody) + req.SetPathValue("id", "resp_test123") + w := httptest.NewRecorder() + + handler.handleDelete(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + // Verify it's deleted + _, ok := handler.store.Get("resp_test123") + if ok { + t.Error("expected response to be deleted") + } +} + +func TestHandler_DeleteResponse_NotFound(t *testing.T) { + mock := &mockSchedulerHTTP{} + handler := newTestHandler(mock) + + req := httptest.NewRequest(http.MethodDelete, "/v1/responses/nonexistent", http.NoBody) + req.SetPathValue("id", "nonexistent") + w := httptest.NewRecorder() + + handler.handleDelete(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusNotFound) + } +} + +func TestHandler_CreateResponse_WithPreviousResponse(t *testing.T) { + mock := &mockSchedulerHTTP{ + statusCode: http.StatusOK, + response: `{ + "id": "chatcmpl-456", + "object": "chat.completion", + "created": 1234567891, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I'm doing well, thanks!" + }, + "finish_reason": "stop" + } + ] + }`, + } + + handler := newTestHandler(mock) + + // Create a previous response + prevResp := NewResponse("resp_prev123", "gpt-4") + prevResp.Status = StatusCompleted + prevResp.Output = []OutputItem{ + { + ID: "msg_1", + Type: ItemTypeMessage, + Role: "assistant", + Content: []ContentPart{ + {Type: ContentTypeOutputText, Text: "Hello!"}, + }, + }, + } + handler.store.Save(prevResp) + + // Create new request chained to previous + reqBody := `{ + "model": "gpt-4", + "input": "How are you?", + "previous_response_id": "resp_prev123" + }` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) + } + + var result Response + json.NewDecoder(resp.Body).Decode(&result) + + if result.PreviousResponseID == nil || *result.PreviousResponseID != "resp_prev123" { + t.Errorf("previous_response_id = %v, want resp_prev123", result.PreviousResponseID) + } +} + +func TestHandler_CreateResponse_UpstreamError(t *testing.T) { + mock := &mockSchedulerHTTP{ + statusCode: http.StatusInternalServerError, + response: `{ + "error": { + "message": "Model overloaded", + "type": "server_error", + "code": "model_overloaded" + } + }`, + } + + handler := newTestHandler(mock) + + reqBody := `{ + "model": "gpt-4", + "input": "Hello" + }` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError) + } + + var result Response + json.NewDecoder(resp.Body).Decode(&result) + + if result.Status != StatusFailed { + t.Errorf("status = %s, want %s", result.Status, StatusFailed) + } + if result.Error == nil { + t.Error("expected error to be set") + } +} + +func TestHandler_CreateResponse_UpstreamError_NonJSONBody(t *testing.T) { + mock := &mockSchedulerHTTP{ + statusCode: http.StatusInternalServerError, + // non-JSON / malformed body to exercise the fallback branch in handleNonStreaming + response: "upstream exploded in a non-json way", + } + + handler := newTestHandler(mock) + + reqBody := `{ + "model": "gpt-4", + "input": "Hello" + }` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusInternalServerError) + } + + var result Response + json.NewDecoder(resp.Body).Decode(&result) + + // Assert: non-streaming error handling falls back correctly + if result.Status != StatusFailed { + t.Errorf("status = %s, want %s", result.Status, StatusFailed) + } + + if result.Error == nil { + t.Fatalf("expected error, got nil") + } + + if result.Error.Code != "upstream_error" { + t.Errorf("error.code = %v, want upstream_error", result.Error.Code) + } + + if !strings.Contains(result.Error.Message, "upstream exploded in a non-json way") { + t.Errorf("error.message = %q, want to contain raw upstream body", result.Error.Message) + } +} + +func TestHandler_CreateResponse_Streaming(t *testing.T) { + // Mock streaming response + mock := &mockSchedulerHTTP{ + streaming: true, + streamChunks: []string{ + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n", + "data: [DONE]\n\n", + }, + } + + handler := newTestHandler(mock) + + reqBody := `{ + "model": "gpt-4", + "input": "Hello", + "stream": true + }` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + // Check content type is SSE + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/event-stream") { + t.Errorf("Content-Type = %s, want text/event-stream", contentType) + } + + // Read all body + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + bodyStr := string(body) + + // Verify we got the expected events + if !strings.Contains(bodyStr, "response.created") { + t.Error("expected response.created event") + } + if !strings.Contains(bodyStr, "response.output_text.delta") { + t.Error("expected response.output_text.delta event") + } + if !strings.Contains(bodyStr, "response.completed") { + t.Error("expected response.completed event") + } +} + +func TestHandler_CreateResponse_WithTools(t *testing.T) { + // Mock response with tool call + mock := &mockSchedulerHTTP{ + statusCode: http.StatusOK, + response: `{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"San Francisco\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ] + }`, + } + + handler := newTestHandler(mock) + + reqBody := `{ + "model": "gpt-4", + "input": "What's the weather in San Francisco?", + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + } + } + ] + }` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("status = %d, want %d, body: %s", resp.StatusCode, http.StatusOK, body) + } + + var result Response + json.NewDecoder(resp.Body).Decode(&result) + + if len(result.Output) == 0 { + t.Fatal("expected output items") + } + + // Find the function call item + var funcCall *OutputItem + for i := range result.Output { + if result.Output[i].Type == ItemTypeFunctionCall { + funcCall = &result.Output[i] + break + } + } + + if funcCall == nil { + t.Fatal("expected function call in output") + } + + if funcCall.Name != "get_weather" { + t.Errorf("function name = %s, want get_weather", funcCall.Name) + } + if funcCall.CallID != "call_abc123" { + t.Errorf("call_id = %s, want call_abc123", funcCall.CallID) + } +} + +// Test that stored responses persist across requests +func TestHandler_ResponsePersistence(t *testing.T) { + mock := &mockSchedulerHTTP{ + statusCode: http.StatusOK, + response: `{ + "id": "chatcmpl-123", + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + } + } + ] + }`, + } + + handler := newTestHandler(mock) + + // Create a response + reqBody := `{"model": "gpt-4", "input": "Hi"}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + var createResult Response + json.NewDecoder(w.Result().Body).Decode(&createResult) + + // Retrieve it + req2 := httptest.NewRequest(http.MethodGet, "/v1/responses/"+createResult.ID, http.NoBody) + req2.SetPathValue("id", createResult.ID) + w2 := httptest.NewRecorder() + + handler.handleGet(w2, req2) + + var getResult Response + json.NewDecoder(w2.Result().Body).Decode(&getResult) + + if getResult.ID != createResult.ID { + t.Errorf("IDs don't match: %s vs %s", getResult.ID, createResult.ID) + } +} + +// Test that streaming responses are properly persisted in the store +func TestHandler_CreateResponse_Streaming_Persistence(t *testing.T) { + // Mock streaming response + mock := &mockSchedulerHTTP{ + streaming: true, + streamChunks: []string{ + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n", + "data: [DONE]\n\n", + }, + } + + handler := newTestHandler(mock) + + reqBody := `{ + "model": "gpt-4", + "input": "Hello", + "stream": true + }` + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + handler.handleCreate(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + // Verify that the StreamingResponseWriter persisted a coherent Response in the store + memStore := handler.store + + if memStore.Count() != 1 { + t.Fatalf("expected exactly one response in store, got %d", memStore.Count()) + } + + // Get the response ID from the store + responseIDs := memStore.GetResponseIDs() + if len(responseIDs) != 1 { + t.Fatalf("expected exactly one response ID in store, got %d", len(responseIDs)) + } + + // Retrieve the response using the public API + persistedResp, ok := memStore.Get(responseIDs[0]) + if !ok { + t.Fatal("expected to retrieve persisted Response from store") + } + + // Status should be completed after streaming finishes + if persistedResp.Status != StatusCompleted { + t.Errorf("persisted response status = %s, want %s", persistedResp.Status, StatusCompleted) + } + + // OutputText should match concatenated streamed chunks: "Hello" + "!" => "Hello!" + if persistedResp.OutputText != "Hello!" { + t.Errorf("persisted response OutputText = %q, want %q", persistedResp.OutputText, "Hello!") + } + + // There should be at least one OutputItem whose message content matches "Hello!" + found := false + for _, item := range persistedResp.Output { + if item.Type != ItemTypeMessage { + continue + } + // Check if the message contains the expected text + for _, contentPart := range item.Content { + if contentPart.Type == ContentTypeOutputText && contentPart.Text == "Hello!" { + found = true + break + } + } + if found { + break + } + } + if !found { + t.Errorf("expected an OutputItem message with text %q in persisted response", "Hello!") + } +} + +// Benchmark for response creation +func BenchmarkHandler_CreateResponse(b *testing.B) { + mock := &mockSchedulerHTTP{ + statusCode: http.StatusOK, + response: `{ + "id": "chatcmpl-123", + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + } + } + ] + }`, + } + + handler := newTestHandler(mock) + reqBody := []byte(`{"model": "gpt-4", "input": "Hello"}`) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + handler.handleCreate(w, req) + } +} diff --git a/pkg/responses/store.go b/pkg/responses/store.go new file mode 100644 index 00000000..85743484 --- /dev/null +++ b/pkg/responses/store.go @@ -0,0 +1,123 @@ +package responses + +import ( + "sync" + "time" +) + +// DefaultTTL is the default time-to-live for stored responses. +const DefaultTTL = 1 * time.Hour + +// Store provides in-memory storage for responses with TTL-based expiration. +type Store struct { + mu sync.RWMutex + responses map[string]*storedResponse + ttl time.Duration +} + +type storedResponse struct { + response *Response + expiresAt time.Time +} + +// NewStore creates a new response store with the given TTL. +func NewStore(ttl time.Duration) *Store { + if ttl <= 0 { + ttl = DefaultTTL + } + s := &Store{ + responses: make(map[string]*storedResponse), + ttl: ttl, + } + // Start background cleanup goroutine + go s.cleanupLoop() + return s +} + +// Save stores a response. +func (s *Store) Save(resp *Response) { + s.mu.Lock() + defer s.mu.Unlock() + s.responses[resp.ID] = &storedResponse{ + response: resp, + expiresAt: time.Now().Add(s.ttl), + } +} + +// Get retrieves a response by ID. +func (s *Store) Get(id string) (*Response, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + stored, ok := s.responses[id] + if !ok { + return nil, false + } + if time.Now().After(stored.expiresAt) { + return nil, false + } + return stored.response, true +} + +// Delete removes a response by ID. +func (s *Store) Delete(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.responses[id]; ok { + delete(s.responses, id) + return true + } + return false +} + +// Update updates a response in place. +func (s *Store) Update(id string, updateFn func(*Response)) bool { + s.mu.Lock() + defer s.mu.Unlock() + stored, ok := s.responses[id] + if !ok || time.Now().After(stored.expiresAt) { + return false + } + updateFn(stored.response) + // Refresh TTL on update + stored.expiresAt = time.Now().Add(s.ttl) + return true +} + +// cleanupLoop periodically removes expired responses. +func (s *Store) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + s.cleanup() + } +} + +// cleanup removes all expired responses. +func (s *Store) cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + now := time.Now() + for id, stored := range s.responses { + if now.After(stored.expiresAt) { + delete(s.responses, id) + } + } +} + +// Count returns the number of stored responses. +func (s *Store) Count() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.responses) +} + +// GetResponseIDs returns all response IDs in the store (for testing purposes). +func (s *Store) GetResponseIDs() []string { + s.mu.RLock() + defer s.mu.RUnlock() + ids := make([]string, 0, len(s.responses)) + for id := range s.responses { + ids = append(ids, id) + } + return ids +} diff --git a/pkg/responses/store_test.go b/pkg/responses/store_test.go new file mode 100644 index 00000000..4d61fea4 --- /dev/null +++ b/pkg/responses/store_test.go @@ -0,0 +1,141 @@ +package responses + +import ( + "testing" + "time" +) + +func TestStore_SaveAndGet(t *testing.T) { + store := NewStore(1 * time.Hour) + + resp := NewResponse("resp_test123", "gpt-4") + resp.Status = StatusCompleted + resp.OutputText = "Hello, world!" + + store.Save(resp) + + // Get should return the response + got, ok := store.Get("resp_test123") + if !ok { + t.Fatal("expected to find response") + } + if got.ID != resp.ID { + t.Errorf("got ID %s, want %s", got.ID, resp.ID) + } + if got.OutputText != resp.OutputText { + t.Errorf("got OutputText %s, want %s", got.OutputText, resp.OutputText) + } +} + +func TestStore_GetNotFound(t *testing.T) { + store := NewStore(1 * time.Hour) + + _, ok := store.Get("nonexistent") + if ok { + t.Error("expected response not to be found") + } +} + +func TestStore_Delete(t *testing.T) { + store := NewStore(1 * time.Hour) + + resp := NewResponse("resp_test123", "gpt-4") + store.Save(resp) + + // Delete should succeed + if !store.Delete("resp_test123") { + t.Error("expected delete to succeed") + } + + // Get should now fail + _, ok := store.Get("resp_test123") + if ok { + t.Error("expected response to be deleted") + } + + // Delete again should return false + if store.Delete("resp_test123") { + t.Error("expected second delete to return false") + } +} + +func TestStore_Update(t *testing.T) { + store := NewStore(1 * time.Hour) + + resp := NewResponse("resp_test123", "gpt-4") + resp.Status = StatusInProgress + store.Save(resp) + + // Update should succeed + ok := store.Update("resp_test123", func(r *Response) { + r.Status = StatusCompleted + r.OutputText = "Updated content" + }) + if !ok { + t.Error("expected update to succeed") + } + + // Get should return updated response + got, ok := store.Get("resp_test123") + if !ok { + t.Fatal("expected to find response") + } + if got.Status != StatusCompleted { + t.Errorf("got Status %s, want %s", got.Status, StatusCompleted) + } + if got.OutputText != "Updated content" { + t.Errorf("got OutputText %s, want Updated content", got.OutputText) + } +} + +func TestStore_UpdateNotFound(t *testing.T) { + store := NewStore(1 * time.Hour) + + ok := store.Update("nonexistent", func(r *Response) { + r.Status = StatusCompleted + }) + if ok { + t.Error("expected update to fail for nonexistent response") + } +} + +func TestStore_Count(t *testing.T) { + store := NewStore(1 * time.Hour) + + if store.Count() != 0 { + t.Errorf("expected count 0, got %d", store.Count()) + } + + store.Save(NewResponse("resp_1", "gpt-4")) + store.Save(NewResponse("resp_2", "gpt-4")) + + if store.Count() != 2 { + t.Errorf("expected count 2, got %d", store.Count()) + } +} + +func TestStore_TTLExpiration(t *testing.T) { + // Use a very short TTL for testing + store := &Store{ + responses: make(map[string]*storedResponse), + ttl: 1 * time.Millisecond, + } + + resp := NewResponse("resp_test123", "gpt-4") + store.Save(resp) + + // Should be found immediately + _, ok := store.Get("resp_test123") + if !ok { + t.Fatal("expected to find response immediately after save") + } + + // Wait for TTL to expire + time.Sleep(10 * time.Millisecond) + + // Should not be found after expiration + _, ok = store.Get("resp_test123") + if ok { + t.Error("expected response to be expired") + } +} diff --git a/pkg/responses/streaming.go b/pkg/responses/streaming.go new file mode 100644 index 00000000..bd06b8e1 --- /dev/null +++ b/pkg/responses/streaming.go @@ -0,0 +1,580 @@ +package responses + +import ( + "bufio" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +// StreamingResponseWriter transforms chat completion SSE events to Responses API SSE events. +type StreamingResponseWriter struct { + w http.ResponseWriter + flusher http.Flusher + response *Response + store *Store + sequenceNumber int + headersSent bool + buffer strings.Builder + + // State for building the response + currentItemID string + currentContentIdx int + accumulatedContent strings.Builder + toolCalls []OutputItem +} + +// NewStreamingResponseWriter creates a new streaming response writer. +func NewStreamingResponseWriter(w http.ResponseWriter, resp *Response, store *Store) *StreamingResponseWriter { + flusher, _ := w.(http.Flusher) + return &StreamingResponseWriter{ + w: w, + flusher: flusher, + response: resp, + store: store, + } +} + +// Header returns the header map. +func (s *StreamingResponseWriter) Header() http.Header { + return s.w.Header() +} + +// WriteHeader writes the HTTP status code. +func (s *StreamingResponseWriter) WriteHeader(statusCode int) { + if s.headersSent { + return + } + s.headersSent = true + + if statusCode != http.StatusOK { + // Send error event before writing the status code + s.response.Status = StatusFailed + s.sendEvent(EventError, &StreamEvent{ + Type: EventError, + SequenceNumber: s.nextSeq(), + Error: &ErrorDetail{ + Code: "upstream_error", + Message: fmt.Sprintf("Upstream service returned status code: %d", statusCode), + }, + }) + + // Send response.failed event + s.sendEvent(EventResponseFailed, &StreamEvent{ + Type: EventResponseFailed, + SequenceNumber: s.nextSeq(), + Response: s.response, + }) + + // Store the failed response + if s.store != nil { + s.store.Save(s.response) + } + + s.w.WriteHeader(statusCode) + return + } + + s.w.Header().Set("Content-Type", "text/event-stream") + s.w.Header().Set("Cache-Control", "no-cache") + s.w.Header().Set("Connection", "keep-alive") + s.w.WriteHeader(statusCode) + + // Send response.created event + s.sendEvent(EventResponseCreated, &StreamEvent{ + Type: EventResponseCreated, + SequenceNumber: s.nextSeq(), + Response: s.response, + }) + + // Send response.in_progress event + s.response.Status = StatusInProgress + s.sendEvent(EventResponseInProgress, &StreamEvent{ + Type: EventResponseInProgress, + SequenceNumber: s.nextSeq(), + Response: s.response, + }) +} + +// Write processes incoming chat completion SSE data. +func (s *StreamingResponseWriter) Write(data []byte) (int, error) { + if !s.headersSent { + s.WriteHeader(http.StatusOK) + } + + // Buffer the data + s.buffer.Write(data) + + // Process complete lines + bufferStr := s.buffer.String() + lines := strings.Split(bufferStr, "\n") + + // Keep incomplete line in buffer + if !strings.HasSuffix(bufferStr, "\n") && len(lines) > 0 { + s.buffer.Reset() + s.buffer.WriteString(lines[len(lines)-1]) + lines = lines[:len(lines)-1] + } else { + s.buffer.Reset() + } + + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data: ") { + continue + } + + dataStr := strings.TrimPrefix(line, "data: ") + + // Detect and propagate upstream metadata-only final chunks (e.g. usage, finish_reason) + // before we consider the stream finalized. This ensures that Response.Usage and + // Response.Status.FinishReason are preserved even when the final chunk has no text. + if dataStr != "" && dataStr != "[DONE]" { + var metaEnvelope struct { + Usage *Usage `json:"usage,omitempty"` + Choices []struct { + FinishReason string `json:"finish_reason,omitempty"` + Delta json.RawMessage `json:"delta,omitempty"` + } `json:"choices,omitempty"` + } + + if err := json.Unmarshal([]byte(dataStr), &metaEnvelope); err != nil { + // Send error event for malformed JSON in metadata chunk + s.response.Status = StatusFailed + s.sendEvent(EventError, &StreamEvent{ + Type: EventError, + SequenceNumber: s.nextSeq(), + Error: &ErrorDetail{ + Code: "parse_error", + Message: fmt.Sprintf("Failed to parse SSE metadata chunk: %v", err), + }, + }) + + // Send response.failed event + s.sendEvent(EventResponseFailed, &StreamEvent{ + Type: EventResponseFailed, + SequenceNumber: s.nextSeq(), + Response: s.response, + }) + + // Store the failed response + if s.store != nil { + s.store.Save(s.response) + } + return len(data), nil + } + + // If upstream sent usage only in the final chunk, capture it once. + if metaEnvelope.Usage != nil && s.response != nil && s.response.Usage == nil { + s.response.Usage = &Usage{ + InputTokens: metaEnvelope.Usage.InputTokens, + OutputTokens: metaEnvelope.Usage.OutputTokens, + TotalTokens: metaEnvelope.Usage.TotalTokens, + } + } + + // If we have a finish_reason but an empty delta, treat this as a + // metadata-only final chunk and propagate the finish state. + if s.response != nil { + for _, choice := range metaEnvelope.Choices { + if choice.FinishReason != "" { + s.response.FinishReason = choice.FinishReason + // Update status based on finish reason + switch choice.FinishReason { + case "stop", "tool_calls": + s.response.Status = StatusCompleted + case "length": + s.response.Status = StatusCompleted // or potentially a different status for truncation + if s.response.IncompleteDetails == nil { + s.response.IncompleteDetails = &IncompleteDetails{Reason: "max_tokens"} + } + case "content_filter": + s.response.Status = StatusFailed + if s.response.Error == nil { + s.response.Error = &ErrorDetail{ + Code: "content_filter", + Message: "Content filtered", + } + } + } + break + } + } + } + } + + if dataStr == "[DONE]" { + s.finalize() + continue + } + + s.processChunk(dataStr) + } + + return len(data), nil +} + +// processChunk processes a single SSE chunk from chat completions. +func (s *StreamingResponseWriter) processChunk(dataStr string) { + var chunk ChatStreamChunk + if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { + // Send error event for malformed JSON + s.response.Status = StatusFailed + s.sendEvent(EventError, &StreamEvent{ + Type: EventError, + SequenceNumber: s.nextSeq(), + Error: &ErrorDetail{ + Code: "parse_error", + Message: fmt.Sprintf("Failed to parse SSE chunk: %v", err), + }, + }) + + // Send response.failed event + s.sendEvent(EventResponseFailed, &StreamEvent{ + Type: EventResponseFailed, + SequenceNumber: s.nextSeq(), + Response: s.response, + }) + + // Store the failed response + if s.store != nil { + s.store.Save(s.response) + } + return + } + + if len(chunk.Choices) == 0 { + return + } + + delta := chunk.Choices[0].Delta + + // Handle tool calls + if len(delta.ToolCalls) > 0 { + s.handleToolCallDelta(delta.ToolCalls) + return + } + + // Handle content delta + if delta.Content != "" { + s.handleContentDelta(delta.Content) + } +} + +// handleContentDelta handles a content delta from the chat completion stream. +func (s *StreamingResponseWriter) handleContentDelta(content string) { + // Initialize message item if needed + if s.currentItemID == "" { + s.currentItemID = GenerateMessageID() + s.currentContentIdx = 0 + + // Send output_item.added + item := &OutputItem{ + ID: s.currentItemID, + Type: ItemTypeMessage, + Role: "assistant", + Content: []ContentPart{{ + Type: ContentTypeOutputText, + Text: "", + Annotations: []Annotation{}, + }}, + Status: StatusInProgress, + } + s.sendEvent(EventOutputItemAdded, &StreamEvent{ + Type: EventOutputItemAdded, + SequenceNumber: s.nextSeq(), + Item: item, + OutputIndex: 0, + }) + + // Send content_part.added + s.sendEvent(EventContentPartAdded, &StreamEvent{ + Type: EventContentPartAdded, + SequenceNumber: s.nextSeq(), + ItemID: s.currentItemID, + OutputIndex: 0, + ContentIndex: 0, + Part: &ContentPart{ + Type: ContentTypeOutputText, + Text: "", + Annotations: []Annotation{}, + }, + }) + } + + // Accumulate content + s.accumulatedContent.WriteString(content) + + // Send output_text.delta + s.sendEvent(EventOutputTextDelta, &StreamEvent{ + Type: EventOutputTextDelta, + SequenceNumber: s.nextSeq(), + ItemID: s.currentItemID, + OutputIndex: 0, + ContentIndex: 0, + Delta: content, + }) +} + +// handleToolCallDelta handles tool call deltas from the chat completion stream. +func (s *StreamingResponseWriter) handleToolCallDelta(toolCalls []ChatToolCall) { + for _, tc := range toolCalls { + // Find or create the tool call item + var item *OutputItem + for i := range s.toolCalls { + if s.toolCalls[i].CallID == tc.ID { + item = &s.toolCalls[i] + break + } + } + + if item == nil { + // New tool call + callID := tc.ID + if callID == "" { + callID = GenerateCallID() + } + newItem := OutputItem{ + ID: GenerateItemID(), + Type: ItemTypeFunctionCall, + CallID: callID, + Name: tc.Function.Name, + Arguments: "", + Status: StatusInProgress, + } + s.toolCalls = append(s.toolCalls, newItem) + item = &s.toolCalls[len(s.toolCalls)-1] + + // Send output_item.added for function call + s.sendEvent(EventOutputItemAdded, &StreamEvent{ + Type: EventOutputItemAdded, + SequenceNumber: s.nextSeq(), + Item: item, + OutputIndex: len(s.toolCalls) - 1, + }) + } + + // Accumulate arguments + if tc.Function.Arguments != "" { + item.Arguments += tc.Function.Arguments + + // Send function_call_arguments.delta + s.sendEvent(EventFunctionCallArgsDelta, &StreamEvent{ + Type: EventFunctionCallArgsDelta, + SequenceNumber: s.nextSeq(), + ItemID: item.ID, + OutputIndex: len(s.toolCalls) - 1, + Delta: tc.Function.Arguments, + }) + } + } +} + +// finalize completes the streaming response. +func (s *StreamingResponseWriter) finalize() { + // Finalize any accumulated content + if s.currentItemID != "" { + finalText := s.accumulatedContent.String() + + // Send output_text.done + s.sendEvent(EventOutputTextDone, &StreamEvent{ + Type: EventOutputTextDone, + SequenceNumber: s.nextSeq(), + ItemID: s.currentItemID, + OutputIndex: 0, + ContentIndex: 0, + Part: &ContentPart{ + Type: ContentTypeOutputText, + Text: finalText, + Annotations: []Annotation{}, + }, + }) + + // Send content_part.done + s.sendEvent(EventContentPartDone, &StreamEvent{ + Type: EventContentPartDone, + SequenceNumber: s.nextSeq(), + ItemID: s.currentItemID, + OutputIndex: 0, + ContentIndex: 0, + Part: &ContentPart{ + Type: ContentTypeOutputText, + Text: finalText, + Annotations: []Annotation{}, + }, + }) + + // Send output_item.done for message + s.sendEvent(EventOutputItemDone, &StreamEvent{ + Type: EventOutputItemDone, + SequenceNumber: s.nextSeq(), + OutputIndex: 0, + Item: &OutputItem{ + ID: s.currentItemID, + Type: ItemTypeMessage, + Role: "assistant", + Content: []ContentPart{{ + Type: ContentTypeOutputText, + Text: finalText, + Annotations: []Annotation{}, + }}, + Status: StatusCompleted, + }, + }) + + // Add to response output + s.response.Output = append(s.response.Output, OutputItem{ + ID: s.currentItemID, + Type: ItemTypeMessage, + Role: "assistant", + Content: []ContentPart{{ + Type: ContentTypeOutputText, + Text: finalText, + Annotations: []Annotation{}, + }}, + Status: StatusCompleted, + }) + s.response.OutputText = finalText + } + + // Finalize tool calls + for i, tc := range s.toolCalls { + // Send function_call_arguments.done + s.sendEvent(EventFunctionCallArgsDone, &StreamEvent{ + Type: EventFunctionCallArgsDone, + SequenceNumber: s.nextSeq(), + ItemID: tc.ID, + OutputIndex: i, + Delta: tc.Arguments, + }) + + // Send output_item.done for function call + tc.Status = StatusCompleted + s.sendEvent(EventOutputItemDone, &StreamEvent{ + Type: EventOutputItemDone, + SequenceNumber: s.nextSeq(), + OutputIndex: i, + Item: &tc, + }) + + // Add to response output + s.response.Output = append(s.response.Output, tc) + } + + // Update response status + s.response.Status = StatusCompleted + + // Send response.completed + s.sendEvent(EventResponseCompleted, &StreamEvent{ + Type: EventResponseCompleted, + SequenceNumber: s.nextSeq(), + Response: s.response, + }) + + // Store the final response + if s.store != nil { + s.store.Save(s.response) + } +} + +// sendEvent sends an SSE event. +func (s *StreamingResponseWriter) sendEvent(eventType string, event *StreamEvent) { + data, err := json.Marshal(event) + if err != nil { + return + } + + fmt.Fprintf(s.w, "event: %s\n", eventType) + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if s.flusher != nil { + s.flusher.Flush() + } +} + +// nextSeq returns the next sequence number. +func (s *StreamingResponseWriter) nextSeq() int { + s.sequenceNumber++ + return s.sequenceNumber +} + +// NonStreamingResponseCapture captures a non-streaming response. +type NonStreamingResponseCapture struct { + StatusCode int + Headers http.Header + Body strings.Builder +} + +// NewNonStreamingResponseCapture creates a new response capture. +func NewNonStreamingResponseCapture() *NonStreamingResponseCapture { + return &NonStreamingResponseCapture{ + StatusCode: http.StatusOK, + Headers: make(http.Header), + } +} + +// Header returns the header map. +func (c *NonStreamingResponseCapture) Header() http.Header { + return c.Headers +} + +// Write writes data to the body. +func (c *NonStreamingResponseCapture) Write(data []byte) (int, error) { + return c.Body.Write(data) +} + +// WriteHeader sets the status code. +func (c *NonStreamingResponseCapture) WriteHeader(statusCode int) { + c.StatusCode = statusCode +} + +// ProcessSSEStream reads an SSE stream from a reader and processes it. +func ProcessSSEStream(reader *bufio.Reader, handler func(event, data string)) error { + var currentEvent string + var currentData strings.Builder + + for { + line, err := reader.ReadString('\n') + if err != nil { + // Process any remaining data + if currentData.Len() > 0 { + handler(currentEvent, currentData.String()) + } + return err + } + + line = strings.TrimRight(line, "\r\n") + + if line == "" { + // Empty line signals end of event + if currentData.Len() > 0 { + handler(currentEvent, currentData.String()) + currentEvent = "" + currentData.Reset() + } + continue + } + + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data := strings.TrimPrefix(line, "data:") + data = strings.TrimPrefix(data, " ") + currentData.WriteString(data) + } + } +} + +// CreateErrorResponse creates an error response. +func CreateErrorResponse(respID, model, code, message string) *Response { + resp := NewResponse(respID, model) + resp.Status = StatusFailed + resp.CreatedAt = float64(time.Now().Unix()) + resp.Error = &ErrorDetail{ + Code: code, + Message: message, + } + return resp +} diff --git a/pkg/responses/transform.go b/pkg/responses/transform.go new file mode 100644 index 00000000..e1a8ce89 --- /dev/null +++ b/pkg/responses/transform.go @@ -0,0 +1,449 @@ +package responses + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ChatCompletionRequest represents an OpenAI chat completion request. +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Tools []ChatTool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + User string `json:"user,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` +} + +// ChatMessage represents a message in the chat completion format. +type ChatMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` // string or []ContentPart + Name string `json:"name,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// ChatContentPart represents a content part in chat format. +type ChatContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ChatImageURL `json:"image_url,omitempty"` +} + +// ChatImageURL represents an image URL in chat format. +type ChatImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +// ChatTool represents a tool in chat completion format. +type ChatTool struct { + Type string `json:"type"` + Function ChatFunction `json:"function"` +} + +// ChatFunction represents a function definition. +type ChatFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +// ChatToolCall represents a tool call in chat format. +type ChatToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function ChatFunctionCall `json:"function"` +} + +// ChatFunctionCall represents a function call. +type ChatFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatCompletionResponse represents an OpenAI chat completion response. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` +} + +// ChatChoice represents a choice in the chat completion response. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatUsage represents token usage in chat completion format. +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ChatStreamChunk represents a streaming chunk from chat completions. +type ChatStreamChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatStreamChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` +} + +// ChatStreamChoice represents a choice in a streaming chunk. +type ChatStreamChoice struct { + Index int `json:"index"` + Delta ChatDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +// ChatDelta represents the delta in a streaming chunk. +type ChatDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` +} + +// TransformRequestToChatCompletion converts a Responses API request to a chat completion request. +func TransformRequestToChatCompletion(req *CreateRequest, store *Store) (*ChatCompletionRequest, error) { + chatReq := &ChatCompletionRequest{ + Model: req.Model, + Temperature: req.Temperature, + TopP: req.TopP, + MaxTokens: req.MaxOutputTokens, + Stream: req.Stream, + User: req.User, + ParallelToolCalls: req.ParallelToolCalls, + ToolChoice: req.ToolChoice, + } + + // Convert tools + if len(req.Tools) > 0 { + chatReq.Tools = make([]ChatTool, 0, len(req.Tools)) + for _, tool := range req.Tools { + if tool.Type == "function" { + chatTool := ChatTool{ + Type: "function", + } + if tool.Function != nil { + chatTool.Function = ChatFunction{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: tool.Function.Parameters, + } + } else { + chatTool.Function = ChatFunction{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.Parameters, + } + } + chatReq.Tools = append(chatReq.Tools, chatTool) + } + } + } + + // Build messages array + var messages []ChatMessage + + // Add system message from instructions + if req.Instructions != "" { + messages = append(messages, ChatMessage{ + Role: "system", + Content: req.Instructions, + }) + } + + // If there's a previous response, include its conversation + if req.PreviousResponseID != "" && store != nil { + prevResp, ok := store.Get(req.PreviousResponseID) + if ok { + // Recursively get the conversation history + prevMessages := getConversationHistory(prevResp, store) + messages = append(messages, prevMessages...) + } + } + + // Parse and convert input + inputMessages, err := parseInput(req.Input) + if err != nil { + return nil, fmt.Errorf("failed to parse input: %w", err) + } + messages = append(messages, inputMessages...) + + chatReq.Messages = messages + return chatReq, nil +} + +// parseInput parses the input field which can be a string or array of items. +func parseInput(input json.RawMessage) ([]ChatMessage, error) { + if len(input) == 0 { + return nil, nil + } + + // Try parsing as a string first + var strInput string + if err := json.Unmarshal(input, &strInput); err == nil { + return []ChatMessage{{ + Role: "user", + Content: strInput, + }}, nil + } + + // Try parsing as an array of input items + var items []InputItem + if err := json.Unmarshal(input, &items); err != nil { + return nil, fmt.Errorf("input must be a string or array of items: %w", err) + } + + return convertInputItems(items) +} + +// convertInputItems converts input items to chat messages. +func convertInputItems(items []InputItem) ([]ChatMessage, error) { + var messages []ChatMessage + + for _, item := range items { + switch { + case item.Type == ItemTypeFunctionCallOutput || item.CallID != "": + // Function call output -> tool message + messages = append(messages, ChatMessage{ + Role: "tool", + Content: item.Output, + ToolCallID: item.CallID, + }) + + case item.Role != "": + // Message-style input + msg := ChatMessage{ + Role: item.Role, + } + + // Parse content + if len(item.Content) > 0 { + content, err := parseContent(item.Content) + if err != nil { + return nil, err + } + msg.Content = content + } + + messages = append(messages, msg) + + default: + // Try to interpret as a simple message + if len(item.Content) > 0 { + content, err := parseContent(item.Content) + if err != nil { + return nil, err + } + messages = append(messages, ChatMessage{ + Role: "user", + Content: content, + }) + } + } + } + + return messages, nil +} + +// parseContent parses content which can be a string or array of content parts. +func parseContent(content json.RawMessage) (interface{}, error) { + // Try string first + var strContent string + if err := json.Unmarshal(content, &strContent); err == nil { + return strContent, nil + } + + // Try array of content parts + var parts []InputContentPart + if err := json.Unmarshal(content, &parts); err != nil { + return nil, fmt.Errorf("content must be string or array: %w", err) + } + + // Convert to chat format content parts + chatParts := make([]ChatContentPart, 0, len(parts)) + for _, part := range parts { + switch part.Type { + case ContentTypeInputText, "text": + chatParts = append(chatParts, ChatContentPart{ + Type: "text", + Text: part.Text, + }) + case ContentTypeInputImage, "image_url": + chatParts = append(chatParts, ChatContentPart{ + Type: "image_url", + ImageURL: &ChatImageURL{ + URL: part.ImageURL, + }, + }) + } + } + + return chatParts, nil +} + +// getConversationHistory recursively builds conversation history from a response chain. +func getConversationHistory(resp *Response, store *Store) []ChatMessage { + var messages []ChatMessage + + // First, get history from previous response + if resp.PreviousResponseID != nil && *resp.PreviousResponseID != "" && store != nil { + prevResp, ok := store.Get(*resp.PreviousResponseID) + if ok { + messages = append(messages, getConversationHistory(prevResp, store)...) + } + } + + // Add this response's output as assistant messages + for _, item := range resp.Output { + switch item.Type { + case ItemTypeMessage: + msg := ChatMessage{ + Role: item.Role, + } + // Extract text content + var textParts []string + for _, part := range item.Content { + if part.Type == ContentTypeOutputText { + textParts = append(textParts, part.Text) + } + } + msg.Content = strings.Join(textParts, "") + messages = append(messages, msg) + + case ItemTypeFunctionCall: + // Add assistant message with tool call + messages = append(messages, ChatMessage{ + Role: "assistant", + ToolCalls: []ChatToolCall{{ + ID: item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: item.Name, + Arguments: item.Arguments, + }, + }}, + }) + } + } + + return messages +} + +// TransformChatCompletionToResponse converts a chat completion response to a Responses API response. +func TransformChatCompletionToResponse(chatResp *ChatCompletionResponse, respID, model string) *Response { + resp := NewResponse(respID, model) + resp.Status = StatusCompleted + + if len(chatResp.Choices) > 0 { + choice := chatResp.Choices[0] + + // Handle tool calls + if len(choice.Message.ToolCalls) > 0 { + for _, tc := range choice.Message.ToolCalls { + resp.Output = append(resp.Output, OutputItem{ + ID: GenerateItemID(), + Type: ItemTypeFunctionCall, + Status: StatusCompleted, + CallID: tc.ID, + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }) + } + } + + // Handle text content (separately from tool calls, so both can exist in the same response) + if content, ok := choice.Message.Content.(string); ok && content != "" { + msgItem := OutputItem{ + ID: GenerateMessageID(), + Type: ItemTypeMessage, + Role: "assistant", + Status: StatusCompleted, + Content: []ContentPart{{ + Type: ContentTypeOutputText, + Text: content, + Annotations: []Annotation{}, + }}, + } + resp.Output = append(resp.Output, msgItem) + resp.OutputText = content + } else if contentParts, ok := choice.Message.Content.([]ChatContentPart); ok && len(contentParts) > 0 { + // Handle multi-part content (e.g., text + images) + var outputText string + var contentPartsList []ContentPart + + for _, part := range contentParts { + var contentPart ContentPart + switch part.Type { + case "text": + contentPart = ContentPart{ + Type: ContentTypeOutputText, + Text: part.Text, + Annotations: []Annotation{}, + } + outputText += part.Text + case "image_url": + if part.ImageURL != nil { + contentPart = ContentPart{ + Type: ContentTypeOutputText, // Map to output text for compatibility + Text: fmt.Sprintf("[Image: %s]", part.ImageURL.URL), // Include URL reference + Annotations: []Annotation{}, + } + // Add image reference to output text + if outputText != "" { + outputText += " " + } + outputText += fmt.Sprintf("[Image: %s]", part.ImageURL.URL) + } + default: + // Skip unknown content types + continue + } + contentPartsList = append(contentPartsList, contentPart) + } + + if len(contentPartsList) > 0 { + msgItem := OutputItem{ + ID: GenerateMessageID(), + Type: ItemTypeMessage, + Role: "assistant", + Status: StatusCompleted, + Content: contentPartsList, + } + resp.Output = append(resp.Output, msgItem) + resp.OutputText = outputText + } + } + } + + // Convert usage + if chatResp.Usage != nil { + resp.Usage = &Usage{ + InputTokens: chatResp.Usage.PromptTokens, + OutputTokens: chatResp.Usage.CompletionTokens, + TotalTokens: chatResp.Usage.TotalTokens, + } + } + + return resp +} + +// MarshalChatCompletionRequest marshals a chat completion request to JSON. +func MarshalChatCompletionRequest(req *ChatCompletionRequest) ([]byte, error) { + return json.Marshal(req) +} diff --git a/pkg/responses/transform_test.go b/pkg/responses/transform_test.go new file mode 100644 index 00000000..a7a22552 --- /dev/null +++ b/pkg/responses/transform_test.go @@ -0,0 +1,494 @@ +package responses + +import ( + "encoding/json" + "testing" +) + +func TestTransformRequestToChatCompletion_SimpleText(t *testing.T) { + req := &CreateRequest{ + Model: "gpt-4", + Input: json.RawMessage(`"Hello, how are you?"`), + } + + chatReq, err := TransformRequestToChatCompletion(req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if chatReq.Model != "gpt-4" { + t.Errorf("got model %s, want gpt-4", chatReq.Model) + } + + if len(chatReq.Messages) != 1 { + t.Fatalf("got %d messages, want 1", len(chatReq.Messages)) + } + + if chatReq.Messages[0].Role != "user" { + t.Errorf("got role %s, want user", chatReq.Messages[0].Role) + } + + content, ok := chatReq.Messages[0].Content.(string) + if !ok { + t.Fatalf("expected string content") + } + if content != "Hello, how are you?" { + t.Errorf("got content %s, want Hello, how are you?", content) + } +} + +func TestTransformRequestToChatCompletion_WithInstructions(t *testing.T) { + req := &CreateRequest{ + Model: "gpt-4", + Input: json.RawMessage(`"Tell me a joke"`), + Instructions: "You are a helpful assistant.", + } + + chatReq, err := TransformRequestToChatCompletion(req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chatReq.Messages) != 2 { + t.Fatalf("got %d messages, want 2", len(chatReq.Messages)) + } + + // First message should be system + if chatReq.Messages[0].Role != "system" { + t.Errorf("first message role = %s, want system", chatReq.Messages[0].Role) + } + if chatReq.Messages[0].Content != "You are a helpful assistant." { + t.Errorf("system content = %v, want You are a helpful assistant.", chatReq.Messages[0].Content) + } + + // Second message should be user + if chatReq.Messages[1].Role != "user" { + t.Errorf("second message role = %s, want user", chatReq.Messages[1].Role) + } +} + +func TestTransformRequestToChatCompletion_MessageArray(t *testing.T) { + input := `[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ]` + + req := &CreateRequest{ + Model: "gpt-4", + Input: json.RawMessage(input), + } + + chatReq, err := TransformRequestToChatCompletion(req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chatReq.Messages) != 3 { + t.Fatalf("got %d messages, want 3", len(chatReq.Messages)) + } + + expectedRoles := []string{"user", "assistant", "user"} + for i, msg := range chatReq.Messages { + if msg.Role != expectedRoles[i] { + t.Errorf("message %d role = %s, want %s", i, msg.Role, expectedRoles[i]) + } + } +} + +func TestTransformRequestToChatCompletion_MessageArrayWithMultiPartContent(t *testing.T) { + input := `[ + { + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Hi" + }, + { + "type": "input_image", + "image_url": "https://example.com/image.png" + } + ] + } + ]` + + req := &CreateRequest{ + Model: "gpt-4", + Input: json.RawMessage(input), + } + + chatReq, err := TransformRequestToChatCompletion(req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chatReq.Messages) != 1 { + t.Fatalf("got %d messages, want 1", len(chatReq.Messages)) + } + + msg := chatReq.Messages[0] + if msg.Role != "user" { + t.Fatalf("message role = %s, want user", msg.Role) + } + + contentParts, ok := msg.Content.([]ChatContentPart) + if !ok { + t.Fatalf("message content has type %T, want []ChatContentPart", msg.Content) + } + + if len(contentParts) != 2 { + t.Fatalf("got %d content parts, want 2", len(contentParts)) + } + + if contentParts[0].Type != "text" { + t.Errorf("first content part type = %s, want text", contentParts[0].Type) + } + if contentParts[0].Text != "Hi" { + t.Errorf("first content part text = %q, want %q", contentParts[0].Text, "Hi") + } + + if contentParts[1].Type != "image_url" { + t.Errorf("second content part type = %s, want image_url", contentParts[1].Type) + } + if contentParts[1].ImageURL == nil || contentParts[1].ImageURL.URL != "https://example.com/image.png" { + t.Errorf("second content part image_url = %v, want https://example.com/image.png", contentParts[1].ImageURL) + } +} + +func TestTransformRequestToChatCompletion_WithTools(t *testing.T) { + req := &CreateRequest{ + Model: "gpt-4", + Input: json.RawMessage(`"What's the weather?"`), + Tools: []Tool{ + { + Type: "function", + Function: &FunctionDef{ + Name: "get_weather", + Description: "Get the weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + }, + } + + chatReq, err := TransformRequestToChatCompletion(req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chatReq.Tools) != 1 { + t.Fatalf("got %d tools, want 1", len(chatReq.Tools)) + } + + if chatReq.Tools[0].Type != "function" { + t.Errorf("tool type = %s, want function", chatReq.Tools[0].Type) + } + if chatReq.Tools[0].Function.Name != "get_weather" { + t.Errorf("function name = %s, want get_weather", chatReq.Tools[0].Function.Name) + } +} + +func TestTransformRequestToChatCompletion_FunctionCallOutput(t *testing.T) { + input := `[ + {"type": "function_call_output", "call_id": "call_123", "output": "{\"temperature\": 72}"} + ]` + + req := &CreateRequest{ + Model: "gpt-4", + Input: json.RawMessage(input), + } + + chatReq, err := TransformRequestToChatCompletion(req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(chatReq.Messages) != 1 { + t.Fatalf("got %d messages, want 1", len(chatReq.Messages)) + } + + if chatReq.Messages[0].Role != "tool" { + t.Errorf("role = %s, want tool", chatReq.Messages[0].Role) + } + if chatReq.Messages[0].ToolCallID != "call_123" { + t.Errorf("tool_call_id = %s, want call_123", chatReq.Messages[0].ToolCallID) + } +} + +func TestTransformRequestToChatCompletion_WithParameters(t *testing.T) { + temp := 0.7 + topP := 0.9 + maxTokens := 100 + + req := &CreateRequest{ + Model: "gpt-4", + Input: json.RawMessage(`"Test"`), + Temperature: &temp, + TopP: &topP, + MaxOutputTokens: &maxTokens, + Stream: true, + } + + chatReq, err := TransformRequestToChatCompletion(req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if chatReq.Temperature == nil || *chatReq.Temperature != 0.7 { + t.Errorf("temperature = %v, want 0.7", chatReq.Temperature) + } + if chatReq.TopP == nil || *chatReq.TopP != 0.9 { + t.Errorf("top_p = %v, want 0.9", chatReq.TopP) + } + if chatReq.MaxTokens == nil || *chatReq.MaxTokens != 100 { + t.Errorf("max_tokens = %v, want 100", chatReq.MaxTokens) + } + if !chatReq.Stream { + t.Error("stream should be true") + } +} + +func TestTransformChatCompletionToResponse_TextContent(t *testing.T) { + chatResp := &ChatCompletionResponse{ + ID: "chatcmpl-123", + Object: "chat.completion", + Created: 1234567890, + Model: "gpt-4", + Choices: []ChatChoice{ + { + Index: 0, + Message: ChatMessage{ + Role: "assistant", + Content: "Hello! How can I help you today?", + }, + FinishReason: "stop", + }, + }, + Usage: &ChatUsage{ + PromptTokens: 10, + CompletionTokens: 8, + TotalTokens: 18, + }, + } + + resp := TransformChatCompletionToResponse(chatResp, "resp_test123", "gpt-4") + + if resp.ID != "resp_test123" { + t.Errorf("ID = %s, want resp_test123", resp.ID) + } + if resp.Model != "gpt-4" { + t.Errorf("Model = %s, want gpt-4", resp.Model) + } + if resp.Status != StatusCompleted { + t.Errorf("Status = %s, want %s", resp.Status, StatusCompleted) + } + if resp.OutputText != "Hello! How can I help you today?" { + t.Errorf("OutputText = %s, want Hello! How can I help you today?", resp.OutputText) + } + + if len(resp.Output) != 1 { + t.Fatalf("got %d output items, want 1", len(resp.Output)) + } + + if resp.Output[0].Type != ItemTypeMessage { + t.Errorf("output type = %s, want %s", resp.Output[0].Type, ItemTypeMessage) + } + if resp.Output[0].Role != "assistant" { + t.Errorf("output role = %s, want assistant", resp.Output[0].Role) + } + + if resp.Usage == nil { + t.Fatal("expected usage to be set") + } + if resp.Usage.InputTokens != 10 { + t.Errorf("input_tokens = %d, want 10", resp.Usage.InputTokens) + } + if resp.Usage.OutputTokens != 8 { + t.Errorf("output_tokens = %d, want 8", resp.Usage.OutputTokens) + } +} + +func TestTransformChatCompletionToResponse_ToolCalls(t *testing.T) { + chatResp := &ChatCompletionResponse{ + ID: "chatcmpl-123", + Object: "chat.completion", + Created: 1234567890, + Model: "gpt-4", + Choices: []ChatChoice{ + { + Index: 0, + Message: ChatMessage{ + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_abc123", + Type: "function", + Function: ChatFunctionCall{ + Name: "get_weather", + Arguments: `{"location": "San Francisco"}`, + }, + }, + }, + }, + FinishReason: "tool_calls", + }, + }, + } + + resp := TransformChatCompletionToResponse(chatResp, "resp_test123", "gpt-4") + + if len(resp.Output) != 1 { + t.Fatalf("got %d output items, want 1", len(resp.Output)) + } + + if resp.Output[0].Type != ItemTypeFunctionCall { + t.Errorf("output type = %s, want %s", resp.Output[0].Type, ItemTypeFunctionCall) + } + if resp.Output[0].CallID != "call_abc123" { + t.Errorf("call_id = %s, want call_abc123", resp.Output[0].CallID) + } + if resp.Output[0].Name != "get_weather" { + t.Errorf("name = %s, want get_weather", resp.Output[0].Name) + } + if resp.Output[0].Arguments != `{"location": "San Francisco"}` { + t.Errorf("arguments = %s, want {\"location\": \"San Francisco\"}", resp.Output[0].Arguments) + } +} + +func TestTransformChatCompletionToResponse_MixedToolCallsAndText(t *testing.T) { + chatResp := &ChatCompletionResponse{ + ID: "chatcmpl-123", + Object: "chat.completion", + Created: 1234567890, + Model: "gpt-4", + Choices: []ChatChoice{ + { + Index: 0, + Message: ChatMessage{ + Role: "assistant", + Content: "Here's the information you requested:", + ToolCalls: []ChatToolCall{ + { + ID: "call_abc123", + Type: "function", + Function: ChatFunctionCall{ + Name: "get_weather", + Arguments: `{"location": "San Francisco"}`, + }, + }, + }, + }, + FinishReason: "stop", + }, + }, + } + + resp := TransformChatCompletionToResponse(chatResp, "resp_test123", "gpt-4") + + // Should have both function call and message outputs + if len(resp.Output) != 2 { + t.Fatalf("got %d output items, want 2", len(resp.Output)) + } + + // Check for function call item + var funcCallItem *OutputItem + var messageItem *OutputItem + + for i := range resp.Output { + switch resp.Output[i].Type { + case ItemTypeFunctionCall: + funcCallItem = &resp.Output[i] + case ItemTypeMessage: + messageItem = &resp.Output[i] + } + } + + if funcCallItem == nil { + t.Fatal("expected function call item in output") + } + if messageItem == nil { + t.Fatal("expected message item in output") + } + + // Verify function call details + if funcCallItem.CallID != "call_abc123" { + t.Errorf("function call ID = %s, want call_abc123", funcCallItem.CallID) + } + if funcCallItem.Name != "get_weather" { + t.Errorf("function name = %s, want get_weather", funcCallItem.Name) + } + if funcCallItem.Arguments != `{"location": "San Francisco"}` { + t.Errorf("function arguments = %s, want {\"location\": \"San Francisco\"}", funcCallItem.Arguments) + } + + // Verify message details + if messageItem.Role != "assistant" { + t.Errorf("message role = %s, want assistant", messageItem.Role) + } + + // Check if the message contains the expected text + foundText := false + for _, contentPart := range messageItem.Content { + if contentPart.Type == ContentTypeOutputText && contentPart.Text == "Here's the information you requested:" { + foundText = true + break + } + } + if !foundText { + t.Errorf("expected message content 'Here's the information you requested:', but not found") + } + + // Verify OutputText field + if resp.OutputText != "Here's the information you requested:" { + t.Errorf("OutputText = %s, want 'Here's the information you requested:'", resp.OutputText) + } +} + +func TestParseInput_InvalidJSON(t *testing.T) { + _, err := parseInput(json.RawMessage(`{invalid`)) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestGenerateIDs(t *testing.T) { + // Test that IDs have correct prefixes + respID := GenerateResponseID() + if !startsWith(respID, "resp_") { + t.Errorf("response ID should start with resp_, got %s", respID) + } + + itemID := GenerateItemID() + if !startsWith(itemID, "item_") { + t.Errorf("item ID should start with item_, got %s", itemID) + } + + msgID := GenerateMessageID() + if !startsWith(msgID, "msg_") { + t.Errorf("message ID should start with msg_, got %s", msgID) + } + + callID := GenerateCallID() + if !startsWith(callID, "call_") { + t.Errorf("call ID should start with call_, got %s", callID) + } + + // Test uniqueness + id1 := GenerateResponseID() + id2 := GenerateResponseID() + if id1 == id2 { + t.Error("generated IDs should be unique") + } +} + +func startsWith(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +}