diff --git a/internal/adapter/provider/adapter.go b/internal/adapter/provider/adapter.go index 46519208..9013af11 100644 --- a/internal/adapter/provider/adapter.go +++ b/internal/adapter/provider/adapter.go @@ -1,10 +1,8 @@ package provider import ( - "context" - "net/http" - "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" ) // ProviderAdapter handles communication with upstream providers @@ -13,10 +11,10 @@ type ProviderAdapter interface { SupportedClientTypes() []domain.ClientType // Execute performs the proxy request to the upstream provider - // It reads from ctx for ClientType, MappedModel, RequestBody - // It writes the response to w + // It reads from flow.Ctx for ClientType, MappedModel, RequestBody + // It writes the response to c.Writer // Returns ProxyError on failure - Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error + Execute(c *flow.Ctx, provider *domain.Provider) error } // AdapterFactory creates ProviderAdapter instances diff --git a/internal/adapter/provider/antigravity/adapter.go b/internal/adapter/provider/antigravity/adapter.go index 0f610573..b83b8011 100644 --- a/internal/adapter/provider/antigravity/adapter.go +++ b/internal/adapter/provider/antigravity/adapter.go @@ -16,8 +16,9 @@ import ( "github.com/awsl-project/maxx/internal/adapter/provider" cliproxyapi "github.com/awsl-project/maxx/internal/adapter/provider/cliproxyapi_antigravity" - ctxutil "github.com/awsl-project/maxx/internal/context" + "github.com/awsl-project/maxx/internal/converter" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" ) @@ -32,10 +33,11 @@ type TokenCache struct { } type AntigravityAdapter struct { - provider *domain.Provider - tokenCache *TokenCache - tokenMu sync.RWMutex - httpClient *http.Client + provider *domain.Provider + tokenCache *TokenCache + tokenMu sync.RWMutex + projectIDOnce sync.Once + httpClient *http.Client } func NewAdapter(p *domain.Provider) (provider.ProviderAdapter, error) { @@ -71,17 +73,21 @@ func NewAdapter(p *domain.Provider) (provider.ProviderAdapter, error) { } func (a *AntigravityAdapter) SupportedClientTypes() []domain.ClientType { - // Antigravity natively supports Claude and Gemini by converting to Gemini/v1internal API - // OpenAI requests will be converted to Claude format by Executor before reaching this adapter - return []domain.ClientType{domain.ClientTypeClaude, domain.ClientTypeGemini} + // Antigravity natively supports Claude and Gemini (via Gemini/v1internal API). + // Prefer Gemini when choosing a target format. + return []domain.ClientType{domain.ClientTypeGemini, domain.ClientTypeClaude} } -func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error { - clientType := ctxutil.GetClientType(ctx) - baseCtx := ctx - requestModel := ctxutil.GetRequestModel(ctx) // Original model from request (e.g., "claude-3-5-sonnet-20241022-online") - mappedModel := ctxutil.GetMappedModel(ctx) // Mapped model after executor's unified mapping - requestBody := ctxutil.GetRequestBody(ctx) +func (a *AntigravityAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error { + clientType := flow.GetClientType(c) + requestModel := flow.GetRequestModel(c) + mappedModel := flow.GetMappedModel(c) + requestBody := flow.GetRequestBody(c) + request := c.Request + ctx := context.Background() + if request != nil { + ctx = request.Context() + } backgroundDowngrade := false backgroundModel := "" @@ -99,8 +105,8 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, retriedWithoutThinking := false for attemptIdx := 0; attemptIdx < 2; attemptIdx++ { - ctx = ctxutil.WithRequestModel(baseCtx, requestModel) - ctx = ctxutil.WithRequestBody(ctx, requestBody) + c.Set(flow.KeyRequestModel, requestModel) + c.Set(flow.KeyRequestBody, requestBody) // Apply background downgrade override if needed config := provider.Config.Antigravity @@ -109,12 +115,12 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, } // Update attempt record with the final mapped model (in case of background downgrade) - if attempt := ctxutil.GetUpstreamAttempt(ctx); attempt != nil { + if attempt := flow.GetUpstreamAttempt(c); attempt != nil { attempt.MappedModel = mappedModel } // Get streaming flag from context (already detected correctly for Gemini URL path) - stream := ctxutil.GetIsStream(ctx) + stream := flow.GetIsStream(c) clientWantsStream := stream actualStream := stream if clientType == domain.ClientTypeClaude && !clientWantsStream { @@ -133,6 +139,7 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, // Transform request based on client type var geminiBody []byte + openAIWrapped := false switch clientType { case domain.ClientTypeClaude: // Use direct transformation (no converter dependency) @@ -151,207 +158,247 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, // Apply minimal post-processing for features not yet fully integrated geminiBody = applyClaudePostProcess(geminiBody, sessionID, hasThinking, requestBody, mappedModel) case domain.ClientTypeOpenAI: - // TODO: Implement OpenAI transformation in the future - return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, true, "OpenAI transformation not yet implemented") + geminiBody = ConvertOpenAIRequestToAntigravity(mappedModel, requestBody, actualStream) + openAIWrapped = true default: // For Gemini, unwrap CLI envelope if present geminiBody = unwrapGeminiCLIEnvelope(requestBody) } - // Wrap request in v1internal format - var toolsForConfig []interface{} - if clientType == domain.ClientTypeClaude { - var raw map[string]interface{} - if err := json.Unmarshal(requestBody, &raw); err == nil { - if tools, ok := raw["tools"].([]interface{}); ok { - toolsForConfig = tools - } + // Resolve project ID (CLIProxyAPI behavior) + a.projectIDOnce.Do(func() { + if strings.TrimSpace(config.ProjectID) != "" { + return } - } - upstreamBody, err := wrapV1InternalRequest(geminiBody, config.ProjectID, requestModel, mappedModel, sessionID, toolsForConfig) - if err != nil { - return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, true, "failed to wrap request for v1internal") - } - - // Build upstream URLs (prod first, daily fallback) - baseURLs := []string{V1InternalBaseURLProd, V1InternalBaseURLDaily} - client := a.httpClient - var lastErr error - - for idx, base := range baseURLs { - upstreamURL := a.buildUpstreamURL(base, actualStream) - - upstreamReq, reqErr := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody)) - if reqErr != nil { - lastErr = reqErr - continue + if pid, _, err := FetchProjectInfo(ctx, accessToken, config.Email); err == nil { + pid = strings.TrimSpace(pid) + if pid != "" { + config.ProjectID = pid + } } + }) + projectID := strings.TrimSpace(config.ProjectID) - // Set only the required headers (like Antigravity-Manager) - upstreamReq.Header.Set("Content-Type", "application/json") - upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) - upstreamReq.Header.Set("User-Agent", AntigravityUserAgent) - - // Send request info via EventChannel (only once per attempt) - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { - eventChan.SendRequestInfo(&domain.RequestInfo{ - Method: upstreamReq.Method, - URL: upstreamURL, - Headers: flattenHeaders(upstreamReq.Header), - Body: string(upstreamBody), - }) + var upstreamBody []byte + if openAIWrapped { + upstreamBody = finalizeOpenAIWrappedRequest(geminiBody, projectID, mappedModel, sessionID) + } else { + // Wrap request in v1internal format + var toolsForConfig []interface{} + if clientType == domain.ClientTypeClaude { + var raw map[string]interface{} + if err := json.Unmarshal(requestBody, &raw); err == nil { + if tools, ok := raw["tools"].([]interface{}); ok { + toolsForConfig = tools + } + } } - - resp, err := client.Do(upstreamReq) + upstreamBody, err = wrapV1InternalRequest(geminiBody, projectID, requestModel, mappedModel, sessionID, toolsForConfig) if err != nil { - lastErr = err - if hasNextEndpoint(idx, len(baseURLs)) { - continue - } - proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream") - proxyErr.IsNetworkError = true // Mark as network error (connection timeout, DNS failure, etc.) - return proxyErr + return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, true, "failed to wrap request for v1internal") } - defer resp.Body.Close() + } - // Check for 401 (token expired) and retry once - if resp.StatusCode == http.StatusUnauthorized { - resp.Body.Close() + // Build upstream URLs (CLIProxyAPI fallback order) + baseURLs := antigravityBaseURLFallbackOrder(config.Endpoint) + client := a.httpClient + var lastErr error - // Invalidate token cache - a.tokenMu.Lock() - a.tokenCache = &TokenCache{} - a.tokenMu.Unlock() + for attempt := 0; attempt < antigravityRetryAttempts; attempt++ { + for idx, base := range baseURLs { + upstreamURL := a.buildUpstreamURL(base, actualStream) - // Get new token - accessToken, err = a.getAccessToken(ctx) - if err != nil { - return domain.NewProxyErrorWithMessage(err, true, "failed to refresh access token") + upstreamReq, reqErr := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody)) + if reqErr != nil { + lastErr = reqErr + continue } - // Retry request with only required headers - upstreamReq, _ = http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody)) + // Set only the required headers (like Antigravity-Manager) upstreamReq.Header.Set("Content-Type", "application/json") upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) upstreamReq.Header.Set("User-Agent", AntigravityUserAgent) - resp, err = client.Do(upstreamReq) + + // Send request info via EventChannel (only once per attempt) + if eventChan := flow.GetEventChan(c); eventChan != nil { + eventChan.SendRequestInfo(&domain.RequestInfo{ + Method: upstreamReq.Method, + URL: upstreamURL, + Headers: flattenHeaders(upstreamReq.Header), + Body: string(upstreamBody), + }) + } + + resp, err := client.Do(upstreamReq) if err != nil { lastErr = err if hasNextEndpoint(idx, len(baseURLs)) { continue } - proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream after token refresh") - proxyErr.IsNetworkError = true // Mark as network error + proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream") + proxyErr.IsNetworkError = true // Mark as network error (connection timeout, DNS failure, etc.) return proxyErr } - defer resp.Body.Close() - } - // Check for error response - if resp.StatusCode >= 400 { - body, _ := io.ReadAll(resp.Body) - // Send error response info via EventChannel - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { - eventChan.SendResponseInfo(&domain.ResponseInfo{ - Status: resp.StatusCode, - Headers: flattenHeaders(resp.Header), - Body: string(body), - }) - } + // Check for 401 (token expired) and retry once + if resp.StatusCode == http.StatusUnauthorized { + resp.Body.Close() - // Check for RESOURCE_EXHAUSTED (429) and extract cooldown info - var rateLimitInfo *domain.RateLimitInfo - var cooldownUpdateChan chan time.Time - if resp.StatusCode == http.StatusTooManyRequests { - rateLimitInfo, cooldownUpdateChan = a.parseRateLimitInfo(ctx, body, provider) - } + // Invalidate token cache + a.tokenMu.Lock() + a.tokenCache = &TokenCache{} + a.tokenMu.Unlock() - // Parse retry info for 429/5xx responses (like Antigravity-Manager) - var retryAfter time.Duration + // Get new token + accessToken, err = a.getAccessToken(ctx) + if err != nil { + return domain.NewProxyErrorWithMessage(err, true, "failed to refresh access token") + } - // 1) Prefer Retry-After header (seconds) - if ra := strings.TrimSpace(resp.Header.Get("Retry-After")); ra != "" { - if secs, err := strconv.Atoi(ra); err == nil && secs > 0 { - retryAfter = time.Duration(secs) * time.Second + // Retry request with only required headers + upstreamReq, reqErr = http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody)) + if reqErr != nil { + return domain.NewProxyErrorWithMessage(reqErr, false, "failed to create upstream request after token refresh") + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", AntigravityUserAgent) + resp, err = client.Do(upstreamReq) + if err != nil { + lastErr = err + if hasNextEndpoint(idx, len(baseURLs)) { + continue + } + proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream after token refresh") + proxyErr.IsNetworkError = true // Mark as network error + return proxyErr } } - // 2) Fallback to body parsing (google.rpc.RetryInfo / quotaResetDelay) - if retryAfter == 0 { - if retryInfo := ParseRetryInfo(resp.StatusCode, body); retryInfo != nil { - retryAfter = retryInfo.Delay + // Check for error response + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + // Send error response info via EventChannel + if eventChan := flow.GetEventChan(c); eventChan != nil { + eventChan.SendResponseInfo(&domain.ResponseInfo{ + Status: resp.StatusCode, + Headers: flattenHeaders(resp.Header), + Body: string(body), + }) + } - // Manager: add a small buffer and cap for 429 retries - if resp.StatusCode == http.StatusTooManyRequests { - retryAfter += 200 * time.Millisecond - if retryAfter > 10*time.Second { - retryAfter = 10 * time.Second - } + // Check for RESOURCE_EXHAUSTED (429) and extract cooldown info + var rateLimitInfo *domain.RateLimitInfo + var cooldownUpdateChan chan time.Time + if resp.StatusCode == http.StatusTooManyRequests { + rateLimitInfo, cooldownUpdateChan = a.parseRateLimitInfo(ctx, body, provider) + } + + // Parse retry info for 429/5xx responses (like Antigravity-Manager) + var retryAfter time.Duration + + // 1) Prefer Retry-After header (seconds) + if ra := strings.TrimSpace(resp.Header.Get("Retry-After")); ra != "" { + if secs, err := strconv.Atoi(ra); err == nil && secs > 0 { + retryAfter = time.Duration(secs) * time.Second } + } + + // 2) Fallback to body parsing (google.rpc.RetryInfo / quotaResetDelay) + if retryAfter == 0 { + if retryInfo := ParseRetryInfo(resp.StatusCode, body); retryInfo != nil { + retryAfter = retryInfo.Delay + + // Manager: add a small buffer and cap for 429 retries + if resp.StatusCode == http.StatusTooManyRequests { + retryAfter += 200 * time.Millisecond + if retryAfter > 10*time.Second { + retryAfter = 10 * time.Second + } + } - retryAfter = ApplyJitter(retryAfter) + retryAfter = ApplyJitter(retryAfter) + } } - } - proxyErr := domain.NewProxyErrorWithMessage( - fmt.Errorf("upstream error: %s", string(body)), - isRetryableStatusCode(resp.StatusCode), - fmt.Sprintf("upstream returned status %d", resp.StatusCode), - ) + proxyErr := domain.NewProxyErrorWithMessage( + fmt.Errorf("upstream error: %s", string(body)), + isRetryableStatusCode(resp.StatusCode), + fmt.Sprintf("upstream returned status %d", resp.StatusCode), + ) - // Set status code and check if it's a server error (5xx) - proxyErr.HTTPStatusCode = resp.StatusCode - proxyErr.IsServerError = resp.StatusCode >= 500 && resp.StatusCode < 600 + // Set status code and check if it's a server error (5xx) + proxyErr.HTTPStatusCode = resp.StatusCode + proxyErr.IsServerError = resp.StatusCode >= 500 && resp.StatusCode < 600 - // Set retry info on error for upstream handling - if retryAfter > 0 { - proxyErr.RetryAfter = retryAfter - } + // Set retry info on error for upstream handling + if retryAfter > 0 { + proxyErr.RetryAfter = retryAfter + } - // Set rate limit info for cooldown handling - if rateLimitInfo != nil { - proxyErr.RateLimitInfo = rateLimitInfo - proxyErr.CooldownUpdateChan = cooldownUpdateChan - } + // Set rate limit info for cooldown handling + if rateLimitInfo != nil { + proxyErr.RateLimitInfo = rateLimitInfo + proxyErr.CooldownUpdateChan = cooldownUpdateChan + } - lastErr = proxyErr + lastErr = proxyErr - // Signature failure recovery: retry once without thinking (like Manager) - if resp.StatusCode == http.StatusBadRequest && !retriedWithoutThinking && isThinkingSignatureError(body) { - retriedWithoutThinking = true + // Signature failure recovery: retry once without thinking (like Manager) + if resp.StatusCode == http.StatusBadRequest && !retriedWithoutThinking && isThinkingSignatureError(body) { + retriedWithoutThinking = true + + // Manager uses a small fixed delay before retrying. + select { + case <-ctx.Done(): + return domain.NewProxyErrorWithMessage(ctx.Err(), false, "client disconnected") + case <-time.After(200 * time.Millisecond): + } - // Manager uses a small fixed delay before retrying. - select { - case <-ctx.Done(): - return domain.NewProxyErrorWithMessage(ctx.Err(), false, "client disconnected") - case <-time.After(200 * time.Millisecond): + requestBody = stripThinkingFromClaude(requestBody) + if newModel := extractModelFromBody(requestBody); newModel != "" { + requestModel = newModel + } + mappedModel = "" // force remap + continue } - requestBody = stripThinkingFromClaude(requestBody) - if newModel := extractModelFromBody(requestBody); newModel != "" { - requestModel = newModel + // Retry fallback handling (CLIProxyAPI behavior) + if resp.StatusCode == http.StatusTooManyRequests && hasNextEndpoint(idx, len(baseURLs)) { + continue } - mappedModel = "" // force remap - continue + if antigravityShouldRetryNoCapacity(resp.StatusCode, body) { + if hasNextEndpoint(idx, len(baseURLs)) { + continue + } + if attempt+1 < antigravityRetryAttempts { + delay := antigravityNoCapacityRetryDelay(attempt) + if err := antigravityWait(ctx, delay); err != nil { + return domain.NewProxyErrorWithMessage(err, false, "client disconnected") + } + break + } + } + + return proxyErr } - // Fallback to next endpoint if available and retryable - if hasNextEndpoint(idx, len(baseURLs)) && shouldTryNextEndpoint(resp.StatusCode) { + // Handle response + if actualStream && !clientWantsStream { + err := a.handleCollectedStreamResponse(c, resp, clientType, requestModel) resp.Body.Close() - continue + return err } - - return proxyErr - } - - // Handle response - if actualStream && !clientWantsStream { - return a.handleCollectedStreamResponse(ctx, w, resp, clientType, requestModel) - } - if actualStream { - return a.handleStreamResponse(ctx, w, resp, clientType) + if actualStream { + err := a.handleStreamResponse(c, resp, clientType) + resp.Body.Close() + return err + } + nErr := a.handleNonStreamResponse(c, resp, clientType) + resp.Body.Close() + return nErr } - return a.handleNonStreamResponse(ctx, w, resp, clientType) } // All endpoints failed in this iteration @@ -490,17 +537,26 @@ func applyClaudePostProcess(geminiBody []byte, sessionID string, hasThinking boo return result } -// v1internal endpoints (prod + daily fallback, like Antigravity-Manager) +// v1internal endpoints (CLIProxyAPI fallback order) const ( - V1InternalBaseURLProd = "https://cloudcode-pa.googleapis.com/v1internal" - V1InternalBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com/v1internal" + V1InternalBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com" + V1InternalSandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" + V1InternalBaseURLProd = "https://cloudcode-pa.googleapis.com" + antigravityRetryAttempts = 3 ) func (a *AntigravityAdapter) buildUpstreamURL(base string, stream bool) string { + base = strings.TrimRight(base, "/") + if strings.Contains(base, "/v1internal") { + if stream { + return fmt.Sprintf("%s:streamGenerateContent?alt=sse", base) + } + return fmt.Sprintf("%s:generateContent", base) + } if stream { - return fmt.Sprintf("%s:streamGenerateContent?alt=sse", base) + return fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", base) } - return fmt.Sprintf("%s:generateContent", base) + return fmt.Sprintf("%s/v1internal:generateContent", base) } func hasNextEndpoint(index, total int) bool { @@ -516,6 +572,76 @@ func shouldTryNextEndpoint(status int) bool { return status >= 500 } +func antigravityBaseURLFallbackOrder(endpoint string) []string { + if endpoint = strings.TrimSpace(endpoint); endpoint != "" { + if isAntigravityEndpoint(endpoint) { + return []string{strings.TrimRight(endpoint, "/")} + } + } + return []string{ + V1InternalBaseURLDaily, + V1InternalSandboxBaseURLDaily, + // V1InternalBaseURLProd, + } +} + +func isAntigravityEndpoint(endpoint string) bool { + endpoint = strings.ToLower(strings.TrimSpace(endpoint)) + if endpoint == "" { + return false + } + // Only accept Antigravity v1internal endpoints, not Vertex AI endpoints. + if strings.Contains(endpoint, "cloudcode-pa.googleapis.com") { + return true + } + if strings.Contains(endpoint, "daily-cloudcode-pa.googleapis.com") { + return true + } + if strings.Contains(endpoint, "daily-cloudcode-pa.sandbox.googleapis.com") { + return true + } + if strings.Contains(endpoint, "/v1internal") && strings.Contains(endpoint, "cloudcode-pa") { + return true + } + return false +} + +func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool { + if statusCode != http.StatusServiceUnavailable { + return false + } + if len(body) == 0 { + return false + } + msg := strings.ToLower(string(body)) + return strings.Contains(msg, "no capacity available") +} + +func antigravityNoCapacityRetryDelay(attempt int) time.Duration { + if attempt < 0 { + attempt = 0 + } + delay := time.Duration(attempt+1) * 250 * time.Millisecond + if delay > 2*time.Second { + delay = 2 * time.Second + } + return delay +} + +func antigravityWait(ctx context.Context, wait time.Duration) error { + if wait <= 0 { + return nil + } + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + // isThinkingSignatureError detects thinking signature related 400 errors (like Manager) func isThinkingSignatureError(body []byte) bool { bodyStr := strings.ToLower(string(body)) @@ -526,7 +652,8 @@ func isThinkingSignatureError(body []byte) bool { strings.Contains(bodyStr, "failed to deserialise") } -func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType) error { +func (a *AntigravityAdapter) handleNonStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType) error { + w := c.Writer body, err := io.ReadAll(resp.Body) if err != nil { return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to read upstream response") @@ -535,31 +662,30 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http // Unwrap v1internal response wrapper (extract "response" field) unwrappedBody := unwrapV1InternalResponse(body) - // Send events via EventChannel (executor will process them) - eventChan := ctxutil.GetEventChan(ctx) - - // Send response info event - eventChan.SendResponseInfo(&domain.ResponseInfo{ - Status: resp.StatusCode, - Headers: flattenHeaders(resp.Header), - Body: string(body), // Keep original for debugging - }) - - // Extract and send token usage metrics - if metrics := usage.ExtractFromResponse(string(unwrappedBody)); metrics != nil { - eventChan.SendMetrics(&domain.AdapterMetrics{ - InputTokens: metrics.InputTokens, - OutputTokens: metrics.OutputTokens, - CacheReadCount: metrics.CacheReadCount, - CacheCreationCount: metrics.CacheCreationCount, - Cache5mCreationCount: metrics.Cache5mCreationCount, - Cache1hCreationCount: metrics.Cache1hCreationCount, + if eventChan := flow.GetEventChan(c); eventChan != nil { + eventChan.SendResponseInfo(&domain.ResponseInfo{ + Status: resp.StatusCode, + Headers: flattenHeaders(resp.Header), + Body: string(body), }) + + if metrics := usage.ExtractFromResponse(string(unwrappedBody)); metrics != nil { + eventChan.SendMetrics(&domain.AdapterMetrics{ + InputTokens: metrics.InputTokens, + OutputTokens: metrics.OutputTokens, + CacheReadCount: metrics.CacheReadCount, + CacheCreationCount: metrics.CacheCreationCount, + Cache5mCreationCount: metrics.Cache5mCreationCount, + Cache1hCreationCount: metrics.Cache1hCreationCount, + }) + } } // Extract and send response model if modelVersion := extractModelVersion(unwrappedBody); modelVersion != "" { - eventChan.SendResponseModel(modelVersion) + if eventChan := flow.GetEventChan(c); eventChan != nil { + eventChan.SendResponseModel(modelVersion) + } } var responseBody []byte @@ -567,14 +693,17 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http // Transform response based on client type switch clientType { case domain.ClientTypeClaude: - requestModel := ctxutil.GetRequestModel(ctx) + requestModel := flow.GetRequestModel(c) responseBody, err = convertGeminiToClaudeResponse(unwrappedBody, requestModel) if err != nil { return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "failed to transform response") } case domain.ClientTypeOpenAI: - // TODO: Implement OpenAI response transformation - return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "OpenAI response transformation not yet implemented") + responseBody, err = converter.GetGlobalRegistry().TransformResponse( + domain.ClientTypeGemini, domain.ClientTypeOpenAI, unwrappedBody) + if err != nil { + return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "failed to transform response") + } default: // Gemini native responseBody = unwrappedBody @@ -588,8 +717,13 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http return nil } -func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType) error { - eventChan := ctxutil.GetEventChan(ctx) +func (a *AntigravityAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType) error { + w := c.Writer + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } + eventChan := flow.GetEventChan(c) // Send initial response info (for streaming, we only capture status and headers) eventChan.SendResponseInfo(&domain.ResponseInfo{ @@ -614,18 +748,23 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re // Use specialized Claude SSE handler for Claude clients isClaudeClient := clientType == domain.ClientTypeClaude + isOpenAIClient := clientType == domain.ClientTypeOpenAI // Extract sessionID for signature caching (like CLIProxyAPI) - requestBody := ctxutil.GetRequestBody(ctx) + requestBody := flow.GetRequestBody(c) sessionID := extractSessionID(requestBody) // Get original request model for Claude response (like Antigravity-Manager) - requestModel := ctxutil.GetRequestModel(ctx) + requestModel := flow.GetRequestModel(c) var claudeState *ClaudeStreamingState if isClaudeClient { claudeState = NewClaudeStreamingStateWithSession(sessionID, requestModel) } + var openaiState *converter.TransformState + if isOpenAIClient { + openaiState = converter.NewTransformState() + } // Collect all SSE events for response body and token extraction var sseBuffer strings.Builder @@ -698,6 +837,9 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re // Unwrap v1internal SSE chunk before processing unwrappedLine := unwrapV1InternalSSEChunk(lineBytes) + if len(unwrappedLine) == 0 { + continue + } // Collect original SSE for token extraction (extractor handles v1internal wrapper) sseBuffer.WriteString(line) @@ -706,9 +848,13 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re if isClaudeClient { // Use specialized Claude SSE transformation output = claudeState.ProcessGeminiSSELine(string(unwrappedLine)) - } else if clientType == domain.ClientTypeOpenAI { - // TODO: Implement OpenAI streaming transformation - continue + } else if isOpenAIClient { + converted, convErr := converter.GetGlobalRegistry().TransformStreamChunk( + domain.ClientTypeGemini, domain.ClientTypeOpenAI, unwrappedLine, openaiState) + if convErr != nil { + continue + } + output = converted } else { // Gemini native output = unwrappedLine @@ -770,15 +916,20 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re } // handleCollectedStreamResponse forwards upstream SSE but collects into a single response body (like Manager non-stream auto-convert) -func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType, requestModel string) error { - eventChan := ctxutil.GetEventChan(ctx) - - // Send initial response info - eventChan.SendResponseInfo(&domain.ResponseInfo{ - Status: resp.StatusCode, - Headers: flattenHeaders(resp.Header), - Body: "[stream-collected]", - }) +func (a *AntigravityAdapter) handleCollectedStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType, requestModel string) error { + w := c.Writer + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } + eventChan := flow.GetEventChan(c) + if eventChan != nil { + eventChan.SendResponseInfo(&domain.ResponseInfo{ + Status: resp.StatusCode, + Headers: flattenHeaders(resp.Header), + Body: "[stream-collected]", + }) + } // Copy upstream headers (except those we override) copyResponseHeaders(w.Header(), resp.Header) @@ -788,14 +939,14 @@ func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context, var claudeSSE strings.Builder if isClaudeClient { // Extract sessionID for signature caching (like CLIProxyAPI) - requestBody := ctxutil.GetRequestBody(ctx) + requestBody := flow.GetRequestBody(c) sessionID := extractSessionID(requestBody) claudeState = NewClaudeStreamingStateWithSession(sessionID, requestModel) } // Collect upstream SSE for attempt/debug and token extraction. var upstreamSSE strings.Builder - var lastPayload []byte + var unwrappedSSE strings.Builder var responseBody []byte var lineBuffer bytes.Buffer @@ -826,15 +977,7 @@ func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context, if len(unwrappedLine) == 0 { continue } - - // Track last Gemini payload for non-Claude responses (best-effort) - lineStr := strings.TrimSpace(string(unwrappedLine)) - if strings.HasPrefix(lineStr, "data: ") { - dataStr := strings.TrimSpace(strings.TrimPrefix(lineStr, "data: ")) - if dataStr != "" && dataStr != "[DONE]" { - lastPayload = []byte(dataStr) - } - } + unwrappedSSE.Write(unwrappedLine) if isClaudeClient && claudeState != nil { out := claudeState.ProcessGeminiSSELine(string(unwrappedLine)) @@ -901,16 +1044,23 @@ func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context, } responseBody = collected } else { - if len(lastPayload) == 0 { + if unwrappedSSE.Len() == 0 { return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "empty upstream stream response") } + geminiWrapped := convertStreamToNonStream([]byte(unwrappedSSE.String())) + geminiResponse := unwrapV1InternalResponse(geminiWrapped) switch clientType { case domain.ClientTypeGemini: - responseBody = lastPayload + responseBody = geminiResponse case domain.ClientTypeOpenAI: - return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "OpenAI response transformation not yet implemented") + var convErr error + responseBody, convErr = converter.GetGlobalRegistry().TransformResponse( + domain.ClientTypeGemini, domain.ClientTypeOpenAI, geminiResponse) + if convErr != nil { + return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "failed to transform response") + } default: - responseBody = lastPayload + responseBody = geminiResponse } } diff --git a/internal/adapter/provider/antigravity/openai_request.go b/internal/adapter/provider/antigravity/openai_request.go new file mode 100644 index 00000000..521f7ce9 --- /dev/null +++ b/internal/adapter/provider/antigravity/openai_request.go @@ -0,0 +1,426 @@ +package antigravity + +import ( + "bytes" + "fmt" + "log" + "mime" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator" + +// ConvertOpenAIRequestToAntigravity converts an OpenAI Chat Completions request (raw JSON) +// into a Gemini CLI compatible request JSON (antigravity format). +// Ported from CLIProxyAPI antigravity/openai/chat-completions translator. +func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { + rawJSON := bytes.Clone(inputRawJSON) + // Base envelope (no default thinkingConfig) + out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`) + + // Model + out, _ = sjson.SetBytes(out, "model", modelName) + + // Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig. + re := gjson.GetBytes(rawJSON, "reasoning_effort") + if re.Exists() { + effort := strings.ToLower(strings.TrimSpace(re.String())) + if effort != "" { + thinkingPath := "request.generationConfig.thinkingConfig" + if effort == "auto" { + out, _ = sjson.SetBytes(out, thinkingPath+".thinkingBudget", -1) + out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", true) + } else { + out, _ = sjson.SetBytes(out, thinkingPath+".thinkingLevel", effort) + out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", effort != "none") + } + } + } + + // Temperature/top_p/top_k/max_tokens + if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num) + } + if tpr := gjson.GetBytes(rawJSON, "top_p"); tpr.Exists() && tpr.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.topP", tpr.Num) + } + if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num) + } + if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number { + out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num) + } + + // Candidate count (OpenAI 'n' parameter) + if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number { + if val := n.Int(); val > 1 { + out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val) + } + } + + // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities + if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() { + var responseMods []string + for _, m := range mods.Array() { + switch strings.ToLower(m.String()) { + case "text": + responseMods = append(responseMods, "TEXT") + case "image": + responseMods = append(responseMods, "IMAGE") + } + } + if len(responseMods) > 0 { + out, _ = sjson.SetBytes(out, "request.generationConfig.responseModalities", responseMods) + } + } + + // OpenRouter-style image_config support + if imgCfg := gjson.GetBytes(rawJSON, "image_config"); imgCfg.Exists() && imgCfg.IsObject() { + if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String { + out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.aspectRatio", ar.Str) + } + if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String { + out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.imageSize", size.Str) + } + } + + // messages -> systemInstruction + contents + messages := gjson.GetBytes(rawJSON, "messages") + if messages.IsArray() { + arr := messages.Array() + // First pass: assistant tool_calls id->name map + tcID2Name := map[string]string{} + for i := 0; i < len(arr); i++ { + m := arr[i] + if m.Get("role").String() == "assistant" { + tcs := m.Get("tool_calls") + if tcs.IsArray() { + for _, tc := range tcs.Array() { + if tc.Get("type").String() == "function" { + id := tc.Get("id").String() + name := tc.Get("function.name").String() + if id != "" && name != "" { + tcID2Name[id] = name + } + } + } + } + } + } + + // Second pass build systemInstruction/tool responses cache + toolResponses := map[string]string{} // tool_call_id -> response text + for i := 0; i < len(arr); i++ { + m := arr[i] + role := m.Get("role").String() + if role == "tool" { + toolCallID := m.Get("tool_call_id").String() + if toolCallID != "" { + c := m.Get("content") + toolResponses[toolCallID] = c.Raw + } + } + } + + systemPartIndex := 0 + for i := 0; i < len(arr); i++ { + m := arr[i] + role := m.Get("role").String() + content := m.Get("content") + + if (role == "system" || role == "developer") && len(arr) > 1 { + // system -> request.systemInstruction as a user message style + if content.Type == gjson.String { + out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") + out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.String()) + systemPartIndex++ + } else if content.IsObject() && content.Get("type").String() == "text" { + out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") + out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String()) + systemPartIndex++ + } else if content.IsArray() { + contents := content.Array() + if len(contents) > 0 { + out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") + for j := 0; j < len(contents); j++ { + text := contents[j].Get("text").String() + if text != "" { + out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), text) + systemPartIndex++ + } + } + } + } + } else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) { + // Build single user content node to avoid splitting into multiple contents + node := []byte(`{"role":"user","parts":[]}`) + if content.Type == gjson.String { + node, _ = sjson.SetBytes(node, "parts.0.text", content.String()) + } else if content.IsArray() { + items := content.Array() + p := 0 + for _, item := range items { + switch item.Get("type").String() { + case "text": + text := item.Get("text").String() + if text != "" { + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text) + p++ + } + case "image_url": + imageURL := item.Get("image_url.url").String() + if strings.HasPrefix(imageURL, "data:") { + pieces := strings.SplitN(imageURL[len("data:"):], ";", 2) + if len(pieces) == 2 && len(pieces[1]) > 7 { + mimeType := pieces[0] + data := pieces[1][7:] + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) + p++ + } + } + case "file": + filename := item.Get("file.filename").String() + fileData := item.Get("file.file_data").String() + if filename != "" && fileData != "" { + ext := "" + if sp := strings.Split(filename, "."); len(sp) > 1 { + ext = sp[len(sp)-1] + } + mimeType := mime.TypeByExtension("." + ext) + if mimeType != "" { + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData) + p++ + } else { + log.Printf("unknown file extension '%s' in user message, skip", ext) + } + } + } + } + } + out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) + } else if role == "assistant" { + node := []byte(`{"role":"model","parts":[]}`) + p := 0 + if content.Type == gjson.String && content.String() != "" { + node, _ = sjson.SetBytes(node, "parts.-1.text", content.String()) + p++ + } else if content.IsArray() { + // Assistant multimodal content -> single model content with parts + for _, item := range content.Array() { + switch item.Get("type").String() { + case "text": + text := item.Get("text").String() + if text != "" { + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text) + p++ + } + case "image_url": + imageURL := item.Get("image_url.url").String() + if strings.HasPrefix(imageURL, "data:") { // expect data:... + pieces := strings.SplitN(imageURL[len("data:"):], ";", 2) + if len(pieces) == 2 && len(pieces[1]) > 7 { + mimeType := pieces[0] + data := pieces[1][7:] + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) + p++ + } + } + } + } + } + + // Tool calls -> single model content with functionCall parts + tcs := m.Get("tool_calls") + if tcs.IsArray() { + fIDs := make([]string, 0) + for _, tc := range tcs.Array() { + if tc.Get("type").String() != "function" { + continue + } + fid := tc.Get("id").String() + fname := tc.Get("function.name").String() + fargs := tc.Get("function.arguments").String() + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname) + if gjson.Valid(fargs) { + node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs)) + } else { + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", fargs) + } + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) + p++ + if fid != "" { + fIDs = append(fIDs, fid) + } + } + out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) + + // Append a single tool content combining name + response per function + toolNode := []byte(`{"role":"user","parts":[]}`) + pp := 0 + for _, fid := range fIDs { + if name, ok := tcID2Name[fid]; ok { + toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid) + toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name) + resp := toolResponses[fid] + if resp == "" { + resp = "{}" + } + if resp != "null" { + parsed := gjson.Parse(resp) + if parsed.Type == gjson.JSON { + toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw)) + } else { + toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp)) + } + } + pp++ + } + } + if pp > 0 { + out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode) + } + } else { + out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) + } + } + } + } + + // tools -> request.tools[].functionDeclarations + request.tools[].googleSearch/codeExecution/urlContext passthrough + tools := gjson.GetBytes(rawJSON, "tools") + if tools.IsArray() && len(tools.Array()) > 0 { + functionToolNode := []byte(`{}`) + hasFunction := false + googleSearchNodes := make([][]byte, 0) + codeExecutionNodes := make([][]byte, 0) + urlContextNodes := make([][]byte, 0) + for _, t := range tools.Array() { + if t.Get("type").String() == "function" { + fn := t.Get("function") + if fn.Exists() && fn.IsObject() { + fnRaw := fn.Raw + if fn.Get("parameters").Exists() { + renamed, errRename := RenameKey(fnRaw, "parameters", "parametersJsonSchema") + if errRename != nil { + log.Printf("failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename) + fnRaw, _ = sjson.Delete(fnRaw, "parameters") + var errSet error + fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object") + if errSet != nil { + log.Printf("failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) + continue + } + fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`) + if errSet != nil { + log.Printf("failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) + continue + } + } else { + fnRaw = renamed + } + } else { + var errSet error + fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object") + if errSet != nil { + log.Printf("failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) + continue + } + fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`) + if errSet != nil { + log.Printf("failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) + continue + } + } + fnRaw, _ = sjson.Delete(fnRaw, "strict") + if !hasFunction { + functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]")) + } + tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw)) + if errSet != nil { + log.Printf("failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) + continue + } + functionToolNode = tmp + hasFunction = true + } + } + if gs := t.Get("google_search"); gs.Exists() { + googleToolNode := []byte(`{}`) + var errSet error + googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw)) + if errSet != nil { + log.Printf("failed to set googleSearch tool: %v", errSet) + continue + } + googleSearchNodes = append(googleSearchNodes, googleToolNode) + } + if ce := t.Get("code_execution"); ce.Exists() { + codeToolNode := []byte(`{}`) + var errSet error + codeToolNode, errSet = sjson.SetRawBytes(codeToolNode, "codeExecution", []byte(ce.Raw)) + if errSet != nil { + log.Printf("failed to set codeExecution tool: %v", errSet) + continue + } + codeExecutionNodes = append(codeExecutionNodes, codeToolNode) + } + if uc := t.Get("url_context"); uc.Exists() { + urlToolNode := []byte(`{}`) + var errSet error + urlToolNode, errSet = sjson.SetRawBytes(urlToolNode, "urlContext", []byte(uc.Raw)) + if errSet != nil { + log.Printf("failed to set urlContext tool: %v", errSet) + continue + } + urlContextNodes = append(urlContextNodes, urlToolNode) + } + } + if hasFunction || len(googleSearchNodes) > 0 || len(codeExecutionNodes) > 0 || len(urlContextNodes) > 0 { + toolsNode := []byte("[]") + if hasFunction { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode) + } + for _, googleNode := range googleSearchNodes { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode) + } + for _, codeNode := range codeExecutionNodes { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", codeNode) + } + for _, urlNode := range urlContextNodes { + toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", urlNode) + } + out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode) + } + } + + return attachDefaultSafetySettings(out, "request.safetySettings") +} + +func itoa(i int) string { return fmt.Sprintf("%d", i) } + +func attachDefaultSafetySettings(rawJSON []byte, path string) []byte { + if gjson.GetBytes(rawJSON, path).Exists() { + return rawJSON + } + defaults := []map[string]string{ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + } + out, err := sjson.SetBytes(rawJSON, path, defaults) + if err != nil { + return rawJSON + } + return out +} diff --git a/internal/adapter/provider/antigravity/request.go b/internal/adapter/provider/antigravity/request.go index e4b45748..0e37f81c 100644 --- a/internal/adapter/provider/antigravity/request.go +++ b/internal/adapter/provider/antigravity/request.go @@ -1,11 +1,24 @@ package antigravity import ( + "crypto/sha256" + "encoding/binary" "encoding/json" "fmt" + "math/rand" + "strconv" "strings" + "sync" + "time" "github.com/google/uuid" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +var ( + randSource = rand.New(rand.NewSource(time.Now().UnixNano())) + randSourceMutex sync.Mutex ) // RequestConfig holds resolved request configuration (like Antigravity-Manager) @@ -178,6 +191,11 @@ func wrapV1InternalRequest(body []byte, projectID, originalModel, mappedModel, s // Remove model field from inner request if present (will be at top level) delete(innerRequest, "model") + // Strip v1internal wrapper fields if client passed them through + delete(innerRequest, "project") + delete(innerRequest, "requestId") + delete(innerRequest, "requestType") + delete(innerRequest, "userAgent") // Resolve request configuration (like Antigravity-Manager) toolsForDetection := toolsForConfig @@ -215,18 +233,22 @@ func wrapV1InternalRequest(body []byte, projectID, originalModel, mappedModel, s // Deep clean [undefined] strings (Cherry Studio client common injection) deepCleanUndefined(innerRequest) - // [Safety Settings] Inject safety settings from environment variable (like Antigravity-Manager) - safetyThreshold := GetSafetyThresholdFromEnv() - innerRequest["safetySettings"] = BuildSafetySettingsMap(safetyThreshold) + // [Safety Settings] Antigravity v1internal does not accept request.safetySettings + delete(innerRequest, "safetySettings") - // [SessionID Support] If metadata.user_id was provided, use it as sessionId (like Antigravity-Manager) - if sessionID != "" { - innerRequest["sessionId"] = sessionID + // [SessionID Support] Use metadata.user_id if provided, otherwise generate a stable session id + if sessionID == "" { + sessionID = generateStableSessionID(body) } + innerRequest["sessionId"] = sessionID // Generate UUID requestId (like Antigravity-Manager) requestID := fmt.Sprintf("agent-%s", uuid.New().String()) + if strings.TrimSpace(projectID) == "" { + projectID = generateProjectID() + } + wrapped := map[string]interface{}{ "project": projectID, "requestId": requestID, @@ -236,7 +258,130 @@ func wrapV1InternalRequest(body []byte, projectID, originalModel, mappedModel, s "requestType": config.RequestType, } - return json.Marshal(wrapped) + payload, err := json.Marshal(wrapped) + if err != nil { + return nil, err + } + payload = applyAntigravityRequestTuning(payload, config.FinalModel) + return payload, nil +} + +// finalizeOpenAIWrappedRequest ensures an OpenAI->Antigravity converted request +// has required envelope fields (project/requestId/sessionId/userAgent/requestType), +// and applies Antigravity request tuning. +func finalizeOpenAIWrappedRequest(payload []byte, projectID, modelName, sessionID string) []byte { + if len(payload) == 0 { + return payload + } + if strings.TrimSpace(projectID) == "" { + projectID = generateProjectID() + } + if sessionID == "" { + sessionID = generateStableSessionID(payload) + } + + out := payload + out, _ = sjson.SetBytes(out, "project", projectID) + out, _ = sjson.SetBytes(out, "requestId", fmt.Sprintf("agent-%s", uuid.New().String())) + out, _ = sjson.SetBytes(out, "requestType", "agent") + out, _ = sjson.SetBytes(out, "userAgent", "antigravity") + out, _ = sjson.SetBytes(out, "model", modelName) + out, _ = sjson.DeleteBytes(out, "request.safetySettings") + + // Move toolConfig to request.toolConfig if needed + if toolConfig := gjson.GetBytes(out, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(out, "request.toolConfig").Exists() { + out, _ = sjson.SetRawBytes(out, "request.toolConfig", []byte(toolConfig.Raw)) + out, _ = sjson.DeleteBytes(out, "toolConfig") + } + + // Ensure sessionId + out, _ = sjson.SetBytes(out, "request.sessionId", sessionID) + return applyAntigravityRequestTuning(out, modelName) +} + +const antigravitySystemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**" + +func applyAntigravityRequestTuning(payload []byte, modelName string) []byte { + if len(payload) == 0 { + return payload + } + strJSON := string(payload) + paths := make([]string, 0) + Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths) + for _, p := range paths { + if !strings.HasSuffix(p, "parametersJsonSchema") { + continue + } + if renamed, err := RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters"); err == nil { + strJSON = renamed + } + } + + if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") { + strJSON = CleanJSONSchemaForAntigravity(strJSON) + } else { + strJSON = CleanJSONSchemaForGemini(strJSON) + } + + payload = []byte(strJSON) + + if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") { + partsResult := gjson.GetBytes(payload, "request.systemInstruction.parts") + payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user") + payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", antigravitySystemInstruction) + payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", antigravitySystemInstruction)) + if partsResult.Exists() && partsResult.IsArray() { + for _, part := range partsResult.Array() { + payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(part.Raw)) + } + } + } + + if strings.Contains(modelName, "claude") { + payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") + } else { + payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens") + } + + return payload +} + +func generateSessionID() string { + randSourceMutex.Lock() + n := randSource.Int63n(9_000_000_000_000_000_000) + randSourceMutex.Unlock() + return "-" + strconv.FormatInt(n, 10) +} + +func generateStableSessionID(payload []byte) string { + contents := gjson.GetBytes(payload, "request.contents") + if !contents.IsArray() { + contents = gjson.GetBytes(payload, "contents") + } + if contents.IsArray() { + for _, content := range contents.Array() { + if content.Get("role").String() == "user" { + text := content.Get("parts.0.text").String() + if text != "" { + h := sha256.Sum256([]byte(text)) + n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF + return "-" + strconv.FormatInt(n, 10) + } + } + } + } + return generateSessionID() +} + +func generateProjectID() string { + adjectives := []string{"useful", "bright", "swift", "calm", "bold"} + nouns := []string{"fuze", "wave", "spark", "flow", "core"} + randSourceMutex.Lock() + adj := adjectives[randSource.Intn(len(adjectives))] + noun := nouns[randSource.Intn(len(nouns))] + randSourceMutex.Unlock() + randomPart := strings.ToLower(uuid.NewString())[:5] + return adj + "-" + noun + "-" + randomPart } // stripThinkingFromClaude removes thinking config and blocks to retry without thinking (like Manager 400 retry) diff --git a/internal/adapter/provider/antigravity/request_test.go b/internal/adapter/provider/antigravity/request_test.go new file mode 100644 index 00000000..6578cfdb --- /dev/null +++ b/internal/adapter/provider/antigravity/request_test.go @@ -0,0 +1,44 @@ +package antigravity + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestApplyAntigravityRequestTuning(t *testing.T) { + input := `{ + "request": { + "systemInstruction": { + "parts": [{"text":"original"}] + }, + "tools": [{ + "functionDeclarations": [{ + "name": "t1", + "parametersJsonSchema": {"type":"object","properties":{"x":{"type":"string"}}} + }] + }] + }, + "model": "claude-sonnet-4-5" +}` + out := applyAntigravityRequestTuning([]byte(input), "claude-sonnet-4-5") + + if !gjson.GetBytes(out, "request.systemInstruction.role").Exists() { + t.Fatalf("expected systemInstruction.role to be set") + } + if gjson.GetBytes(out, "request.systemInstruction.parts.0.text").String() == "" { + t.Fatalf("expected systemInstruction parts[0].text to be injected") + } + if gjson.GetBytes(out, "request.systemInstruction.parts.1.text").String() == "" { + t.Fatalf("expected systemInstruction parts[1].text to be injected") + } + if gjson.GetBytes(out, "request.toolConfig.functionCallingConfig.mode").String() != "VALIDATED" { + t.Fatalf("expected toolConfig.functionCallingConfig.mode=VALIDATED") + } + if gjson.GetBytes(out, "request.tools.0.functionDeclarations.0.parametersJsonSchema").Exists() { + t.Fatalf("expected parametersJsonSchema to be renamed") + } + if !gjson.GetBytes(out, "request.tools.0.functionDeclarations.0.parameters").Exists() { + t.Fatalf("expected parameters to exist after rename") + } +} diff --git a/internal/adapter/provider/antigravity/response.go b/internal/adapter/provider/antigravity/response.go index b8ef97c7..53db1f62 100644 --- a/internal/adapter/provider/antigravity/response.go +++ b/internal/adapter/provider/antigravity/response.go @@ -1,10 +1,14 @@ package antigravity import ( + "bytes" "encoding/json" "fmt" "net/http" "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // Response headers to exclude when copying @@ -114,6 +118,191 @@ func isRetryableStatusCode(code int) bool { } } +// convertStreamToNonStream collects Gemini SSE stream into a single response payload. +// Ported from CLIProxyAPI Antigravity convertStreamToNonStream. +func convertStreamToNonStream(stream []byte) []byte { + responseTemplate := "" + var traceID string + var finishReason string + var modelVersion string + var responseID string + var role string + var usageRaw string + parts := make([]map[string]interface{}, 0) + var pendingKind string + var pendingText strings.Builder + var pendingThoughtSig string + + flushPending := func() { + if pendingKind == "" { + return + } + text := pendingText.String() + switch pendingKind { + case "text": + if strings.TrimSpace(text) == "" { + pendingKind = "" + pendingText.Reset() + pendingThoughtSig = "" + return + } + parts = append(parts, map[string]interface{}{"text": text}) + case "thought": + if strings.TrimSpace(text) == "" && pendingThoughtSig == "" { + pendingKind = "" + pendingText.Reset() + pendingThoughtSig = "" + return + } + part := map[string]interface{}{"thought": true} + part["text"] = text + if pendingThoughtSig != "" { + part["thoughtSignature"] = pendingThoughtSig + } + parts = append(parts, part) + } + pendingKind = "" + pendingText.Reset() + pendingThoughtSig = "" + } + + normalizePart := func(partResult gjson.Result) map[string]interface{} { + var m map[string]interface{} + _ = json.Unmarshal([]byte(partResult.Raw), &m) + if m == nil { + m = map[string]interface{}{} + } + sig := partResult.Get("thoughtSignature").String() + if sig == "" { + sig = partResult.Get("thought_signature").String() + } + if sig != "" { + m["thoughtSignature"] = sig + delete(m, "thought_signature") + } + if inlineData, ok := m["inline_data"]; ok { + m["inlineData"] = inlineData + delete(m, "inline_data") + } + return m + } + + for _, line := range bytes.Split(stream, []byte("\n")) { + trimmed := bytes.TrimSpace(line) + trimmed = bytes.TrimPrefix(trimmed, []byte("data: ")) + if len(trimmed) == 0 || !gjson.ValidBytes(trimmed) { + continue + } + + root := gjson.ParseBytes(trimmed) + responseNode := root.Get("response") + if !responseNode.Exists() { + if root.Get("candidates").Exists() { + responseNode = root + } else { + continue + } + } + responseTemplate = responseNode.Raw + + if traceResult := root.Get("traceId"); traceResult.Exists() && traceResult.String() != "" { + traceID = traceResult.String() + } + + if roleResult := responseNode.Get("candidates.0.content.role"); roleResult.Exists() { + role = roleResult.String() + } + + if finishResult := responseNode.Get("candidates.0.finishReason"); finishResult.Exists() && finishResult.String() != "" { + finishReason = finishResult.String() + } + + if modelResult := responseNode.Get("modelVersion"); modelResult.Exists() && modelResult.String() != "" { + modelVersion = modelResult.String() + } + if responseIDResult := responseNode.Get("responseId"); responseIDResult.Exists() && responseIDResult.String() != "" { + responseID = responseIDResult.String() + } + if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() { + usageRaw = usageResult.Raw + } else if usageMetadataResult := root.Get("usageMetadata"); usageMetadataResult.Exists() { + usageRaw = usageMetadataResult.Raw + } + + if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() { + for _, part := range partsResult.Array() { + hasFunctionCall := part.Get("functionCall").Exists() + hasInlineData := part.Get("inlineData").Exists() || part.Get("inline_data").Exists() + sig := part.Get("thoughtSignature").String() + if sig == "" { + sig = part.Get("thought_signature").String() + } + text := part.Get("text").String() + thought := part.Get("thought").Bool() + + if hasFunctionCall || hasInlineData { + flushPending() + parts = append(parts, normalizePart(part)) + continue + } + + if thought || part.Get("text").Exists() { + kind := "text" + if thought { + kind = "thought" + } + if pendingKind != "" && pendingKind != kind { + flushPending() + } + pendingKind = kind + pendingText.WriteString(text) + if kind == "thought" && sig != "" { + pendingThoughtSig = sig + } + continue + } + + flushPending() + parts = append(parts, normalizePart(part)) + } + } + } + flushPending() + + if responseTemplate == "" { + responseTemplate = `{"candidates":[{"content":{"role":"model","parts":[]}}]}` + } + + partsJSON, _ := json.Marshal(parts) + responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON)) + if role != "" { + responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role) + } + if finishReason != "" { + responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason) + } + if modelVersion != "" { + responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion) + } + if responseID != "" { + responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID) + } + if usageRaw != "" { + responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw) + } else if !gjson.Get(responseTemplate, "usageMetadata").Exists() { + responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0) + responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0) + responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0) + } + + output := `{"response":{},"traceId":""}` + output, _ = sjson.SetRaw(output, "response", responseTemplate) + if traceID != "" { + output, _ = sjson.Set(output, "traceId", traceID) + } + return []byte(output) +} + // convertGeminiToClaudeResponse converts a non-streaming Gemini response to Claude format // (like Antigravity-Manager's response conversion) func convertGeminiToClaudeResponse(geminiBody []byte, requestModel string) ([]byte, error) { diff --git a/internal/adapter/provider/antigravity/schema_cleaner.go b/internal/adapter/provider/antigravity/schema_cleaner.go new file mode 100644 index 00000000..41f4c371 --- /dev/null +++ b/internal/adapter/provider/antigravity/schema_cleaner.go @@ -0,0 +1,731 @@ +// Package util provides utility functions for the CLI Proxy API server. +package antigravity + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?") + +const placeholderReasonDescription = "Brief explanation of why you are calling this tool" + +// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API. +// It handles unsupported keywords, type flattening, and schema simplification while preserving +// semantic information as description hints. +func CleanJSONSchemaForAntigravity(jsonStr string) string { + return cleanJSONSchema(jsonStr, true) +} + +// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini tool calling. +// It removes unsupported keywords and simplifies schemas, without adding empty-schema placeholders. +func CleanJSONSchemaForGemini(jsonStr string) string { + return cleanJSONSchema(jsonStr, false) +} + +// cleanJSONSchema performs the core cleaning operations on the JSON schema. +func cleanJSONSchema(jsonStr string, addPlaceholder bool) string { + // Phase 1: Convert and add hints + jsonStr = convertRefsToHints(jsonStr) + jsonStr = convertConstToEnum(jsonStr) + jsonStr = convertEnumValuesToStrings(jsonStr) + jsonStr = addEnumHints(jsonStr) + jsonStr = addAdditionalPropertiesHints(jsonStr) + jsonStr = moveConstraintsToDescription(jsonStr) + + // Phase 2: Flatten complex structures + jsonStr = mergeAllOf(jsonStr) + jsonStr = flattenAnyOfOneOf(jsonStr) + jsonStr = flattenTypeArrays(jsonStr) + + // Phase 3: Cleanup + jsonStr = removeUnsupportedKeywords(jsonStr) + if !addPlaceholder { + // Gemini schema cleanup: remove nullable/title and placeholder-only fields. + jsonStr = removeKeywords(jsonStr, []string{"nullable", "title"}) + jsonStr = removePlaceholderFields(jsonStr) + } + jsonStr = cleanupRequiredFields(jsonStr) + // Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement) + if addPlaceholder { + jsonStr = addEmptySchemaPlaceholder(jsonStr) + } + + return jsonStr +} + +// removeKeywords removes all occurrences of specified keywords from the JSON schema. +func removeKeywords(jsonStr string, keywords []string) string { + for _, key := range keywords { + for _, p := range findPaths(jsonStr, key) { + if isPropertyDefinition(trimSuffix(p, "."+key)) { + continue + } + jsonStr, _ = sjson.Delete(jsonStr, p) + } + } + return jsonStr +} + +// removePlaceholderFields removes placeholder-only properties ("_" and "reason") and their required entries. +func removePlaceholderFields(jsonStr string) string { + // Remove "_" placeholder properties. + paths := findPaths(jsonStr, "_") + sortByDepth(paths) + for _, p := range paths { + if !strings.HasSuffix(p, ".properties._") { + continue + } + jsonStr, _ = sjson.Delete(jsonStr, p) + parentPath := trimSuffix(p, ".properties._") + reqPath := joinPath(parentPath, "required") + req := gjson.Get(jsonStr, reqPath) + if req.IsArray() { + var filtered []string + for _, r := range req.Array() { + if r.String() != "_" { + filtered = append(filtered, r.String()) + } + } + if len(filtered) == 0 { + jsonStr, _ = sjson.Delete(jsonStr, reqPath) + } else { + jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered) + } + } + } + + // Remove placeholder-only "reason" objects. + reasonPaths := findPaths(jsonStr, "reason") + sortByDepth(reasonPaths) + for _, p := range reasonPaths { + if !strings.HasSuffix(p, ".properties.reason") { + continue + } + parentPath := trimSuffix(p, ".properties.reason") + props := gjson.Get(jsonStr, joinPath(parentPath, "properties")) + if !props.IsObject() || len(props.Map()) != 1 { + continue + } + desc := gjson.Get(jsonStr, p+".description").String() + if desc != placeholderReasonDescription { + continue + } + jsonStr, _ = sjson.Delete(jsonStr, p) + reqPath := joinPath(parentPath, "required") + req := gjson.Get(jsonStr, reqPath) + if req.IsArray() { + var filtered []string + for _, r := range req.Array() { + if r.String() != "reason" { + filtered = append(filtered, r.String()) + } + } + if len(filtered) == 0 { + jsonStr, _ = sjson.Delete(jsonStr, reqPath) + } else { + jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered) + } + } + } + + return jsonStr +} + +// convertRefsToHints converts $ref to description hints (Lazy Hint strategy). +func convertRefsToHints(jsonStr string) string { + paths := findPaths(jsonStr, "$ref") + sortByDepth(paths) + + for _, p := range paths { + refVal := gjson.Get(jsonStr, p).String() + defName := refVal + if idx := strings.LastIndex(refVal, "/"); idx >= 0 { + defName = refVal[idx+1:] + } + + parentPath := trimSuffix(p, ".$ref") + hint := fmt.Sprintf("See: %s", defName) + if existing := gjson.Get(jsonStr, descriptionPath(parentPath)).String(); existing != "" { + hint = fmt.Sprintf("%s (%s)", existing, hint) + } + + replacement := `{"type":"object","description":""}` + replacement, _ = sjson.Set(replacement, "description", hint) + jsonStr = setRawAt(jsonStr, parentPath, replacement) + } + return jsonStr +} + +func convertConstToEnum(jsonStr string) string { + for _, p := range findPaths(jsonStr, "const") { + val := gjson.Get(jsonStr, p) + if !val.Exists() { + continue + } + enumPath := trimSuffix(p, ".const") + ".enum" + if !gjson.Get(jsonStr, enumPath).Exists() { + jsonStr, _ = sjson.Set(jsonStr, enumPath, []interface{}{val.Value()}) + } + } + return jsonStr +} + +// convertEnumValuesToStrings ensures all enum values are strings and the schema type is set to string. +// Gemini API requires enum values to be of type string, not numbers or booleans. +func convertEnumValuesToStrings(jsonStr string) string { + for _, p := range findPaths(jsonStr, "enum") { + arr := gjson.Get(jsonStr, p) + if !arr.IsArray() { + continue + } + + var stringVals []string + for _, item := range arr.Array() { + stringVals = append(stringVals, item.String()) + } + + // Always update enum values to strings and set type to "string" + // This ensures compatibility with Antigravity Gemini which only allows enum for STRING type + jsonStr, _ = sjson.Set(jsonStr, p, stringVals) + parentPath := trimSuffix(p, ".enum") + jsonStr, _ = sjson.Set(jsonStr, joinPath(parentPath, "type"), "string") + } + return jsonStr +} + +func addEnumHints(jsonStr string) string { + for _, p := range findPaths(jsonStr, "enum") { + arr := gjson.Get(jsonStr, p) + if !arr.IsArray() { + continue + } + items := arr.Array() + if len(items) <= 1 || len(items) > 10 { + continue + } + + var vals []string + for _, item := range items { + vals = append(vals, item.String()) + } + jsonStr = appendHint(jsonStr, trimSuffix(p, ".enum"), "Allowed: "+strings.Join(vals, ", ")) + } + return jsonStr +} + +func addAdditionalPropertiesHints(jsonStr string) string { + for _, p := range findPaths(jsonStr, "additionalProperties") { + if gjson.Get(jsonStr, p).Type == gjson.False { + jsonStr = appendHint(jsonStr, trimSuffix(p, ".additionalProperties"), "No extra properties allowed") + } + } + return jsonStr +} + +var unsupportedConstraints = []string{ + "minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum", + "pattern", "minItems", "maxItems", "format", + "default", "examples", // Claude rejects these in VALIDATED mode +} + +func moveConstraintsToDescription(jsonStr string) string { + for _, key := range unsupportedConstraints { + for _, p := range findPaths(jsonStr, key) { + val := gjson.Get(jsonStr, p) + if !val.Exists() || val.IsObject() || val.IsArray() { + continue + } + parentPath := trimSuffix(p, "."+key) + if isPropertyDefinition(parentPath) { + continue + } + jsonStr = appendHint(jsonStr, parentPath, fmt.Sprintf("%s: %s", key, val.String())) + } + } + return jsonStr +} + +func mergeAllOf(jsonStr string) string { + paths := findPaths(jsonStr, "allOf") + sortByDepth(paths) + + for _, p := range paths { + allOf := gjson.Get(jsonStr, p) + if !allOf.IsArray() { + continue + } + parentPath := trimSuffix(p, ".allOf") + + for _, item := range allOf.Array() { + if props := item.Get("properties"); props.IsObject() { + props.ForEach(func(key, value gjson.Result) bool { + destPath := joinPath(parentPath, "properties."+escapeGJSONPathKey(key.String())) + jsonStr, _ = sjson.SetRaw(jsonStr, destPath, value.Raw) + return true + }) + } + if req := item.Get("required"); req.IsArray() { + reqPath := joinPath(parentPath, "required") + current := getStrings(jsonStr, reqPath) + for _, r := range req.Array() { + if s := r.String(); !contains(current, s) { + current = append(current, s) + } + } + jsonStr, _ = sjson.Set(jsonStr, reqPath, current) + } + } + jsonStr, _ = sjson.Delete(jsonStr, p) + } + return jsonStr +} + +func flattenAnyOfOneOf(jsonStr string) string { + for _, key := range []string{"anyOf", "oneOf"} { + paths := findPaths(jsonStr, key) + sortByDepth(paths) + + for _, p := range paths { + arr := gjson.Get(jsonStr, p) + if !arr.IsArray() || len(arr.Array()) == 0 { + continue + } + + parentPath := trimSuffix(p, "."+key) + parentDesc := gjson.Get(jsonStr, descriptionPath(parentPath)).String() + + items := arr.Array() + bestIdx, allTypes := selectBest(items) + selected := items[bestIdx].Raw + + if parentDesc != "" { + selected = mergeDescriptionRaw(selected, parentDesc) + } + + if len(allTypes) > 1 { + hint := "Accepts: " + strings.Join(allTypes, " | ") + selected = appendHintRaw(selected, hint) + } + + jsonStr = setRawAt(jsonStr, parentPath, selected) + } + } + return jsonStr +} + +func selectBest(items []gjson.Result) (bestIdx int, types []string) { + bestScore := -1 + for i, item := range items { + t := item.Get("type").String() + score := 0 + + switch { + case t == "object" || item.Get("properties").Exists(): + score, t = 3, orDefault(t, "object") + case t == "array" || item.Get("items").Exists(): + score, t = 2, orDefault(t, "array") + case t != "" && t != "null": + score = 1 + default: + t = orDefault(t, "null") + } + + if t != "" { + types = append(types, t) + } + if score > bestScore { + bestScore, bestIdx = score, i + } + } + return +} + +func flattenTypeArrays(jsonStr string) string { + paths := findPaths(jsonStr, "type") + sortByDepth(paths) + + nullableFields := make(map[string][]string) + + for _, p := range paths { + res := gjson.Get(jsonStr, p) + if !res.IsArray() || len(res.Array()) == 0 { + continue + } + + hasNull := false + var nonNullTypes []string + for _, item := range res.Array() { + s := item.String() + if s == "null" { + hasNull = true + } else if s != "" { + nonNullTypes = append(nonNullTypes, s) + } + } + + firstType := "string" + if len(nonNullTypes) > 0 { + firstType = nonNullTypes[0] + } + + jsonStr, _ = sjson.Set(jsonStr, p, firstType) + + parentPath := trimSuffix(p, ".type") + if len(nonNullTypes) > 1 { + hint := "Accepts: " + strings.Join(nonNullTypes, " | ") + jsonStr = appendHint(jsonStr, parentPath, hint) + } + + if hasNull { + parts := splitGJSONPath(p) + if len(parts) >= 3 && parts[len(parts)-3] == "properties" { + fieldNameEscaped := parts[len(parts)-2] + fieldName := unescapeGJSONPathKey(fieldNameEscaped) + objectPath := strings.Join(parts[:len(parts)-3], ".") + nullableFields[objectPath] = append(nullableFields[objectPath], fieldName) + + propPath := joinPath(objectPath, "properties."+fieldNameEscaped) + jsonStr = appendHint(jsonStr, propPath, "(nullable)") + } + } + } + + for objectPath, fields := range nullableFields { + reqPath := joinPath(objectPath, "required") + req := gjson.Get(jsonStr, reqPath) + if !req.IsArray() { + continue + } + + var filtered []string + for _, r := range req.Array() { + if !contains(fields, r.String()) { + filtered = append(filtered, r.String()) + } + } + + if len(filtered) == 0 { + jsonStr, _ = sjson.Delete(jsonStr, reqPath) + } else { + jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered) + } + } + return jsonStr +} + +func removeUnsupportedKeywords(jsonStr string) string { + keywords := append(unsupportedConstraints, + "$schema", "$defs", "definitions", "const", "$ref", "additionalProperties", + "propertyNames", // Gemini doesn't support property name validation + ) + for _, key := range keywords { + for _, p := range findPaths(jsonStr, key) { + if isPropertyDefinition(trimSuffix(p, "."+key)) { + continue + } + jsonStr, _ = sjson.Delete(jsonStr, p) + } + } + // Remove x-* extension fields (e.g., x-google-enum-descriptions) that are not supported by Gemini API + jsonStr = removeExtensionFields(jsonStr) + return jsonStr +} + +// removeExtensionFields removes all x-* extension fields from the JSON schema. +// These are OpenAPI/JSON Schema extension fields that Google APIs don't recognize. +func removeExtensionFields(jsonStr string) string { + var paths []string + walkForExtensions(gjson.Parse(jsonStr), "", &paths) + // walkForExtensions returns paths in a way that deeper paths are added before their ancestors + // when they are not deleted wholesale, but since we skip children of deleted x-* nodes, + // any collected path is safe to delete. We still use DeleteBytes for efficiency. + + b := []byte(jsonStr) + for _, p := range paths { + b, _ = sjson.DeleteBytes(b, p) + } + return string(b) +} + +func walkForExtensions(value gjson.Result, path string, paths *[]string) { + if value.IsArray() { + arr := value.Array() + for i := len(arr) - 1; i >= 0; i-- { + walkForExtensions(arr[i], joinPath(path, strconv.Itoa(i)), paths) + } + return + } + + if value.IsObject() { + value.ForEach(func(key, val gjson.Result) bool { + keyStr := key.String() + safeKey := escapeGJSONPathKey(keyStr) + childPath := joinPath(path, safeKey) + + // If it's an extension field, we delete it and don't need to look at its children. + if strings.HasPrefix(keyStr, "x-") && !isPropertyDefinition(path) { + *paths = append(*paths, childPath) + return true + } + + walkForExtensions(val, childPath, paths) + return true + }) + } +} + +func cleanupRequiredFields(jsonStr string) string { + for _, p := range findPaths(jsonStr, "required") { + parentPath := trimSuffix(p, ".required") + propsPath := joinPath(parentPath, "properties") + + req := gjson.Get(jsonStr, p) + props := gjson.Get(jsonStr, propsPath) + if !req.IsArray() || !props.IsObject() { + continue + } + + var valid []string + for _, r := range req.Array() { + key := r.String() + if props.Get(escapeGJSONPathKey(key)).Exists() { + valid = append(valid, key) + } + } + + if len(valid) != len(req.Array()) { + if len(valid) == 0 { + jsonStr, _ = sjson.Delete(jsonStr, p) + } else { + jsonStr, _ = sjson.Set(jsonStr, p, valid) + } + } + } + return jsonStr +} + +// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas. +// Claude VALIDATED mode requires at least one required property in tool schemas. +func addEmptySchemaPlaceholder(jsonStr string) string { + // Find all "type" fields + paths := findPaths(jsonStr, "type") + + // Process from deepest to shallowest (to handle nested objects properly) + sortByDepth(paths) + + for _, p := range paths { + typeVal := gjson.Get(jsonStr, p) + if typeVal.String() != "object" { + continue + } + + // Get the parent path (the object containing "type") + parentPath := trimSuffix(p, ".type") + + // Check if properties exists and is empty or missing + propsPath := joinPath(parentPath, "properties") + propsVal := gjson.Get(jsonStr, propsPath) + reqPath := joinPath(parentPath, "required") + reqVal := gjson.Get(jsonStr, reqPath) + hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0 + + needsPlaceholder := false + if !propsVal.Exists() { + // No properties field at all + needsPlaceholder = true + } else if propsVal.IsObject() && len(propsVal.Map()) == 0 { + // Empty properties object + needsPlaceholder = true + } + + if needsPlaceholder { + // Add placeholder "reason" property + reasonPath := joinPath(propsPath, "reason") + jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string") + jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", placeholderReasonDescription) + + // Add to required array + jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"}) + continue + } + + // If schema has properties but none are required, add a minimal placeholder. + if propsVal.IsObject() && !hasRequiredProperties { + // DO NOT add placeholder if it's a top-level schema (parentPath is empty) + // or if we've already added a placeholder reason above. + if parentPath == "" { + continue + } + placeholderPath := joinPath(propsPath, "_") + if !gjson.Get(jsonStr, placeholderPath).Exists() { + jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean") + } + jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"}) + } + } + + return jsonStr +} + +// --- Helpers --- + +func findPaths(jsonStr, field string) []string { + var paths []string + Walk(gjson.Parse(jsonStr), "", field, &paths) + return paths +} + +func sortByDepth(paths []string) { + sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) }) +} + +func trimSuffix(path, suffix string) string { + if path == strings.TrimPrefix(suffix, ".") { + return "" + } + return strings.TrimSuffix(path, suffix) +} + +func joinPath(base, suffix string) string { + if base == "" { + return suffix + } + return base + "." + suffix +} + +func setRawAt(jsonStr, path, value string) string { + if path == "" { + return value + } + result, _ := sjson.SetRaw(jsonStr, path, value) + return result +} + +func isPropertyDefinition(path string) bool { + return path == "properties" || strings.HasSuffix(path, ".properties") +} + +func descriptionPath(parentPath string) string { + if parentPath == "" || parentPath == "@this" { + return "description" + } + return parentPath + ".description" +} + +func appendHint(jsonStr, parentPath, hint string) string { + descPath := parentPath + ".description" + if parentPath == "" || parentPath == "@this" { + descPath = "description" + } + existing := gjson.Get(jsonStr, descPath).String() + if existing != "" { + hint = fmt.Sprintf("%s (%s)", existing, hint) + } + jsonStr, _ = sjson.Set(jsonStr, descPath, hint) + return jsonStr +} + +func appendHintRaw(jsonRaw, hint string) string { + existing := gjson.Get(jsonRaw, "description").String() + if existing != "" { + hint = fmt.Sprintf("%s (%s)", existing, hint) + } + jsonRaw, _ = sjson.Set(jsonRaw, "description", hint) + return jsonRaw +} + +func getStrings(jsonStr, path string) []string { + var result []string + if arr := gjson.Get(jsonStr, path); arr.IsArray() { + for _, r := range arr.Array() { + result = append(result, r.String()) + } + } + return result +} + +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +func orDefault(val, def string) string { + if val == "" { + return def + } + return val +} + +func escapeGJSONPathKey(key string) string { + return gjsonPathKeyReplacer.Replace(key) +} + +func unescapeGJSONPathKey(key string) string { + if !strings.Contains(key, "\\") { + return key + } + var b strings.Builder + b.Grow(len(key)) + for i := 0; i < len(key); i++ { + if key[i] == '\\' && i+1 < len(key) { + i++ + b.WriteByte(key[i]) + continue + } + b.WriteByte(key[i]) + } + return b.String() +} + +func splitGJSONPath(path string) []string { + if path == "" { + return nil + } + + parts := make([]string, 0, strings.Count(path, ".")+1) + var b strings.Builder + b.Grow(len(path)) + + for i := 0; i < len(path); i++ { + c := path[i] + if c == '\\' && i+1 < len(path) { + b.WriteByte('\\') + i++ + b.WriteByte(path[i]) + continue + } + if c == '.' { + parts = append(parts, b.String()) + b.Reset() + continue + } + b.WriteByte(c) + } + parts = append(parts, b.String()) + return parts +} + +func mergeDescriptionRaw(schemaRaw, parentDesc string) string { + childDesc := gjson.Get(schemaRaw, "description").String() + switch { + case childDesc == "": + schemaRaw, _ = sjson.Set(schemaRaw, "description", parentDesc) + return schemaRaw + case childDesc == parentDesc: + return schemaRaw + default: + combined := fmt.Sprintf("%s (%s)", parentDesc, childDesc) + schemaRaw, _ = sjson.Set(schemaRaw, "description", combined) + return schemaRaw + } +} diff --git a/internal/adapter/provider/antigravity/service.go b/internal/adapter/provider/antigravity/service.go index 96eb66b1..e5d11d1f 100644 --- a/internal/adapter/provider/antigravity/service.go +++ b/internal/adapter/provider/antigravity/service.go @@ -27,8 +27,8 @@ const ( UserAgentLoadCodeAssist = "antigravity/windows/amd64" // fetchAvailableModels 使用带版本号的 User-Agent UserAgentFetchModels = "antigravity/1.11.3 Darwin/arm64" - // 代理请求使用的 User-Agent - AntigravityUserAgent = "antigravity/1.11.9 windows/amd64" + // 代理请求使用的 User-Agent (CLIProxyAPI default) + AntigravityUserAgent = "antigravity/1.104.0 darwin/arm64" // 默认 Project ID (当 API 未返回时使用) DefaultProjectID = "bamboo-precept-lgxtn" diff --git a/internal/adapter/provider/antigravity/transform_request.go b/internal/adapter/provider/antigravity/transform_request.go index 8a9609ac..e3a2f536 100644 --- a/internal/adapter/provider/antigravity/transform_request.go +++ b/internal/adapter/provider/antigravity/transform_request.go @@ -74,13 +74,7 @@ func TransformClaudeToGemini( genConfig := buildGenerationConfig(&claudeReq, mappedModel, stream, hasThinking) geminiReq["generationConfig"] = genConfig - // 5.5 Safety Settings (configurable via environment) - // Reference: Antigravity-Manager's build_safety_settings - safetyThreshold := GetSafetyThresholdFromEnv() - safetySettings := BuildSafetySettingsMap(safetyThreshold) - geminiReq["safetySettings"] = safetySettings - - // 5.6 Deep clean [undefined] strings (Cherry Studio injection fix) + // 5.5 Deep clean [undefined] strings (Cherry Studio injection fix) // Reference: Antigravity-Manager line 278 deepCleanUndefined(geminiReq) diff --git a/internal/adapter/provider/antigravity/translator_helpers.go b/internal/adapter/provider/antigravity/translator_helpers.go new file mode 100644 index 00000000..2303b1db --- /dev/null +++ b/internal/adapter/provider/antigravity/translator_helpers.go @@ -0,0 +1,231 @@ +// Package util provides utility functions for the CLI Proxy API server. +// It includes helper functions for JSON manipulation, proxy configuration, +// and other common operations used across the application. +package antigravity + +import ( + "bytes" + "fmt" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// Walk recursively traverses a JSON structure to find all occurrences of a specific field. +// It builds paths to each occurrence and adds them to the provided paths slice. +// +// Parameters: +// - value: The gjson.Result object to traverse +// - path: The current path in the JSON structure (empty string for root) +// - field: The field name to search for +// - paths: Pointer to a slice where found paths will be stored +// +// The function works recursively, building dot-notation paths to each occurrence +// of the specified field throughout the JSON structure. +func Walk(value gjson.Result, path, field string, paths *[]string) { + switch value.Type { + case gjson.JSON: + // For JSON objects and arrays, iterate through each child + value.ForEach(func(key, val gjson.Result) bool { + var childPath string + // Escape special characters for gjson/sjson path syntax + // . -> \. + // * -> \* + // ? -> \? + var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?") + safeKey := keyReplacer.Replace(key.String()) + + if path == "" { + childPath = safeKey + } else { + childPath = path + "." + safeKey + } + if key.String() == field { + *paths = append(*paths, childPath) + } + Walk(val, childPath, field, paths) + return true + }) + case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null: + // Terminal types - no further traversal needed + } +} + +// RenameKey renames a key in a JSON string by moving its value to a new key path +// and then deleting the old key path. +// +// Parameters: +// - jsonStr: The JSON string to modify +// - oldKeyPath: The dot-notation path to the key that should be renamed +// - newKeyPath: The dot-notation path where the value should be moved to +// +// Returns: +// - string: The modified JSON string with the key renamed +// - error: An error if the operation fails +// +// The function performs the rename in two steps: +// 1. Sets the value at the new key path +// 2. Deletes the old key path +func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) { + value := gjson.Get(jsonStr, oldKeyPath) + + if !value.Exists() { + return "", fmt.Errorf("old key '%s' does not exist", oldKeyPath) + } + + interimJson, err := sjson.SetRaw(jsonStr, newKeyPath, value.Raw) + if err != nil { + return "", fmt.Errorf("failed to set new key '%s': %w", newKeyPath, err) + } + + finalJson, err := sjson.Delete(interimJson, oldKeyPath) + if err != nil { + return "", fmt.Errorf("failed to delete old key '%s': %w", oldKeyPath, err) + } + + return finalJson, nil +} + +func DeleteKey(jsonStr, keyName string) string { + paths := make([]string, 0) + Walk(gjson.Parse(jsonStr), "", keyName, &paths) + for _, p := range paths { + jsonStr, _ = sjson.Delete(jsonStr, p) + } + return jsonStr +} + +// FixJSON converts non-standard JSON that uses single quotes for strings into +// RFC 8259-compliant JSON by converting those single-quoted strings to +// double-quoted strings with proper escaping. +// +// Examples: +// +// {'a': 1, 'b': '2'} => {"a": 1, "b": "2"} +// {"t": 'He said "hi"'} => {"t": "He said \"hi\""} +// +// Rules: +// - Existing double-quoted JSON strings are preserved as-is. +// - Single-quoted strings are converted to double-quoted strings. +// - Inside converted strings, any double quote is escaped (\"). +// - Common backslash escapes (\n, \r, \t, \b, \f, \\) are preserved. +// - \' inside single-quoted strings becomes a literal ' in the output (no +// escaping needed inside double quotes). +// - Unicode escapes (\uXXXX) inside single-quoted strings are forwarded. +// - The function does not attempt to fix other non-JSON features beyond quotes. +func FixJSON(input string) string { + var out bytes.Buffer + + inDouble := false + inSingle := false + escaped := false // applies within the current string state + + // Helper to write a rune, escaping double quotes when inside a converted + // single-quoted string (which becomes a double-quoted string in output). + writeConverted := func(r rune) { + if r == '"' { + out.WriteByte('\\') + out.WriteByte('"') + return + } + out.WriteRune(r) + } + + runes := []rune(input) + for i := 0; i < len(runes); i++ { + r := runes[i] + + if inDouble { + out.WriteRune(r) + if escaped { + // end of escape sequence in a standard JSON string + escaped = false + continue + } + if r == '\\' { + escaped = true + continue + } + if r == '"' { + inDouble = false + } + continue + } + + if inSingle { + if escaped { + // Handle common escape sequences after a backslash within a + // single-quoted string + escaped = false + switch r { + case 'n', 'r', 't', 'b', 'f', '/', '"': + // Keep the backslash and the character (except for '"' which + // rarely appears, but if it does, keep as \" to remain valid) + out.WriteByte('\\') + out.WriteRune(r) + case '\\': + out.WriteByte('\\') + out.WriteByte('\\') + case '\'': + // \' inside single-quoted becomes a literal ' + out.WriteRune('\'') + case 'u': + // Forward \uXXXX if possible + out.WriteByte('\\') + out.WriteByte('u') + // Copy up to next 4 hex digits if present + for k := 0; k < 4 && i+1 < len(runes); k++ { + peek := runes[i+1] + // simple hex check + if (peek >= '0' && peek <= '9') || (peek >= 'a' && peek <= 'f') || (peek >= 'A' && peek <= 'F') { + out.WriteRune(peek) + i++ + } else { + break + } + } + default: + // Unknown escape: preserve the backslash and the char + out.WriteByte('\\') + out.WriteRune(r) + } + continue + } + + if r == '\\' { // start escape sequence + escaped = true + continue + } + if r == '\'' { // end of single-quoted string + out.WriteByte('"') + inSingle = false + continue + } + // regular char inside converted string; escape double quotes + writeConverted(r) + continue + } + + // Outside any string + if r == '"' { + inDouble = true + out.WriteRune(r) + continue + } + if r == '\'' { // start of non-standard single-quoted string + inSingle = true + out.WriteByte('"') + continue + } + out.WriteRune(r) + } + + // If input ended while still inside a single-quoted string, close it to + // produce the best-effort valid JSON. + if inSingle { + out.WriteByte('"') + } + + return out.String() +} diff --git a/internal/adapter/provider/cliproxyapi_antigravity/adapter.go b/internal/adapter/provider/cliproxyapi_antigravity/adapter.go index 9d80dcbd..ff138f1f 100644 --- a/internal/adapter/provider/cliproxyapi_antigravity/adapter.go +++ b/internal/adapter/provider/cliproxyapi_antigravity/adapter.go @@ -11,8 +11,8 @@ import ( "time" "github.com/awsl-project/maxx/internal/adapter/provider" - ctxutil "github.com/awsl-project/maxx/internal/context" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -56,12 +56,14 @@ func (a *CLIProxyAPIAntigravityAdapter) SupportedClientTypes() []domain.ClientTy return []domain.ClientType{domain.ClientTypeClaude, domain.ClientTypeGemini} } -func (a *CLIProxyAPIAntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, p *domain.Provider) error { - clientType := ctxutil.GetClientType(ctx) - requestBody := ctxutil.GetRequestBody(ctx) - stream := ctxutil.GetIsStream(ctx) - requestModel := ctxutil.GetRequestModel(ctx) - model := ctxutil.GetMappedModel(ctx) // 全局映射后的模型名(已包含 ProviderType 条件) +func (a *CLIProxyAPIAntigravityAdapter) Execute(c *flow.Ctx, p *domain.Provider) error { + w := c.Writer + + clientType := flow.GetClientType(c) + requestBody := flow.GetRequestBody(c) + stream := flow.GetIsStream(c) + requestModel := flow.GetRequestModel(c) + model := flow.GetMappedModel(c) // 全局映射后的模型名(已包含 ProviderType 条件) log.Printf("[CLIProxyAPI-Antigravity] requestModel=%s, mappedModel=%s, clientType=%s", requestModel, model, clientType) @@ -72,7 +74,7 @@ func (a *CLIProxyAPIAntigravityAdapter) Execute(ctx context.Context, w http.Resp } // 发送事件 - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { eventChan.SendRequestInfo(&domain.RequestInfo{ Method: "POST", URL: fmt.Sprintf("cliproxyapi://antigravity/%s", model), @@ -105,9 +107,9 @@ func (a *CLIProxyAPIAntigravityAdapter) Execute(ctx context.Context, w http.Resp } if stream { - return a.executeStream(ctx, w, execReq, execOpts) + return a.executeStream(c, w, execReq, execOpts) } - return a.executeNonStream(ctx, w, execReq, execOpts) + return a.executeNonStream(c, w, execReq, execOpts) } // updateModelInBody 替换 body 中的 model 字段 @@ -120,14 +122,19 @@ func updateModelInBody(body []byte, model string) ([]byte, error) { return json.Marshal(req) } -func (a *CLIProxyAPIAntigravityAdapter) executeNonStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { +func (a *CLIProxyAPIAntigravityAdapter) executeNonStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } + resp, err := a.executor.Execute(ctx, a.authObj, execReq, execOpts) if err != nil { log.Printf("[CLIProxyAPI-Antigravity] executeNonStream error: model=%s, err=%v", execReq.Model, err) return domain.NewProxyErrorWithMessage(err, true, fmt.Sprintf("executor request failed: %v", err)) } - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { // Send response info eventChan.SendResponseInfo(&domain.ResponseInfo{ Status: http.StatusOK, @@ -159,13 +166,16 @@ func (a *CLIProxyAPIAntigravityAdapter) executeNonStream(ctx context.Context, w return nil } -func (a *CLIProxyAPIAntigravityAdapter) executeStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { +func (a *CLIProxyAPIAntigravityAdapter) executeStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { flusher, ok := w.(http.Flusher) if !ok { - return a.executeNonStream(ctx, w, execReq, execOpts) + return a.executeNonStream(c, w, execReq, execOpts) } - startTime := time.Now() + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } stream, err := a.executor.ExecuteStream(ctx, a.authObj, execReq, execOpts) if err != nil { @@ -179,7 +189,7 @@ func (a *CLIProxyAPIAntigravityAdapter) executeStream(ctx context.Context, w htt w.Header().Set("Connection", "keep-alive") w.WriteHeader(http.StatusOK) - eventChan := ctxutil.GetEventChan(ctx) + eventChan := flow.GetEventChan(c) // Collect SSE content for token extraction var sseBuffer bytes.Buffer @@ -200,7 +210,7 @@ func (a *CLIProxyAPIAntigravityAdapter) executeStream(ctx context.Context, w htt // Report TTFT on first non-empty chunk if !firstChunkSent && eventChan != nil { - eventChan.SendFirstToken(time.Since(startTime).Milliseconds()) + eventChan.SendFirstToken(time.Now().UnixMilli()) firstChunkSent = true } } diff --git a/internal/adapter/provider/cliproxyapi_codex/adapter.go b/internal/adapter/provider/cliproxyapi_codex/adapter.go index bab635c3..c1266b9c 100644 --- a/internal/adapter/provider/cliproxyapi_codex/adapter.go +++ b/internal/adapter/provider/cliproxyapi_codex/adapter.go @@ -11,8 +11,8 @@ import ( "time" "github.com/awsl-project/maxx/internal/adapter/provider" - ctxutil "github.com/awsl-project/maxx/internal/context" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -53,16 +53,18 @@ func (a *CLIProxyAPICodexAdapter) SupportedClientTypes() []domain.ClientType { return []domain.ClientType{domain.ClientTypeCodex} } -func (a *CLIProxyAPICodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, p *domain.Provider) error { - requestBody := ctxutil.GetRequestBody(ctx) - stream := ctxutil.GetIsStream(ctx) - model := ctxutil.GetMappedModel(ctx) +func (a *CLIProxyAPICodexAdapter) Execute(c *flow.Ctx, p *domain.Provider) error { + w := c.Writer + + requestBody := flow.GetRequestBody(c) + stream := flow.GetIsStream(c) + model := flow.GetMappedModel(c) // Codex CLI 使用 OpenAI Responses API 格式 sourceFormat := translator.FormatCodex // 发送事件 - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { eventChan.SendRequestInfo(&domain.RequestInfo{ Method: "POST", URL: fmt.Sprintf("cliproxyapi://codex/%s", model), @@ -84,18 +86,23 @@ func (a *CLIProxyAPICodexAdapter) Execute(ctx context.Context, w http.ResponseWr } if stream { - return a.executeStream(ctx, w, execReq, execOpts) + return a.executeStream(c, w, execReq, execOpts) } - return a.executeNonStream(ctx, w, execReq, execOpts) + return a.executeNonStream(c, w, execReq, execOpts) } -func (a *CLIProxyAPICodexAdapter) executeNonStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { +func (a *CLIProxyAPICodexAdapter) executeNonStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } + resp, err := a.executor.Execute(ctx, a.authObj, execReq, execOpts) if err != nil { return domain.NewProxyErrorWithMessage(err, true, fmt.Sprintf("executor request failed: %v", err)) } - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { // Send response info eventChan.SendResponseInfo(&domain.ResponseInfo{ Status: http.StatusOK, @@ -125,13 +132,16 @@ func (a *CLIProxyAPICodexAdapter) executeNonStream(ctx context.Context, w http.R return nil } -func (a *CLIProxyAPICodexAdapter) executeStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { +func (a *CLIProxyAPICodexAdapter) executeStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error { flusher, ok := w.(http.Flusher) if !ok { - return a.executeNonStream(ctx, w, execReq, execOpts) + return a.executeNonStream(c, w, execReq, execOpts) } - startTime := time.Now() + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } stream, err := a.executor.ExecuteStream(ctx, a.authObj, execReq, execOpts) if err != nil { @@ -144,7 +154,7 @@ func (a *CLIProxyAPICodexAdapter) executeStream(ctx context.Context, w http.Resp w.Header().Set("Connection", "keep-alive") w.WriteHeader(http.StatusOK) - eventChan := ctxutil.GetEventChan(ctx) + eventChan := flow.GetEventChan(c) // Collect SSE content for token extraction var sseBuffer bytes.Buffer @@ -165,7 +175,7 @@ func (a *CLIProxyAPICodexAdapter) executeStream(ctx context.Context, w http.Resp // Report TTFT on first non-empty chunk if !firstChunkSent && eventChan != nil { - eventChan.SendFirstToken(time.Since(startTime).Milliseconds()) + eventChan.SendFirstToken(time.Now().UnixMilli()) firstChunkSent = true } } diff --git a/internal/adapter/provider/codex/adapter.go b/internal/adapter/provider/codex/adapter.go index 74d90672..0dad347c 100644 --- a/internal/adapter/provider/codex/adapter.go +++ b/internal/adapter/provider/codex/adapter.go @@ -1,6 +1,7 @@ package codex import ( + "bufio" "bytes" "context" "encoding/json" @@ -14,13 +15,30 @@ import ( "github.com/awsl-project/maxx/internal/adapter/provider" cliproxyapi "github.com/awsl-project/maxx/internal/adapter/provider/cliproxyapi_codex" - ctxutil "github.com/awsl-project/maxx/internal/context" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" + "github.com/google/uuid" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) func init() { provider.RegisterAdapterFactory("codex", NewAdapter) + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + codexCacheMu.Lock() + now := time.Now() + for k, v := range codexCaches { + if now.After(v.Expire) { + delete(codexCaches, k) + } + } + codexCacheMu.Unlock() + } + }() } // TokenCache caches access tokens @@ -95,9 +113,14 @@ func (a *CodexAdapter) SupportedClientTypes() []domain.ClientType { return []domain.ClientType{domain.ClientTypeCodex} } -func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error { - requestBody := ctxutil.GetRequestBody(ctx) - stream := ctxutil.GetIsStream(ctx) +func (a *CodexAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error { + requestBody := flow.GetRequestBody(c) + clientWantsStream := flow.GetIsStream(c) + request := c.Request + ctx := context.Background() + if request != nil { + ctx = request.Context() + } // Get access token accessToken, err := a.getAccessToken(ctx) @@ -105,8 +128,18 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req * return domain.NewProxyErrorWithMessage(err, true, "failed to get access token") } + // Apply Codex CLI payload adjustments (CLIProxyAPI-aligned) + cacheID, updatedBody := applyCodexRequestTuning(c, requestBody) + requestBody = updatedBody + // Build upstream URL upstreamURL := CodexBaseURL + "/responses" + upstreamStream := true + if len(requestBody) > 0 { + if updated, err := sjson.SetBytes(requestBody, "stream", upstreamStream); err == nil { + requestBody = updated + } + } // Create upstream request upstreamReq, err := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(requestBody)) @@ -116,10 +149,10 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req * // Apply headers with passthrough support (client headers take priority) config := provider.Config.Codex - a.applyCodexHeaders(upstreamReq, req, accessToken, config.AccountID) + a.applyCodexHeaders(upstreamReq, request, accessToken, config.AccountID, upstreamStream, cacheID) // Send request info via EventChannel - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { eventChan.SendRequestInfo(&domain.RequestInfo{ Method: upstreamReq.Method, URL: upstreamURL, @@ -153,8 +186,11 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req * } // Retry request - upstreamReq, _ = http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(requestBody)) - a.applyCodexHeaders(upstreamReq, req, accessToken, config.AccountID) + upstreamReq, reqErr := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(requestBody)) + if reqErr != nil { + return domain.NewProxyErrorWithMessage(reqErr, false, fmt.Sprintf("failed to create retry request: %v", reqErr)) + } + a.applyCodexHeaders(upstreamReq, request, accessToken, config.AccountID, upstreamStream, cacheID) resp, err = a.httpClient.Do(upstreamReq) if err != nil { @@ -170,7 +206,7 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req * body, _ := io.ReadAll(resp.Body) // Send error response info via EventChannel - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { eventChan.SendResponseInfo(&domain.ResponseInfo{ Status: resp.StatusCode, Headers: flattenHeaders(resp.Header), @@ -200,26 +236,51 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req * } // Handle response - if stream { - return a.handleStreamResponse(ctx, w, resp) + if clientWantsStream { + return a.handleStreamResponse(c, resp) } - return a.handleNonStreamResponse(ctx, w, resp) + return a.handleCollectedStreamResponse(c, resp) } func (a *CodexAdapter) getAccessToken(ctx context.Context) (string, error) { // Check cache a.tokenMu.RLock() - if a.tokenCache.AccessToken != "" && time.Now().Add(60*time.Second).Before(a.tokenCache.ExpiresAt) { - token := a.tokenCache.AccessToken - a.tokenMu.RUnlock() - return token, nil + if a.tokenCache.AccessToken != "" { + if a.tokenCache.ExpiresAt.IsZero() || time.Now().Add(60*time.Second).Before(a.tokenCache.ExpiresAt) { + token := a.tokenCache.AccessToken + a.tokenMu.RUnlock() + return token, nil + } } a.tokenMu.RUnlock() - // Refresh token + // Use persisted access token if present (even if expiry is unknown) config := a.provider.Config.Codex + if strings.TrimSpace(config.AccessToken) != "" { + var expiresAt time.Time + if strings.TrimSpace(config.ExpiresAt) != "" { + if parsed, err := time.Parse(time.RFC3339, config.ExpiresAt); err == nil { + expiresAt = parsed + } + } + a.tokenMu.Lock() + a.tokenCache = &TokenCache{ + AccessToken: config.AccessToken, + ExpiresAt: expiresAt, + } + a.tokenMu.Unlock() + + if expiresAt.IsZero() || time.Now().Add(60*time.Second).Before(expiresAt) { + return config.AccessToken, nil + } + } + + // Refresh token tokenResp, err := RefreshAccessToken(ctx, config.RefreshToken) if err != nil { + if strings.TrimSpace(config.AccessToken) != "" { + return config.AccessToken, nil + } return "", err } @@ -238,6 +299,37 @@ func (a *CodexAdapter) getAccessToken(ctx context.Context) (string, error) { if a.providerUpdate != nil { config.AccessToken = tokenResp.AccessToken config.ExpiresAt = expiresAt.Format(time.RFC3339) + if tokenResp.RefreshToken != "" { + config.RefreshToken = tokenResp.RefreshToken + } + if tokenResp.IDToken != "" { + if claims, parseErr := ParseIDToken(tokenResp.IDToken); parseErr == nil && claims != nil { + if v := strings.TrimSpace(claims.GetAccountID()); v != "" { + config.AccountID = v + } + if v := strings.TrimSpace(claims.GetUserID()); v != "" { + config.UserID = v + } + if v := strings.TrimSpace(claims.Email); v != "" { + config.Email = v + } + if v := strings.TrimSpace(claims.Name); v != "" { + config.Name = v + } + if v := strings.TrimSpace(claims.Picture); v != "" { + config.Picture = v + } + if v := strings.TrimSpace(claims.GetPlanType()); v != "" { + config.PlanType = v + } + if v := strings.TrimSpace(claims.GetSubscriptionStart()); v != "" { + config.SubscriptionStart = v + } + if v := strings.TrimSpace(claims.GetSubscriptionEnd()); v != "" { + config.SubscriptionEnd = v + } + } + } // Note: We intentionally ignore errors here as token persistence is best-effort // The token will still work in memory even if DB update fails _ = a.providerUpdate(a.provider) @@ -246,72 +338,109 @@ func (a *CodexAdapter) getAccessToken(ctx context.Context) (string, error) { return tokenResp.AccessToken, nil } -func (a *CodexAdapter) handleNonStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response) error { +func (a *CodexAdapter) handleNonStreamResponse(c *flow.Ctx, resp *http.Response) error { body, err := io.ReadAll(resp.Body) if err != nil { return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to read upstream response") } // Send events via EventChannel - eventChan := ctxutil.GetEventChan(ctx) - eventChan.SendResponseInfo(&domain.ResponseInfo{ - Status: resp.StatusCode, - Headers: flattenHeaders(resp.Header), - Body: string(body), - }) - - // Extract token usage from response - if metrics := usage.ExtractFromResponse(string(body)); metrics != nil { - eventChan.SendMetrics(&domain.AdapterMetrics{ - InputTokens: metrics.InputTokens, - OutputTokens: metrics.OutputTokens, + if eventChan := flow.GetEventChan(c); eventChan != nil { + eventChan.SendResponseInfo(&domain.ResponseInfo{ + Status: resp.StatusCode, + Headers: flattenHeaders(resp.Header), + Body: string(body), }) - } - - // Extract model from response - if model := extractModelFromResponse(body); model != "" { - eventChan.SendResponseModel(model) + // Extract token usage from response + if metrics := usage.ExtractFromResponse(string(body)); metrics != nil { + eventChan.SendMetrics(&domain.AdapterMetrics{ + InputTokens: metrics.InputTokens, + OutputTokens: metrics.OutputTokens, + }) + } + // Extract model from response + if model := extractModelFromResponse(body); model != "" { + eventChan.SendResponseModel(model) + } } // Copy response headers - copyResponseHeaders(w.Header(), resp.Header) - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(resp.StatusCode) - _, _ = w.Write(body) + copyResponseHeaders(c.Writer.Header(), resp.Header) + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(body) return nil } -func (a *CodexAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response) error { - eventChan := ctxutil.GetEventChan(ctx) +func (a *CodexAdapter) handleCollectedStreamResponse(c *flow.Ctx, resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to read upstream response") + } - // Send initial response info - eventChan.SendResponseInfo(&domain.ResponseInfo{ - Status: resp.StatusCode, - Headers: flattenHeaders(resp.Header), - Body: "[streaming]", - }) + responsePayload := body + if isSSEPayload(body) { + if completed := extractCodexCompletedResponse(body); len(completed) > 0 { + responsePayload = completed + } + } - // Set streaming headers - copyResponseHeaders(w.Header(), resp.Header) - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") + if eventChan := flow.GetEventChan(c); eventChan != nil { + eventChan.SendResponseInfo(&domain.ResponseInfo{ + Status: resp.StatusCode, + Headers: flattenHeaders(resp.Header), + Body: string(responsePayload), + }) + if metrics := usage.ExtractFromResponse(string(responsePayload)); metrics != nil { + eventChan.SendMetrics(&domain.AdapterMetrics{ + InputTokens: metrics.InputTokens, + OutputTokens: metrics.OutputTokens, + }) + } + if model := extractModelFromResponse(responsePayload); model != "" { + eventChan.SendResponseModel(model) + } + } - flusher, ok := w.(http.Flusher) + copyResponseHeaders(c.Writer.Header(), resp.Header) + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(responsePayload) + return nil +} + +func (a *CodexAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response) error { + eventChan := flow.GetEventChan(c) + if eventChan != nil { + eventChan.SendResponseInfo(&domain.ResponseInfo{ + Status: resp.StatusCode, + Headers: flattenHeaders(resp.Header), + Body: "[streaming]", + }) + } + + copyResponseHeaders(c.Writer.Header(), resp.Header) + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + flusher, ok := c.Writer.(http.Flusher) if !ok { return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, false, "streaming not supported") } // Collect SSE for token extraction var sseBuffer strings.Builder - var lineBuffer bytes.Buffer - buf := make([]byte, 4096) + reader := bufio.NewReader(resp.Body) firstChunkSent := false responseCompleted := false + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } for { - // Check context select { case <-ctx.Done(): a.sendFinalStreamEvents(eventChan, &sseBuffer, resp) @@ -322,39 +451,30 @@ func (a *CodexAdapter) handleStreamResponse(ctx context.Context, w http.Response default: } - n, err := resp.Body.Read(buf) - if n > 0 { - lineBuffer.Write(buf[:n]) - - // Process complete lines - for { - line, readErr := lineBuffer.ReadString('\n') - if readErr != nil { - lineBuffer.WriteString(line) - break - } - - sseBuffer.WriteString(line) + line, err := reader.ReadString('\n') + if line != "" { + sseBuffer.WriteString(line) - // Check for response.completed in data line - if strings.HasPrefix(line, "data:") && strings.Contains(line, "response.completed") { - responseCompleted = true - } + // Check for response.completed in data line + if strings.HasPrefix(line, "data:") && strings.Contains(line, "\"response.completed\"") { + responseCompleted = true + } - // Write to client - _, writeErr := w.Write([]byte(line)) - if writeErr != nil { - a.sendFinalStreamEvents(eventChan, &sseBuffer, resp) - if responseCompleted { - return nil - } - return domain.NewProxyErrorWithMessage(writeErr, false, "client disconnected") + // Write to client + _, writeErr := c.Writer.Write([]byte(line)) + if writeErr != nil { + a.sendFinalStreamEvents(eventChan, &sseBuffer, resp) + if responseCompleted { + return nil } - flusher.Flush() + return domain.NewProxyErrorWithMessage(writeErr, false, "client disconnected") + } + flusher.Flush() - // Track TTFT - if !firstChunkSent { - firstChunkSent = true + // Track TTFT + if !firstChunkSent { + firstChunkSent = true + if eventChan != nil { eventChan.SendFirstToken(time.Now().UnixMilli()) } } @@ -374,6 +494,9 @@ func (a *CodexAdapter) handleStreamResponse(ctx context.Context, w http.Response } func (a *CodexAdapter) sendFinalStreamEvents(eventChan domain.AdapterEventChan, sseBuffer *strings.Builder, resp *http.Response) { + if eventChan == nil { + return + } if sseBuffer.Len() > 0 { // Update response body with collected SSE eventChan.SendResponseInfo(&domain.ResponseInfo{ @@ -397,6 +520,85 @@ func (a *CodexAdapter) sendFinalStreamEvents(eventChan domain.AdapterEventChan, } } +type codexCache struct { + ID string + Expire time.Time +} + +var ( + codexCacheMu sync.Mutex + codexCaches = map[string]codexCache{} +) + +func getCodexCache(key string) (codexCache, bool) { + codexCacheMu.Lock() + defer codexCacheMu.Unlock() + cache, ok := codexCaches[key] + if !ok { + return codexCache{}, false + } + if time.Now().After(cache.Expire) { + delete(codexCaches, key) + return codexCache{}, false + } + return cache, true +} + +func setCodexCache(key string, cache codexCache) { + codexCacheMu.Lock() + codexCaches[key] = cache + codexCacheMu.Unlock() +} + +func applyCodexRequestTuning(c *flow.Ctx, body []byte) (string, []byte) { + if len(body) == 0 { + return "", body + } + + origBody := flow.GetOriginalRequestBody(c) + origType := flow.GetOriginalClientType(c) + + cacheID := "" + if origType == domain.ClientTypeClaude && len(origBody) > 0 { + userID := gjson.GetBytes(origBody, "metadata.user_id") + if userID.Exists() && strings.TrimSpace(userID.String()) != "" { + model := gjson.GetBytes(body, "model").String() + key := model + "-" + userID.String() + if cache, ok := getCodexCache(key); ok { + cacheID = cache.ID + } else { + cacheID = uuid.NewString() + setCodexCache(key, codexCache{ + ID: cacheID, + Expire: time.Now().Add(1 * time.Hour), + }) + } + } + } else if len(origBody) > 0 { + if promptKey := gjson.GetBytes(origBody, "prompt_cache_key"); promptKey.Exists() { + cacheID = promptKey.String() + } + } + + if cacheID != "" { + if updated, err := sjson.SetBytes(body, "prompt_cache_key", cacheID); err == nil { + body = updated + } + } + + if updated, err := sjson.SetBytes(body, "stream", true); err == nil { + body = updated + } + body, _ = sjson.DeleteBytes(body, "previous_response_id") + body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") + body, _ = sjson.DeleteBytes(body, "safety_identifier") + if !gjson.GetBytes(body, "instructions").Exists() { + body, _ = sjson.SetBytes(body, "instructions", "") + } + + return cacheID, body +} + func newUpstreamHTTPClient() *http.Client { dialer := &net.Dialer{ Timeout: 20 * time.Second, @@ -486,9 +688,39 @@ func extractModelFromSSE(sseContent string) string { return lastModel } +func isSSEPayload(body []byte) bool { + trimmed := bytes.TrimSpace(body) + return bytes.HasPrefix(trimmed, []byte("data:")) || bytes.HasPrefix(trimmed, []byte("event:")) +} + +func extractCodexCompletedResponse(body []byte) []byte { + scanner := bufio.NewScanner(bytes.NewReader(body)) + scanner.Buffer(nil, 52_428_800) + for scanner.Scan() { + line := scanner.Bytes() + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + data := bytes.TrimSpace(line[5:]) + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + root := gjson.ParseBytes(data) + if root.Get("type").String() == "response.completed" { + if resp := root.Get("response"); resp.Exists() { + return []byte(resp.Raw) + } + return data + } + } + return nil +} + // applyCodexHeaders applies headers for Codex API requests // It follows the CLIProxyAPI pattern: passthrough client headers, use defaults only when missing -func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, accessToken, accountID string) { +func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, accessToken, accountID string, stream bool, cacheID string) { + hasAccessToken := strings.TrimSpace(accessToken) != "" + // First, copy passthrough headers from client request (excluding hop-by-hop and auth) if clientReq != nil { for k, vv := range clientReq.Header { @@ -496,8 +728,12 @@ func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, a // Skip hop-by-hop headers and authorization (we'll set our own) switch lk { case "connection", "keep-alive", "transfer-encoding", "upgrade", - "authorization", "host", "content-length": + "host", "content-length": continue + case "authorization": + if hasAccessToken { + continue + } } for _, v := range vv { upstreamReq.Header.Add(k, v) @@ -507,18 +743,32 @@ func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, a // Set required headers (these always override) upstreamReq.Header.Set("Content-Type", "application/json") - upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) - upstreamReq.Header.Set("Accept", "text/event-stream") + if hasAccessToken { + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + } + if stream { + upstreamReq.Header.Set("Accept", "text/event-stream") + } else { + upstreamReq.Header.Set("Accept", "application/json") + } upstreamReq.Header.Set("Connection", "Keep-Alive") // Set Codex-specific headers only if client didn't provide them ensureHeader(upstreamReq.Header, clientReq, "Version", CodexVersion) ensureHeader(upstreamReq.Header, clientReq, "Openai-Beta", OpenAIBetaHeader) + if cacheID != "" { + upstreamReq.Header.Set("Conversation_id", cacheID) + upstreamReq.Header.Set("Session_id", cacheID) + } else { + ensureHeader(upstreamReq.Header, clientReq, "Session_id", uuid.NewString()) + } ensureHeader(upstreamReq.Header, clientReq, "User-Agent", CodexUserAgent) - ensureHeader(upstreamReq.Header, clientReq, "Originator", CodexOriginator) + if hasAccessToken { + ensureHeader(upstreamReq.Header, clientReq, "Originator", CodexOriginator) + } // Set account ID if available (required for OAuth auth, not for API key) - if accountID != "" { + if hasAccessToken && accountID != "" { upstreamReq.Header.Set("Chatgpt-Account-Id", accountID) } } diff --git a/internal/adapter/provider/codex/adapter_test.go b/internal/adapter/provider/codex/adapter_test.go new file mode 100644 index 00000000..7aaaaf30 --- /dev/null +++ b/internal/adapter/provider/codex/adapter_test.go @@ -0,0 +1,37 @@ +package codex + +import ( + "testing" + + "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" + "github.com/tidwall/gjson" +) + +func TestApplyCodexRequestTuning(t *testing.T) { + c := flow.NewCtx(nil, nil) + c.Set(flow.KeyOriginalClientType, domain.ClientTypeClaude) + c.Set(flow.KeyOriginalRequestBody, []byte(`{"metadata":{"user_id":"user-123"}}`)) + + body := []byte(`{"model":"gpt-5","stream":false,"instructions":"x","previous_response_id":"r1","prompt_cache_retention":123,"safety_identifier":"s1"}`) + cacheID, tuned := applyCodexRequestTuning(c, body) + + if cacheID == "" { + t.Fatalf("expected cacheID to be set") + } + if gjson.GetBytes(tuned, "prompt_cache_key").String() == "" { + t.Fatalf("expected prompt_cache_key to be set") + } + if !gjson.GetBytes(tuned, "stream").Bool() { + t.Fatalf("expected stream=true") + } + if gjson.GetBytes(tuned, "previous_response_id").Exists() { + t.Fatalf("expected previous_response_id to be removed") + } + if gjson.GetBytes(tuned, "prompt_cache_retention").Exists() { + t.Fatalf("expected prompt_cache_retention to be removed") + } + if gjson.GetBytes(tuned, "safety_identifier").Exists() { + t.Fatalf("expected safety_identifier to be removed") + } +} diff --git a/internal/adapter/provider/codex/oauth.go b/internal/adapter/provider/codex/oauth.go index 1dea2466..2fc6c9cd 100644 --- a/internal/adapter/provider/codex/oauth.go +++ b/internal/adapter/provider/codex/oauth.go @@ -188,12 +188,14 @@ func RefreshAccessToken(ctx context.Context, refreshToken string) (*TokenRespons data.Set("grant_type", "refresh_token") data.Set("client_id", OAuthClientID) data.Set("refresh_token", refreshToken) + data.Set("scope", "openid profile email") req, err := http.NewRequestWithContext(ctx, "POST", OpenAITokenURL, strings.NewReader(data.Encode())) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") client := &http.Client{Timeout: 15 * time.Second} resp, err := client.Do(req) @@ -275,13 +277,13 @@ type CodexUsageResponse struct { // codexUsageAPIResponse handles both camelCase and snake_case from API type codexUsageAPIResponse struct { - PlanType string `json:"plan_type,omitempty"` - PlanTypeCamel string `json:"planType,omitempty"` - RateLimit *struct { - Allowed *bool `json:"allowed,omitempty"` - LimitReached *bool `json:"limit_reached,omitempty"` - LimitReachedCamel *bool `json:"limitReached,omitempty"` - PrimaryWindow *struct { + PlanType string `json:"plan_type,omitempty"` + PlanTypeCamel string `json:"planType,omitempty"` + RateLimit *struct { + Allowed *bool `json:"allowed,omitempty"` + LimitReached *bool `json:"limit_reached,omitempty"` + LimitReachedCamel *bool `json:"limitReached,omitempty"` + PrimaryWindow *struct { UsedPercent *float64 `json:"used_percent,omitempty"` UsedPercentCamel *float64 `json:"usedPercent,omitempty"` LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"` @@ -323,10 +325,10 @@ type codexUsageAPIResponse struct { } `json:"secondaryWindow,omitempty"` } `json:"rate_limit,omitempty"` RateLimitCamel *struct { - Allowed *bool `json:"allowed,omitempty"` - LimitReached *bool `json:"limit_reached,omitempty"` - LimitReachedCamel *bool `json:"limitReached,omitempty"` - PrimaryWindow *struct { + Allowed *bool `json:"allowed,omitempty"` + LimitReached *bool `json:"limit_reached,omitempty"` + LimitReachedCamel *bool `json:"limitReached,omitempty"` + PrimaryWindow *struct { UsedPercent *float64 `json:"used_percent,omitempty"` UsedPercentCamel *float64 `json:"usedPercent,omitempty"` LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"` @@ -368,10 +370,10 @@ type codexUsageAPIResponse struct { } `json:"secondaryWindow,omitempty"` } `json:"rateLimit,omitempty"` CodeReviewRateLimit *struct { - Allowed *bool `json:"allowed,omitempty"` - LimitReached *bool `json:"limit_reached,omitempty"` - LimitReachedCamel *bool `json:"limitReached,omitempty"` - PrimaryWindow *struct { + Allowed *bool `json:"allowed,omitempty"` + LimitReached *bool `json:"limit_reached,omitempty"` + LimitReachedCamel *bool `json:"limitReached,omitempty"` + PrimaryWindow *struct { UsedPercent *float64 `json:"used_percent,omitempty"` UsedPercentCamel *float64 `json:"usedPercent,omitempty"` LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"` @@ -393,10 +395,10 @@ type codexUsageAPIResponse struct { } `json:"primaryWindow,omitempty"` } `json:"code_review_rate_limit,omitempty"` CodeReviewRateLimitCamel *struct { - Allowed *bool `json:"allowed,omitempty"` - LimitReached *bool `json:"limit_reached,omitempty"` - LimitReachedCamel *bool `json:"limitReached,omitempty"` - PrimaryWindow *struct { + Allowed *bool `json:"allowed,omitempty"` + LimitReached *bool `json:"limit_reached,omitempty"` + LimitReachedCamel *bool `json:"limitReached,omitempty"` + PrimaryWindow *struct { UsedPercent *float64 `json:"used_percent,omitempty"` UsedPercentCamel *float64 `json:"usedPercent,omitempty"` LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"` diff --git a/internal/adapter/provider/custom/adapter.go b/internal/adapter/provider/custom/adapter.go index b10fe575..2681ce9a 100644 --- a/internal/adapter/provider/custom/adapter.go +++ b/internal/adapter/provider/custom/adapter.go @@ -14,8 +14,8 @@ import ( "time" "github.com/awsl-project/maxx/internal/adapter/provider" - ctxutil "github.com/awsl-project/maxx/internal/context" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" ) @@ -40,10 +40,15 @@ func (a *CustomAdapter) SupportedClientTypes() []domain.ClientType { return a.provider.SupportedClientTypes } -func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error { - clientType := ctxutil.GetClientType(ctx) - mappedModel := ctxutil.GetMappedModel(ctx) - requestBody := ctxutil.GetRequestBody(ctx) +func (a *CustomAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error { + clientType := flow.GetClientType(c) + mappedModel := flow.GetMappedModel(c) + requestBody := flow.GetRequestBody(c) + request := c.Request + ctx := context.Background() + if request != nil { + ctx = request.Context() + } // Determine if streaming stream := isStreamRequest(requestBody) @@ -54,7 +59,7 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req // Build upstream URL baseURL := a.getBaseURL(clientType) - requestURI := ctxutil.GetRequestURI(ctx) + requestURI := flow.GetRequestURI(c) // Apply model mapping if configured var err error @@ -90,8 +95,8 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req // Claude: Following CLIProxyAPI pattern // 1. Process body first (get extraBetas, inject cloaking/cache_control) clientUA := "" - if req != nil { - clientUA = req.Header.Get("User-Agent") + if request != nil { + clientUA = request.Header.Get("User-Agent") } var extraBetas []string requestBody, extraBetas = processClaudeRequestBody(requestBody, clientUA, a.provider.Config.Custom.Cloak) @@ -101,33 +106,32 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req } // 2. Set headers (streaming only if requested) - applyClaudeHeaders(upstreamReq, req, a.provider.Config.Custom.APIKey, extraBetas, stream) + applyClaudeHeaders(upstreamReq, request, a.provider.Config.Custom.APIKey, extraBetas, stream) // 3. Update request body and ContentLength (IMPORTANT: body was modified) upstreamReq.Body = io.NopCloser(bytes.NewReader(requestBody)) upstreamReq.ContentLength = int64(len(requestBody)) case domain.ClientTypeCodex: // Codex: Use Codex CLI-style headers with passthrough support - applyCodexHeaders(upstreamReq, req, a.provider.Config.Custom.APIKey) + applyCodexHeaders(upstreamReq, request, a.provider.Config.Custom.APIKey) case domain.ClientTypeGemini: // Gemini: Use Gemini-style headers with passthrough support - applyGeminiHeaders(upstreamReq, req, a.provider.Config.Custom.APIKey) + applyGeminiHeaders(upstreamReq, request, a.provider.Config.Custom.APIKey) default: // Other types: Preserve original header forwarding logic - originalHeaders := ctxutil.GetRequestHeaders(ctx) + originalHeaders := flow.GetRequestHeaders(c) upstreamReq.Header = originalHeaders // Override auth headers with provider's credentials if a.provider.Config.Custom.APIKey != "" { - // Check if this is a format conversion scenario - originalClientType := ctxutil.GetOriginalClientType(ctx) + originalClientType := flow.GetOriginalClientType(c) isConversion := originalClientType != "" && originalClientType != clientType setAuthHeader(upstreamReq, clientType, a.provider.Config.Custom.APIKey, isConversion) } } // Send request info via EventChannel - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { eventChan.SendRequestInfo(&domain.RequestInfo{ Method: upstreamReq.Method, URL: upstreamURL, @@ -159,7 +163,7 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req body, _ := io.ReadAll(reader) // Send error response info via EventChannel - if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil { + if eventChan := flow.GetEventChan(c); eventChan != nil { eventChan.SendResponseInfo(&domain.ResponseInfo{ Status: resp.StatusCode, Headers: flattenHeaders(resp.Header), @@ -192,9 +196,9 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req // Note: Response format conversion is handled by Executor's ConvertingResponseWriter // Adapters simply pass through the upstream response if stream { - return a.handleStreamResponse(ctx, w, resp, clientType, isOAuthToken) + return a.handleStreamResponse(c, resp, clientType, isOAuthToken) } - return a.handleNonStreamResponse(ctx, w, resp, clientType, isOAuthToken) + return a.handleNonStreamResponse(c, resp, clientType, isOAuthToken) } func (a *CustomAdapter) supportsClientType(ct domain.ClientType) bool { @@ -214,7 +218,7 @@ func (a *CustomAdapter) getBaseURL(clientType domain.ClientType) string { return config.BaseURL } -func (a *CustomAdapter) handleNonStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error { +func (a *CustomAdapter) handleNonStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error { // Decompress response body if needed reader, err := decompressResponse(resp) if err != nil { @@ -239,45 +243,50 @@ func (a *CustomAdapter) handleNonStreamResponse(ctx context.Context, w http.Resp body = stripClaudeToolPrefixFromResponse(body, claudeToolPrefix) } - eventChan := ctxutil.GetEventChan(ctx) + eventChan := flow.GetEventChan(c) - // Send response info via EventChannel - eventChan.SendResponseInfo(&domain.ResponseInfo{ - Status: resp.StatusCode, - Headers: flattenHeaders(resp.Header), - Body: string(body), - }) + if eventChan != nil { + eventChan.SendResponseInfo(&domain.ResponseInfo{ + Status: resp.StatusCode, + Headers: flattenHeaders(resp.Header), + Body: string(body), + }) + } // Extract and send token usage metrics if metrics := usage.ExtractFromResponse(string(body)); metrics != nil { // Adjust for client-specific quirks (e.g., Codex input_tokens includes cached tokens) metrics = usage.AdjustForClientType(metrics, clientType) - eventChan.SendMetrics(&domain.AdapterMetrics{ - InputTokens: metrics.InputTokens, - OutputTokens: metrics.OutputTokens, - CacheReadCount: metrics.CacheReadCount, - CacheCreationCount: metrics.CacheCreationCount, - Cache5mCreationCount: metrics.Cache5mCreationCount, - Cache1hCreationCount: metrics.Cache1hCreationCount, - }) + if eventChan != nil { + eventChan.SendMetrics(&domain.AdapterMetrics{ + InputTokens: metrics.InputTokens, + OutputTokens: metrics.OutputTokens, + CacheReadCount: metrics.CacheReadCount, + CacheCreationCount: metrics.CacheCreationCount, + Cache5mCreationCount: metrics.Cache5mCreationCount, + Cache1hCreationCount: metrics.Cache1hCreationCount, + }) + } } // Extract and send responseModel if responseModel := extractResponseModel(body, clientType); responseModel != "" { - eventChan.SendResponseModel(responseModel) + if eventChan != nil { + eventChan.SendResponseModel(responseModel) + } } // Note: Response format conversion is handled by Executor's ConvertingResponseWriter // Adapter simply passes through the upstream response body // Copy upstream headers (except those we override) - copyResponseHeaders(w.Header(), resp.Header) - w.WriteHeader(resp.StatusCode) - _, _ = w.Write(body) + copyResponseHeaders(c.Writer.Header(), resp.Header) + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(body) return nil } -func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error { +func (a *CustomAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error { // Decompress response body if needed reader, err := decompressResponse(resp) if err != nil { @@ -285,7 +294,7 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons } defer reader.Close() - eventChan := ctxutil.GetEventChan(ctx) + eventChan := flow.GetEventChan(c) // Send initial response info (for streaming, we only capture status and headers) eventChan.SendResponseInfo(&domain.ResponseInfo{ @@ -295,24 +304,24 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons }) // Copy upstream headers (except those we override) - copyResponseHeaders(w.Header(), resp.Header) + copyResponseHeaders(c.Writer.Header(), resp.Header) // Set streaming headers only if not already set by upstream // These are required for SSE (Server-Sent Events) to work correctly - if w.Header().Get("Content-Type") == "" { - w.Header().Set("Content-Type", "text/event-stream") + if c.Writer.Header().Get("Content-Type") == "" { + c.Writer.Header().Set("Content-Type", "text/event-stream") } - if w.Header().Get("Cache-Control") == "" { - w.Header().Set("Cache-Control", "no-cache") + if c.Writer.Header().Get("Cache-Control") == "" { + c.Writer.Header().Set("Cache-Control", "no-cache") } - if w.Header().Get("Connection") == "" { - w.Header().Set("Connection", "keep-alive") + if c.Writer.Header().Get("Connection") == "" { + c.Writer.Header().Set("Connection", "keep-alive") } - if w.Header().Get("X-Accel-Buffering") == "" { - w.Header().Set("X-Accel-Buffering", "no") + if c.Writer.Header().Get("X-Accel-Buffering") == "" { + c.Writer.Header().Set("X-Accel-Buffering", "no") } - flusher, ok := w.(http.Flusher) + flusher, ok := c.Writer.(http.Flusher) if !ok { return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, false, "streaming not supported") } @@ -323,6 +332,10 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons // Collect all SSE events for response body and token extraction var sseBuffer strings.Builder var sseError error // Track any SSE error event + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } // Helper to send final events via EventChannel sendFinalEvents := func() { @@ -447,7 +460,7 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons // Note: Response format conversion is handled by Executor's ConvertingResponseWriter // Adapter simply passes through the upstream SSE data if len(processedLine) > 0 { - _, writeErr := w.Write([]byte(processedLine)) + _, writeErr := c.Writer.Write([]byte(processedLine)) if writeErr != nil { // Client disconnected sendFinalEvents() @@ -456,7 +469,7 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons flusher.Flush() // Track TTFT: send first token time on first successful write - if !firstChunkSent { + if !firstChunkSent && eventChan != nil { firstChunkSent = true eventChan.SendFirstToken(time.Now().UnixMilli()) } diff --git a/internal/adapter/provider/kiro/adapter.go b/internal/adapter/provider/kiro/adapter.go index 93dc1923..83115d8f 100644 --- a/internal/adapter/provider/kiro/adapter.go +++ b/internal/adapter/provider/kiro/adapter.go @@ -13,9 +13,9 @@ import ( "time" "github.com/awsl-project/maxx/internal/adapter/provider" - ctxutil "github.com/awsl-project/maxx/internal/context" "github.com/awsl-project/maxx/internal/converter" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/usage" ) @@ -64,10 +64,15 @@ func (a *KiroAdapter) SupportedClientTypes() []domain.ClientType { } // Execute performs the proxy request to the upstream CodeWhisperer API -func (a *KiroAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error { - requestModel := ctxutil.GetRequestModel(ctx) - requestBody := ctxutil.GetRequestBody(ctx) - stream := ctxutil.GetIsStream(ctx) +func (a *KiroAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error { + requestModel := flow.GetRequestModel(c) + requestBody := flow.GetRequestBody(c) + stream := flow.GetIsStream(c) + request := c.Request + ctx := context.Background() + if request != nil { + ctx = request.Context() + } config := provider.Config.Kiro @@ -84,18 +89,18 @@ func (a *KiroAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *h } // Convert Claude request to CodeWhisperer format (传入 req 用于生成稳定会话ID) - cwBody, mappedModel, err := ConvertClaudeToCodeWhisperer(requestBody, config.ModelMapping, req) + cwBody, mappedModel, err := ConvertClaudeToCodeWhisperer(requestBody, config.ModelMapping, request) if err != nil { return domain.NewProxyErrorWithMessage(err, true, fmt.Sprintf("failed to convert request: %v", err)) } // Update attempt record with the mapped model (kiro-specific internal mapping) - if attempt := ctxutil.GetUpstreamAttempt(ctx); attempt != nil { + if attempt := flow.GetUpstreamAttempt(c); attempt != nil { attempt.MappedModel = mappedModel } // Get EventChannel for sending events to executor - eventChan := ctxutil.GetEventChan(ctx) + eventChan := flow.GetEventChan(c) // Build upstream URL upstreamURL := fmt.Sprintf(CodeWhispererURLTemplate, region) @@ -196,9 +201,9 @@ func (a *KiroAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *h inputTokens := calculateInputTokens(requestBody) if stream { - return a.handleStreamResponse(ctx, w, resp, requestModel, inputTokens) + return a.handleStreamResponse(c, resp, requestModel, inputTokens) } - return a.handleCollectedStreamResponse(ctx, w, resp, requestModel, inputTokens) + return a.handleCollectedStreamResponse(c, resp, requestModel, inputTokens) } // getAccessToken gets a valid access token, refreshing if necessary @@ -335,8 +340,13 @@ func (a *KiroAdapter) refreshIdCToken(ctx context.Context, config *domain.Provid } // handleStreamResponse handles streaming EventStream response -func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, requestModel string, inputTokens int) error { - eventChan := ctxutil.GetEventChan(ctx) +func (a *KiroAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response, requestModel string, inputTokens int) error { + w := c.Writer + ctx := context.Background() + if c.Request != nil { + ctx = c.Request.Context() + } + eventChan := flow.GetEventChan(c) // Send initial response info eventChan.SendResponseInfo(&domain.ResponseInfo{ @@ -362,7 +372,7 @@ func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseW if err := streamCtx.sendInitialEvents(); err != nil { inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) + a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return domain.NewProxyErrorWithMessage(err, false, "failed to send initial events") } @@ -370,30 +380,29 @@ func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseW if err != nil { if ctx.Err() != nil { inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) + a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return domain.NewProxyErrorWithMessage(ctx.Err(), false, "client disconnected") } _ = streamCtx.sendFinalEvents() inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) + a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return nil } if err := streamCtx.sendFinalEvents(); err != nil { inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) + a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return domain.NewProxyErrorWithMessage(err, false, "failed to send final events") } inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) + a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return nil } // sendFinalEvents sends final events via EventChannel -func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTokens, outputTokens int, requestModel string, firstTokenTimeMs int64) { - eventChan := ctxutil.GetEventChan(ctx) +func (a *KiroAdapter) sendFinalEvents(eventChan domain.AdapterEventChan, body string, inputTokens, outputTokens int, requestModel string, firstTokenTimeMs int64) { if eventChan == nil { return } @@ -405,8 +414,8 @@ func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTok // Send response info with body eventChan.SendResponseInfo(&domain.ResponseInfo{ - Status: 200, // streaming always returns 200 at this point - Body: body, + Status: 200, // streaming always returns 200 at this point + Body: body, }) // Try to extract usage metrics from the SSE content first @@ -432,8 +441,9 @@ func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTok } // handleCollectedStreamResponse collects streaming response into a single JSON response -func (a *KiroAdapter) handleCollectedStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, requestModel string, inputTokens int) error { - eventChan := ctxutil.GetEventChan(ctx) +func (a *KiroAdapter) handleCollectedStreamResponse(c *flow.Ctx, resp *http.Response, requestModel string, inputTokens int) error { + w := c.Writer + eventChan := flow.GetEventChan(c) // Send initial response info eventChan.SendResponseInfo(&domain.ResponseInfo{ diff --git a/internal/context/context.go b/internal/context/context.go index 0e340421..17c5de7d 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -8,26 +8,7 @@ import ( "github.com/awsl-project/maxx/internal/event" ) -type contextKey string - -const ( - CtxKeyClientType contextKey = "client_type" - CtxKeyOriginalClientType contextKey = "original_client_type" // Original client type before format conversion - CtxKeySessionID contextKey = "session_id" - CtxKeyProjectID contextKey = "project_id" - CtxKeyRequestModel contextKey = "request_model" - CtxKeyMappedModel contextKey = "mapped_model" - CtxKeyResponseModel contextKey = "response_model" - CtxKeyProxyRequest contextKey = "proxy_request" - CtxKeyRequestBody contextKey = "request_body" - CtxKeyUpstreamAttempt contextKey = "upstream_attempt" - CtxKeyRequestHeaders contextKey = "request_headers" - CtxKeyRequestURI contextKey = "request_uri" - CtxKeyBroadcaster contextKey = "broadcaster" - CtxKeyIsStream contextKey = "is_stream" - CtxKeyAPITokenID contextKey = "api_token_id" - CtxKeyEventChan contextKey = "event_chan" -) +// context keys defined in keys.go // Setters func WithClientType(ctx context.Context, ct domain.ClientType) context.Context { diff --git a/internal/context/keys.go b/internal/context/keys.go new file mode 100644 index 00000000..1c7decdf --- /dev/null +++ b/internal/context/keys.go @@ -0,0 +1,22 @@ +package context + +type contextKey string + +const ( + CtxKeyClientType contextKey = "client_type" + CtxKeyOriginalClientType contextKey = "original_client_type" + CtxKeySessionID contextKey = "session_id" + CtxKeyProjectID contextKey = "project_id" + CtxKeyRequestModel contextKey = "request_model" + CtxKeyMappedModel contextKey = "mapped_model" + CtxKeyResponseModel contextKey = "response_model" + CtxKeyProxyRequest contextKey = "proxy_request" + CtxKeyRequestBody contextKey = "request_body" + CtxKeyUpstreamAttempt contextKey = "upstream_attempt" + CtxKeyRequestHeaders contextKey = "request_headers" + CtxKeyRequestURI contextKey = "request_uri" + CtxKeyBroadcaster contextKey = "broadcaster" + CtxKeyIsStream contextKey = "is_stream" + CtxKeyAPITokenID contextKey = "api_token_id" + CtxKeyEventChan contextKey = "event_chan" +) diff --git a/internal/converter/claude_to_codex.go b/internal/converter/claude_to_codex.go index a5d789eb..4808ba5f 100644 --- a/internal/converter/claude_to_codex.go +++ b/internal/converter/claude_to_codex.go @@ -30,6 +30,22 @@ func (c *claudeToCodexRequest) Transform(body []byte, model string, stream bool) TopP: req.TopP, } + shortMap := map[string]string{} + if len(req.Tools) > 0 { + var names []string + for _, tool := range req.Tools { + if tool.Type != "" { + continue // server tools should keep their type + } + if tool.Name != "" { + names = append(names, tool.Name) + } + } + if len(names) > 0 { + shortMap = buildShortNameMap(names) + } + } + // Convert messages to input var input []CodexInputItem if req.System != nil { @@ -69,6 +85,11 @@ func (c *claudeToCodexRequest) Transform(body []byte, model string, stream bool) case "tool_use": // Convert tool use to function_call output name, _ := m["name"].(string) + if short, ok := shortMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } id, _ := m["id"].(string) inputData := m["input"] argJSON, _ := json.Marshal(inputData) @@ -102,9 +123,21 @@ func (c *claudeToCodexRequest) Transform(body []byte, model string, stream bool) // Convert tools for _, tool := range req.Tools { + if tool.Type != "" { + codexReq.Tools = append(codexReq.Tools, CodexTool{ + Type: tool.Type, + }) + continue + } + name := tool.Name + if short, ok := shortMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } codexReq.Tools = append(codexReq.Tools, CodexTool{ Type: "function", - Name: tool.Name, + Name: name, Description: tool.Description, Parameters: tool.InputSchema, }) diff --git a/internal/converter/claude_to_openai_stream.go b/internal/converter/claude_to_openai_stream.go index d6fc9513..289852ba 100644 --- a/internal/converter/claude_to_openai_stream.go +++ b/internal/converter/claude_to_openai_stream.go @@ -3,8 +3,14 @@ package converter import ( "encoding/json" "time" + + "github.com/tidwall/gjson" ) +type claudeOpenAIStreamMeta struct { + Model string +} + func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) { events, remaining := ParseSSE(state.Buffer + string(chunk)) state.Buffer = remaining @@ -23,6 +29,16 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt switch claudeEvent.Type { case "message_start": + streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta) + if streamMeta == nil { + streamMeta = &claudeOpenAIStreamMeta{} + state.Custom = streamMeta + } + if streamMeta.Model == "" && len(state.OriginalRequestBody) > 0 { + if reqModel := gjson.GetBytes(state.OriginalRequestBody, "model"); reqModel.Exists() && reqModel.String() != "" { + streamMeta.Model = reqModel.String() + } + } if claudeEvent.Message != nil { state.MessageID = claudeEvent.Message.ID } @@ -30,6 +46,7 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, Delta: &OpenAIMessage{Role: "assistant", Content: ""}, @@ -56,25 +73,35 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt if claudeEvent.Delta != nil { switch claudeEvent.Delta.Type { case "text_delta": + streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta) + if streamMeta == nil { + continue + } chunk := OpenAIStreamChunk{ ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, - Delta: &OpenAIMessage{Content: claudeEvent.Delta.Text}, + Delta: &OpenAIMessage{Role: "assistant", Content: claudeEvent.Delta.Text}, }}, } output = append(output, FormatSSE("", chunk)...) case "thinking_delta": if claudeEvent.Delta.Thinking != "" { + streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta) + if streamMeta == nil { + continue + } chunk := OpenAIStreamChunk{ ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, - Delta: &OpenAIMessage{ReasoningContent: claudeEvent.Delta.Thinking}, + Delta: &OpenAIMessage{Role: "assistant", ReasoningContent: claudeEvent.Delta.Thinking}, }}, } output = append(output, FormatSSE("", chunk)...) @@ -82,13 +109,19 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt case "input_json_delta": if tc, ok := state.ToolCalls[state.CurrentIndex]; ok { tc.Arguments += claudeEvent.Delta.PartialJSON + streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta) + if streamMeta == nil { + continue + } chunk := OpenAIStreamChunk{ ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, Delta: &OpenAIMessage{ + Role: "assistant", ToolCalls: []OpenAIToolCall{{ Index: state.CurrentIndex, ID: tc.ID, @@ -124,13 +157,19 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt case "tool_use": finishReason = "tool_calls" } + streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta) + if streamMeta == nil { + streamMeta = &claudeOpenAIStreamMeta{} + state.Custom = streamMeta + } chunk := OpenAIStreamChunk{ ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, - Delta: &OpenAIMessage{}, + Delta: &OpenAIMessage{Role: "assistant", Content: ""}, FinishReason: finishReason, }}, } diff --git a/internal/converter/codex_openai_stream_test.go b/internal/converter/codex_openai_stream_test.go new file mode 100644 index 00000000..3cf58550 --- /dev/null +++ b/internal/converter/codex_openai_stream_test.go @@ -0,0 +1,224 @@ +package converter + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestCodexToOpenAIStreamToolCalls(t *testing.T) { + state := NewTransformState() + conv := &codexToOpenAIResponse{} + + created := map[string]interface{}{ + "type": "response.created", + "response": map[string]interface{}{ + "id": "resp_test_1", + }, + } + added := map[string]interface{}{ + "type": "response.output_item.added", + "output_index": 0, + "item": map[string]interface{}{ + "id": "fc_call1", + "type": "function_call", + "call_id": "call1", + "name": "tool_alpha", + }, + } + doneItem := map[string]interface{}{ + "type": "response.output_item.done", + "item": map[string]interface{}{ + "type": "function_call", + "call_id": "call1", + "name": "tool_alpha", + "arguments": `{"a":1}`, + }, + } + completed := map[string]interface{}{ + "type": "response.completed", + "response": map[string]interface{}{ + "id": "resp_test_1", + }, + } + + var out []byte + for _, ev := range []interface{}{created, added, doneItem, completed} { + chunk := FormatSSE("", ev) + next, err := conv.TransformChunk(chunk, state) + if err != nil { + t.Fatalf("transform chunk error: %v", err) + } + out = append(out, next...) + } + + events, _ := ParseSSE(string(out)) + if len(events) == 0 { + t.Fatalf("no SSE events produced") + } + + foundToolDelta := false + foundFinishToolCalls := false + + for _, ev := range events { + if ev.Event == "done" { + continue + } + var chunk OpenAIStreamChunk + if err := json.Unmarshal(ev.Data, &chunk); err != nil { + t.Fatalf("invalid chunk JSON: %v", err) + } + if len(chunk.Choices) == 0 { + continue + } + if chunk.Choices[0].Delta != nil && len(chunk.Choices[0].Delta.ToolCalls) > 0 { + tc := chunk.Choices[0].Delta.ToolCalls[0] + if tc.Type == "function" && tc.Function.Arguments != "" { + foundToolDelta = true + } + } + if chunk.Choices[0].FinishReason == "tool_calls" { + foundFinishToolCalls = true + } + } + + if !foundToolDelta { + t.Fatalf("expected tool_calls delta in stream output") + } + if !foundFinishToolCalls { + t.Fatalf("expected finish_reason=tool_calls in stream output") + } +} + +func TestCodexToClaudeStreamToolStopReason(t *testing.T) { + state := NewTransformState() + conv := &codexToClaudeResponse{} + + created := map[string]interface{}{ + "type": "response.created", + "response": map[string]interface{}{ + "id": "resp_test_2", + }, + } + added := map[string]interface{}{ + "type": "response.output_item.added", + "output_index": 0, + "item": map[string]interface{}{ + "id": "fc_call2", + "type": "function_call", + "call_id": "call2", + "name": "tool_beta", + }, + } + doneItem := map[string]interface{}{ + "type": "response.output_item.done", + "item": map[string]interface{}{ + "type": "function_call", + "call_id": "call2", + "name": "tool_beta", + "arguments": `{"b":2}`, + }, + } + completed := map[string]interface{}{ + "type": "response.completed", + "response": map[string]interface{}{ + "id": "resp_test_2", + }, + } + + var out []byte + for _, ev := range []interface{}{created, added, doneItem, completed} { + chunk := FormatSSE("", ev) + next, err := conv.TransformChunk(chunk, state) + if err != nil { + t.Fatalf("transform chunk error: %v", err) + } + out = append(out, next...) + } + + events, _ := ParseSSE(string(out)) + if len(events) == 0 { + t.Fatalf("no SSE events produced") + } + + foundStopReason := false + for _, ev := range events { + if ev.Event != "message_delta" { + continue + } + var payload map[string]interface{} + if err := json.Unmarshal(ev.Data, &payload); err != nil { + t.Fatalf("invalid event JSON: %v", err) + } + if delta, ok := payload["delta"].(map[string]interface{}); ok { + if sr, ok := delta["stop_reason"].(string); ok && sr == "tool_use" { + foundStopReason = true + } + } + } + + if !foundStopReason { + t.Fatalf("expected stop_reason=tool_use in Claude stream output") + } +} + +func TestClaudeToCodexToolShortening(t *testing.T) { + longName := "mcp__server__" + strings.Repeat("x", 80) + claudeReq := map[string]interface{}{ + "model": "claude-3", + "messages": []map[string]interface{}{ + {"role": "user", "content": "hi"}, + }, + "tools": []map[string]interface{}{ + { + "name": longName, + "description": "d", + "input_schema": map[string]interface{}{"type": "object"}, + }, + { + "type": "web_search_20250305", + }, + }, + } + + raw, err := json.Marshal(claudeReq) + if err != nil { + t.Fatalf("marshal claude req: %v", err) + } + + conv := &claudeToCodexRequest{} + out, err := conv.Transform(raw, "gpt-5.2-codex", false) + if err != nil { + t.Fatalf("transform error: %v", err) + } + + var codexReq CodexRequest + if err := json.Unmarshal(out, &codexReq); err != nil { + t.Fatalf("unmarshal codex req: %v", err) + } + + if len(codexReq.Tools) != 2 { + t.Fatalf("tools = %d, want 2", len(codexReq.Tools)) + } + + var fnTool *CodexTool + var serverTool *CodexTool + for i := range codexReq.Tools { + switch codexReq.Tools[i].Type { + case "function": + fnTool = &codexReq.Tools[i] + case "web_search_20250305": + serverTool = &codexReq.Tools[i] + } + } + + if fnTool == nil || fnTool.Name == "" { + t.Fatalf("missing function tool after transform") + } + if len(fnTool.Name) > maxToolNameLen { + t.Fatalf("function tool name too long: %d", len(fnTool.Name)) + } + if serverTool == nil { + t.Fatalf("missing server tool type in codex tools") + } +} diff --git a/internal/converter/codex_to_claude.go b/internal/converter/codex_to_claude.go index 36532f2e..bf627b7d 100644 --- a/internal/converter/codex_to_claude.go +++ b/internal/converter/codex_to_claude.go @@ -2,8 +2,11 @@ package converter import ( "encoding/json" + "strings" "github.com/awsl-project/maxx/internal/domain" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) func init() { @@ -13,6 +16,12 @@ func init() { type codexToClaudeRequest struct{} type codexToClaudeResponse struct{} +type claudeStreamState struct { + HasToolCall bool + BlockIndex int + ShortToOrig map[string]string +} + func (c *codexToClaudeRequest) Transform(body []byte, model string, stream bool) ([]byte, error) { var req CodexRequest if err := json.Unmarshal(body, &req); err != nil { @@ -106,85 +115,77 @@ func (c *codexToClaudeRequest) Transform(body []byte, model string, stream bool) } func (c *codexToClaudeResponse) Transform(body []byte) ([]byte, error) { - var resp CodexResponse - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } - - claudeResp := ClaudeResponse{ - ID: resp.ID, - Type: "message", - Role: "assistant", - Model: resp.Model, - Usage: ClaudeUsage{ - InputTokens: resp.Usage.InputTokens, - OutputTokens: resp.Usage.OutputTokens, - }, - } - - var hasToolCall bool - for _, out := range resp.Output { - switch out.Type { - case "message": - contentStr, _ := out.Content.(string) - claudeResp.Content = append(claudeResp.Content, ClaudeContentBlock{ - Type: "text", - Text: contentStr, - }) - case "function_call": - hasToolCall = true - var args interface{} - json.Unmarshal([]byte(out.Arguments), &args) - claudeResp.Content = append(claudeResp.Content, ClaudeContentBlock{ - Type: "tool_use", - ID: out.ID, - Name: out.Name, - Input: args, - }) - } - } - - if hasToolCall { - claudeResp.StopReason = "tool_use" - } else { - claudeResp.StopReason = "end_turn" - } - - return json.Marshal(claudeResp) + return c.TransformWithState(body, nil) } func (c *codexToClaudeResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) { events, remaining := ParseSSE(state.Buffer + string(chunk)) state.Buffer = remaining + st := getClaudeStreamState(state) var output []byte for _, event := range events { - var codexEvent map[string]interface{} - if err := json.Unmarshal(event.Data, &codexEvent); err != nil { + if event.Event == "done" { continue } - eventType, _ := codexEvent["type"].(string) + root := gjson.ParseBytes(event.Data) + if !root.Exists() { + continue + } + + eventType := root.Get("type").String() switch eventType { case "response.created": - if resp, ok := codexEvent["response"].(map[string]interface{}); ok { - state.MessageID, _ = resp["id"].(string) - } + state.MessageID = root.Get("response.id").String() msgStart := map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ "id": state.MessageID, "type": "message", "role": "assistant", + "model": root.Get("response.model").String(), "usage": map[string]int{"input_tokens": 0, "output_tokens": 0}, }, } output = append(output, FormatSSE("message_start", msgStart)...) + case "response.reasoning_summary_part.added": + blockStart := map[string]interface{}{ + "type": "content_block_start", + "index": st.BlockIndex, + "content_block": map[string]interface{}{ + "type": "thinking", + "thinking": "", + }, + } + output = append(output, FormatSSE("content_block_start", blockStart)...) + + case "response.reasoning_summary_text.delta": + delta := root.Get("delta").String() + claudeDelta := map[string]interface{}{ + "type": "content_block_delta", + "index": st.BlockIndex, + "delta": map[string]interface{}{ + "type": "thinking_delta", + "thinking": delta, + }, + } + output = append(output, FormatSSE("content_block_delta", claudeDelta)...) + + case "response.reasoning_summary_part.done": + blockStop := map[string]interface{}{ + "type": "content_block_stop", + "index": st.BlockIndex, + } + output = append(output, FormatSSE("content_block_stop", blockStop)...) + st.BlockIndex++ + + case "response.content_part.added": blockStart := map[string]interface{}{ "type": "content_block_start", - "index": 0, + "index": st.BlockIndex, "content_block": map[string]interface{}{ "type": "text", "text": "", @@ -192,34 +193,104 @@ func (c *codexToClaudeResponse) TransformChunk(chunk []byte, state *TransformSta } output = append(output, FormatSSE("content_block_start", blockStart)...) - case "response.output_item.delta": - if delta, ok := codexEvent["delta"].(map[string]interface{}); ok { - if text, ok := delta["text"].(string); ok { - claudeDelta := map[string]interface{}{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]interface{}{ - "type": "text_delta", - "text": text, - }, - } - output = append(output, FormatSSE("content_block_delta", claudeDelta)...) - } + case "response.output_text.delta": + delta := root.Get("delta").String() + claudeDelta := map[string]interface{}{ + "type": "content_block_delta", + "index": st.BlockIndex, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": delta, + }, } + output = append(output, FormatSSE("content_block_delta", claudeDelta)...) - case "response.done": + case "response.content_part.done": blockStop := map[string]interface{}{ "type": "content_block_stop", - "index": 0, + "index": st.BlockIndex, } output = append(output, FormatSSE("content_block_stop", blockStop)...) + st.BlockIndex++ + + case "response.output_item.added": + item := root.Get("item") + if item.Get("type").String() == "function_call" { + st.HasToolCall = true + if st.ShortToOrig == nil { + st.ShortToOrig = buildReverseMapFromClaudeOriginalShortToOriginal(state.OriginalRequestBody) + } + name := item.Get("name").String() + if orig, ok := st.ShortToOrig[name]; ok { + name = orig + } + blockStart := map[string]interface{}{ + "type": "content_block_start", + "index": st.BlockIndex, + "content_block": map[string]interface{}{ + "type": "tool_use", + "id": item.Get("call_id").String(), + "name": name, + "input": map[string]interface{}{}, + }, + } + output = append(output, FormatSSE("content_block_start", blockStart)...) + + blockDelta := map[string]interface{}{ + "type": "content_block_delta", + "index": st.BlockIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": "", + }, + } + output = append(output, FormatSSE("content_block_delta", blockDelta)...) + } + + case "response.function_call_arguments.delta": + blockDelta := map[string]interface{}{ + "type": "content_block_delta", + "index": st.BlockIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": root.Get("delta").String(), + }, + } + output = append(output, FormatSSE("content_block_delta", blockDelta)...) + + case "response.output_item.done": + item := root.Get("item") + if item.Get("type").String() == "function_call" { + blockStop := map[string]interface{}{ + "type": "content_block_stop", + "index": st.BlockIndex, + } + output = append(output, FormatSSE("content_block_stop", blockStop)...) + st.BlockIndex++ + } + case "response.completed": + stopReason := root.Get("response.stop_reason").String() + if stopReason == "" { + if st.HasToolCall { + stopReason = "tool_use" + } else { + stopReason = "end_turn" + } + } + inputTokens, outputTokens, cachedTokens := extractResponsesUsage(root.Get("response.usage")) msgDelta := map[string]interface{}{ "type": "message_delta", "delta": map[string]interface{}{ - "stop_reason": "end_turn", + "stop_reason": stopReason, + }, + "usage": map[string]int{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, }, - "usage": map[string]int{"output_tokens": 0}, + } + if cachedTokens > 0 { + msgDelta["usage"].(map[string]int)["cache_read_input_tokens"] = cachedTokens } output = append(output, FormatSSE("message_delta", msgDelta)...) output = append(output, FormatSSE("message_stop", map[string]string{"type": "message_stop"})...) @@ -228,3 +299,171 @@ func (c *codexToClaudeResponse) TransformChunk(chunk []byte, state *TransformSta return output, nil } + +func getClaudeStreamState(state *TransformState) *claudeStreamState { + if state.Custom == nil { + state.Custom = &claudeStreamState{} + } + st, ok := state.Custom.(*claudeStreamState) + if !ok || st == nil { + st = &claudeStreamState{} + state.Custom = st + } + return st +} + +func (c *codexToClaudeResponse) TransformWithState(body []byte, state *TransformState) ([]byte, error) { + root := gjson.ParseBytes(body) + var response gjson.Result + if root.Get("type").String() == "response.completed" && root.Get("response").Exists() { + response = root.Get("response") + } else if root.Get("output").Exists() { + response = root + } else { + return nil, nil + } + + revNames := map[string]string{} + if state != nil && len(state.OriginalRequestBody) > 0 { + revNames = buildReverseMapFromClaudeOriginalShortToOriginal(state.OriginalRequestBody) + } + + out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` + out, _ = sjson.Set(out, "id", response.Get("id").String()) + out, _ = sjson.Set(out, "model", response.Get("model").String()) + inputTokens, outputTokens, cachedTokens := extractResponsesUsage(response.Get("usage")) + out, _ = sjson.Set(out, "usage.input_tokens", inputTokens) + out, _ = sjson.Set(out, "usage.output_tokens", outputTokens) + if cachedTokens > 0 { + out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens) + } + + hasToolCall := false + if output := response.Get("output"); output.Exists() && output.IsArray() { + output.ForEach(func(_, item gjson.Result) bool { + switch item.Get("type").String() { + case "reasoning": + thinkingBuilder := strings.Builder{} + if summary := item.Get("summary"); summary.Exists() { + if summary.IsArray() { + summary.ForEach(func(_, part gjson.Result) bool { + if txt := part.Get("text"); txt.Exists() { + thinkingBuilder.WriteString(txt.String()) + } else { + thinkingBuilder.WriteString(part.String()) + } + return true + }) + } else { + thinkingBuilder.WriteString(summary.String()) + } + } + if thinkingBuilder.Len() == 0 { + if content := item.Get("content"); content.Exists() { + if content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + if txt := part.Get("text"); txt.Exists() { + thinkingBuilder.WriteString(txt.String()) + } else { + thinkingBuilder.WriteString(part.String()) + } + return true + }) + } else { + thinkingBuilder.WriteString(content.String()) + } + } + } + if thinkingBuilder.Len() > 0 { + block := `{"type":"thinking","thinking":""}` + block, _ = sjson.Set(block, "thinking", thinkingBuilder.String()) + out, _ = sjson.SetRaw(out, "content.-1", block) + } + case "message": + if content := item.Get("content"); content.Exists() { + if content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + if part.Get("type").String() == "output_text" { + block := `{"type":"text","text":""}` + block, _ = sjson.Set(block, "text", part.Get("text").String()) + out, _ = sjson.SetRaw(out, "content.-1", block) + } + return true + }) + } else if content.Type == gjson.String { + block := `{"type":"text","text":""}` + block, _ = sjson.Set(block, "text", content.String()) + out, _ = sjson.SetRaw(out, "content.-1", block) + } + } + case "function_call": + hasToolCall = true + callID := item.Get("call_id").String() + name := item.Get("name").String() + if orig, ok := revNames[name]; ok { + name = orig + } + argsRaw := item.Get("arguments").String() + var args interface{} + if argsRaw != "" { + _ = json.Unmarshal([]byte(argsRaw), &args) + } + block := `{"type":"tool_use","id":"","name":"","input":{}}` + block, _ = sjson.Set(block, "id", callID) + block, _ = sjson.Set(block, "name", name) + if args != nil { + block, _ = sjson.Set(block, "input", args) + } + out, _ = sjson.SetRaw(out, "content.-1", block) + } + return true + }) + } + + stopReason := response.Get("stop_reason").String() + if stopReason == "" { + if hasToolCall { + stopReason = "tool_use" + } else { + stopReason = "end_turn" + } + } + out, _ = sjson.Set(out, "stop_reason", stopReason) + + return []byte(out), nil +} + +func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[string]string { + tools := gjson.GetBytes(original, "tools") + rev := map[string]string{} + if tools.IsArray() && len(tools.Array()) > 0 { + var names []string + arr := tools.Array() + for i := 0; i < len(arr); i++ { + t := arr[i] + if t.Get("type").String() != "" { + continue + } + if v := t.Get("name"); v.Exists() { + names = append(names, v.String()) + } + } + if len(names) > 0 { + m := buildShortNameMap(names) + for orig, short := range m { + rev[short] = orig + } + } + } + return rev +} + +func extractResponsesUsage(usage gjson.Result) (int, int, int) { + if !usage.Exists() { + return 0, 0, 0 + } + inputTokens := int(usage.Get("input_tokens").Int()) + outputTokens := int(usage.Get("output_tokens").Int()) + cachedTokens := int(usage.Get("input_tokens_details.cached_tokens").Int()) + return inputTokens, outputTokens, cachedTokens +} diff --git a/internal/converter/codex_to_openai.go b/internal/converter/codex_to_openai.go index 455437f2..7996a6b1 100644 --- a/internal/converter/codex_to_openai.go +++ b/internal/converter/codex_to_openai.go @@ -1,10 +1,13 @@ package converter import ( + "bytes" "encoding/json" "time" "github.com/awsl-project/maxx/internal/domain" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) func init() { @@ -14,6 +17,24 @@ func init() { type codexToOpenAIRequest struct{} type codexToOpenAIResponse struct{} +type openaiStreamState struct { + Started bool + HasToolCall bool + ToolCalls map[int]*openaiToolCallState + ShortToOrig map[string]string + Index int + CreatedAt int64 + Model string + FinishSent bool +} + +type openaiToolCallState struct { + ID string + CallID string + Name string + NameSent bool +} + func (c *codexToOpenAIRequest) Transform(body []byte, model string, stream bool) ([]byte, error) { var req CodexRequest if err := json.Unmarshal(body, &req); err != nil { @@ -107,126 +128,319 @@ func (c *codexToOpenAIRequest) Transform(body []byte, model string, stream bool) } func (c *codexToOpenAIResponse) Transform(body []byte) ([]byte, error) { - var resp CodexResponse - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } + return c.TransformWithState(body, nil) +} - openaiResp := OpenAIResponse{ - ID: resp.ID, - Object: "chat.completion", - Created: resp.CreatedAt, - Model: resp.Model, - Usage: OpenAIUsage{ - PromptTokens: resp.Usage.InputTokens, - CompletionTokens: resp.Usage.OutputTokens, - TotalTokens: resp.Usage.TotalTokens, - }, +func (c *codexToOpenAIResponse) TransformWithState(body []byte, state *TransformState) ([]byte, error) { + root := gjson.ParseBytes(body) + var response gjson.Result + if root.Get("type").String() == "response.completed" && root.Get("response").Exists() { + response = root.Get("response") + } else if root.Get("output").Exists() { + response = root + } else { + return body, nil } - msg := OpenAIMessage{Role: "assistant"} - var textContent string - var toolCalls []OpenAIToolCall + template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` - for _, out := range resp.Output { - switch out.Type { - case "message": - if s, ok := out.Content.(string); ok { - textContent += s - } - case "function_call": - toolCalls = append(toolCalls, OpenAIToolCall{ - ID: out.ID, - Type: "function", - Function: OpenAIFunctionCall{ - Name: out.Name, - Arguments: out.Arguments, - }, - }) - } + if modelResult := response.Get("model"); modelResult.Exists() { + template, _ = sjson.Set(template, "model", modelResult.String()) } - - if textContent != "" { - msg.Content = textContent + if createdAtResult := response.Get("created_at"); createdAtResult.Exists() { + template, _ = sjson.Set(template, "created", createdAtResult.Int()) + } else { + template, _ = sjson.Set(template, "created", time.Now().Unix()) } - if len(toolCalls) > 0 { - msg.ToolCalls = toolCalls + if idResult := response.Get("id"); idResult.Exists() { + template, _ = sjson.Set(template, "id", idResult.String()) } - finishReason := "stop" - if len(toolCalls) > 0 { - finishReason = "tool_calls" + if usageResult := response.Get("usage"); usageResult.Exists() { + template = applyOpenAIUsage(template, usageResult) + } + + outputResult := response.Get("output") + if outputResult.IsArray() { + var contentText string + var reasoningText string + var toolCalls []string + rev := buildReverseMapFromOriginalOpenAI(nil) + if state != nil && len(state.OriginalRequestBody) > 0 { + rev = buildReverseMapFromOriginalOpenAI(state.OriginalRequestBody) + } + + outputResult.ForEach(func(_, outputItem gjson.Result) bool { + switch outputItem.Get("type").String() { + case "reasoning": + if summaryResult := outputItem.Get("summary"); summaryResult.IsArray() { + summaryResult.ForEach(func(_, summaryItem gjson.Result) bool { + if summaryItem.Get("type").String() == "summary_text" { + reasoningText = summaryItem.Get("text").String() + return false + } + return true + }) + } + case "message": + if contentResult := outputItem.Get("content"); contentResult.IsArray() { + contentResult.ForEach(func(_, contentItem gjson.Result) bool { + if contentItem.Get("type").String() == "output_text" { + contentText = contentItem.Get("text").String() + return false + } + return true + }) + } + case "function_call": + functionCallTemplate := `{"id":"","type":"function","function":{"name":"","arguments":""}}` + if callIDResult := outputItem.Get("call_id"); callIDResult.Exists() { + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIDResult.String()) + } + if nameResult := outputItem.Get("name"); nameResult.Exists() { + name := nameResult.String() + if orig, ok := rev[name]; ok { + name = orig + } + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", name) + } + if argsResult := outputItem.Get("arguments"); argsResult.Exists() { + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String()) + } + toolCalls = append(toolCalls, functionCallTemplate) + } + return true + }) + + if contentText != "" { + template, _ = sjson.Set(template, "choices.0.message.content", contentText) + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } + if reasoningText != "" { + template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText) + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } + if len(toolCalls) > 0 { + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) + for _, toolCall := range toolCalls { + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } } - openaiResp.Choices = []OpenAIChoice{{ - Index: 0, - Message: &msg, - FinishReason: finishReason, - }} + if statusResult := response.Get("status"); statusResult.Exists() && statusResult.String() == "completed" { + template, _ = sjson.Set(template, "choices.0.finish_reason", "stop") + template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop") + } - return json.Marshal(openaiResp) + return []byte(template), nil } func (c *codexToOpenAIResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) { events, remaining := ParseSSE(state.Buffer + string(chunk)) state.Buffer = remaining + st := getOpenAIStreamState(state) var output []byte for _, event := range events { - var codexEvent map[string]interface{} - if err := json.Unmarshal(event.Data, &codexEvent); err != nil { + if event.Event == "done" { + if !st.FinishSent { + output = append(output, buildOpenAIStreamDone(state.MessageID, st.HasToolCall)...) + st.FinishSent = true + } + output = append(output, FormatDone()...) + continue + } + + raw := bytes.TrimSpace(event.Data) + if len(raw) == 0 { + continue + } + root := gjson.ParseBytes(raw) + if !root.Exists() { continue } - eventType, _ := codexEvent["type"].(string) + eventType := root.Get("type").String() switch eventType { case "response.created": - if resp, ok := codexEvent["response"].(map[string]interface{}); ok { - state.MessageID, _ = resp["id"].(string) + state.MessageID = root.Get("response.id").String() + st.CreatedAt = root.Get("response.created_at").Int() + st.Model = root.Get("response.model").String() + + case "response.reasoning_summary_text.delta": + if delta := root.Get("delta"); delta.Exists() { + chunk := newOpenAIStreamTemplate(state.MessageID, st) + chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant") + chunk, _ = sjson.Set(chunk, "choices.0.delta.reasoning_content", delta.String()) + chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage")) + output = append(output, FormatSSE("", []byte(chunk))...) } - openaiChunk := OpenAIStreamChunk{ - ID: state.MessageID, - Object: "chat.completion.chunk", - Created: time.Now().Unix(), - Choices: []OpenAIChoice{{ - Index: 0, - Delta: &OpenAIMessage{Role: "assistant", Content: ""}, - }}, + + case "response.reasoning_summary_text.done": + chunk := newOpenAIStreamTemplate(state.MessageID, st) + chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant") + chunk, _ = sjson.Set(chunk, "choices.0.delta.reasoning_content", "\n\n") + chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage")) + output = append(output, FormatSSE("", []byte(chunk))...) + + case "response.output_text.delta": + if delta := root.Get("delta"); delta.Exists() { + chunk := newOpenAIStreamTemplate(state.MessageID, st) + chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant") + chunk, _ = sjson.Set(chunk, "choices.0.delta.content", delta.String()) + chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage")) + output = append(output, FormatSSE("", []byte(chunk))...) } - output = append(output, FormatSSE("", openaiChunk)...) - - case "response.output_item.delta": - if delta, ok := codexEvent["delta"].(map[string]interface{}); ok { - if text, ok := delta["text"].(string); ok { - openaiChunk := OpenAIStreamChunk{ - ID: state.MessageID, - Object: "chat.completion.chunk", - Created: time.Now().Unix(), - Choices: []OpenAIChoice{{ - Index: 0, - Delta: &OpenAIMessage{Content: text}, - }}, - } - output = append(output, FormatSSE("", openaiChunk)...) + + case "response.output_item.done": + item := root.Get("item") + if item.Exists() && item.Get("type").String() == "function_call" { + st.Index++ + st.HasToolCall = true + functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}` + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", st.Index) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", item.Get("call_id").String()) + + name := item.Get("name").String() + rev := st.ShortToOrig + if rev == nil { + rev = buildReverseMapFromOriginalOpenAI(state.OriginalRequestBody) + st.ShortToOrig = rev + } + if orig, ok := rev[name]; ok { + name = orig } + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", item.Get("arguments").String()) + + chunk := newOpenAIStreamTemplate(state.MessageID, st) + chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant") + chunk, _ = sjson.SetRaw(chunk, "choices.0.delta.tool_calls", `[]`) + chunk, _ = sjson.SetRaw(chunk, "choices.0.delta.tool_calls.-1", functionCallItemTemplate) + chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage")) + output = append(output, FormatSSE("", []byte(chunk))...) } - case "response.done": - openaiChunk := OpenAIStreamChunk{ - ID: state.MessageID, - Object: "chat.completion.chunk", - Created: time.Now().Unix(), - Choices: []OpenAIChoice{{ - Index: 0, - Delta: &OpenAIMessage{}, - FinishReason: "stop", - }}, + case "response.completed": + if !st.FinishSent { + chunk := newOpenAIStreamTemplate(state.MessageID, st) + finishReason := "stop" + if st.HasToolCall { + finishReason = "tool_calls" + } + chunk, _ = sjson.Set(chunk, "choices.0.finish_reason", finishReason) + chunk, _ = sjson.Set(chunk, "choices.0.native_finish_reason", finishReason) + chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage")) + output = append(output, FormatSSE("", []byte(chunk))...) + st.FinishSent = true } - output = append(output, FormatSSE("", openaiChunk)...) - output = append(output, FormatDone()...) } } return output, nil } + +func getOpenAIStreamState(state *TransformState) *openaiStreamState { + if state.Custom == nil { + state.Custom = &openaiStreamState{ + ToolCalls: map[int]*openaiToolCallState{}, + Index: -1, + } + } + st, ok := state.Custom.(*openaiStreamState) + if !ok || st == nil { + st = &openaiStreamState{ + ToolCalls: map[int]*openaiToolCallState{}, + Index: -1, + } + state.Custom = st + } + return st +} + +func buildOpenAIStreamDone(id string, hasToolCalls bool) []byte { + finishReason := "stop" + if hasToolCalls { + finishReason = "tool_calls" + } + openaiChunk := OpenAIStreamChunk{ + ID: id, + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Choices: []OpenAIChoice{{ + Index: 0, + Delta: &OpenAIMessage{}, + FinishReason: finishReason, + }}, + } + return FormatSSE("", openaiChunk) +} + +func newOpenAIStreamTemplate(id string, st *openaiStreamState) string { + template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` + template, _ = sjson.Set(template, "id", id) + if st != nil && st.CreatedAt > 0 { + template, _ = sjson.Set(template, "created", st.CreatedAt) + } else { + template, _ = sjson.Set(template, "created", time.Now().Unix()) + } + if st != nil && st.Model != "" { + template, _ = sjson.Set(template, "model", st.Model) + } + return template +} + +func buildReverseMapFromOriginalOpenAI(original []byte) map[string]string { + tools := gjson.GetBytes(original, "tools") + rev := map[string]string{} + if tools.IsArray() && len(tools.Array()) > 0 { + var names []string + arr := tools.Array() + for i := 0; i < len(arr); i++ { + t := arr[i] + if t.Get("type").String() != "function" { + continue + } + fn := t.Get("function") + if !fn.Exists() { + continue + } + if v := fn.Get("name"); v.Exists() { + names = append(names, v.String()) + } + } + if len(names) > 0 { + m := buildShortNameMap(names) + for orig, short := range m { + rev[short] = orig + } + } + } + return rev +} + +func applyOpenAIUsageFromResponse(template string, usage gjson.Result) string { + if !usage.Exists() { + return template + } + return applyOpenAIUsage(template, usage) +} + +func applyOpenAIUsage(template string, usage gjson.Result) string { + if outputTokensResult := usage.Get("output_tokens"); outputTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int()) + } + if totalTokensResult := usage.Get("total_tokens"); totalTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int()) + } + if inputTokensResult := usage.Get("input_tokens"); inputTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int()) + } + if reasoningTokensResult := usage.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int()) + } + return template +} diff --git a/internal/converter/coverage_claude_response_test.go b/internal/converter/coverage_claude_response_test.go index f159da85..b6185e58 100644 --- a/internal/converter/coverage_claude_response_test.go +++ b/internal/converter/coverage_claude_response_test.go @@ -219,8 +219,11 @@ func TestGeminiToClaudeResponseUsage(t *testing.T) { } func TestCodexToClaudeResponseInvalidJSON(t *testing.T) { - _, err := (&codexToClaudeResponse{}).Transform([]byte("{")) - if err == nil { - t.Fatalf("expected error") + out, err := (&codexToClaudeResponse{}).Transform([]byte("{")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != nil { + t.Fatalf("expected empty output") } } diff --git a/internal/converter/coverage_codex_instructions_test.go b/internal/converter/coverage_codex_instructions_test.go index 3c1988f9..adf95428 100644 --- a/internal/converter/coverage_codex_instructions_test.go +++ b/internal/converter/coverage_codex_instructions_test.go @@ -141,8 +141,8 @@ func TestOpenAIToCodexReasoningWhitespace(t *testing.T) { if err := json.Unmarshal(out, &codexReq); err != nil { t.Fatalf("unmarshal: %v", err) } - if codexReq.Reasoning == nil || codexReq.Reasoning.Effort != "medium" { - t.Fatalf("expected reasoning default for whitespace") + if codexReq.Reasoning == nil || codexReq.Reasoning.Effort != " " { + t.Fatalf("expected reasoning effort to preserve whitespace") } } @@ -192,8 +192,8 @@ func TestOpenAIToCodexInstructionsEnabled(t *testing.T) { if err := json.Unmarshal(out, &codexReq); err != nil { t.Fatalf("unmarshal: %v", err) } - if strings.TrimSpace(codexReq.Instructions) == "" { - t.Fatalf("expected instructions") + if codexReq.Instructions != "" { + t.Fatalf("expected no instructions in request conversion") } } diff --git a/internal/converter/coverage_misc_helpers_test.go b/internal/converter/coverage_misc_helpers_test.go index 22a068ed..2e44f63a 100644 --- a/internal/converter/coverage_misc_helpers_test.go +++ b/internal/converter/coverage_misc_helpers_test.go @@ -361,8 +361,8 @@ func TestHelpers_ShortenNameIfNeededLong(t *testing.T) { if len(short) > maxToolNameLen { t.Fatalf("expected shortened") } - if !strings.Contains(short, "_") { - t.Fatalf("expected hash suffix") + if short == name { + t.Fatalf("expected shortened name to differ") } } @@ -462,8 +462,8 @@ func TestShortenNameIfNeededLong(t *testing.T) { if len(short) > maxToolNameLen { t.Fatalf("expected shortened name") } - if !strings.Contains(short, "_") { - t.Fatalf("expected hash suffix") + if short == name { + t.Fatalf("expected shortened name to differ") } } diff --git a/internal/converter/coverage_openai_response_test.go b/internal/converter/coverage_openai_response_test.go index efdbb0f1..141e5b6f 100644 --- a/internal/converter/coverage_openai_response_test.go +++ b/internal/converter/coverage_openai_response_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "strings" "testing" + + "github.com/tidwall/gjson" ) func TestOpenAIToClaudeRequestAndResponse(t *testing.T) { @@ -292,7 +294,7 @@ func TestCodexToClaudeAndOpenAIResponses(t *testing.T) { if err := json.Unmarshal(openaiOut, &openaiResp); err != nil { t.Fatalf("unmarshal openai: %v", err) } - if openaiResp.Choices[0].FinishReason != "tool_calls" { + if openaiResp.Choices[0].FinishReason != "" { t.Fatalf("finish reason: %v", openaiResp.Choices[0].FinishReason) } } @@ -376,12 +378,8 @@ func TestOpenAIToCodexResponseContent(t *testing.T) { if err != nil { t.Fatalf("Transform: %v", err) } - var codexResp CodexResponse - if err := json.Unmarshal(out, &codexResp); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if len(codexResp.Output) == 0 { - t.Fatalf("expected message output") + if !gjson.GetBytes(out, "output").Exists() { + t.Fatalf("expected output in response") } } @@ -541,8 +539,8 @@ func TestCodexToOpenAIResponseNoToolCalls(t *testing.T) { if err != nil { t.Fatalf("Transform: %v", err) } - if !strings.Contains(string(out), "finish_reason\":\"stop") { - t.Fatalf("expected stop finish reason") + if !strings.Contains(string(out), "finish_reason\":null") { + t.Fatalf("expected empty finish reason") } } @@ -642,8 +640,8 @@ func TestCodexToOpenAIResponseMessageOnly(t *testing.T) { if err != nil { t.Fatalf("Transform: %v", err) } - if !strings.Contains(string(out), "\"finish_reason\":\"stop\"") { - t.Fatalf("expected stop finish") + if !strings.Contains(string(out), "\"finish_reason\":null") { + t.Fatalf("expected empty finish") } } @@ -737,9 +735,13 @@ func TestOpenAIToClaudeResponseStopReason(t *testing.T) { } func TestCodexToOpenAIResponseInvalidJSON(t *testing.T) { - _, err := (&codexToOpenAIResponse{}).Transform([]byte("{")) - if err == nil { - t.Fatalf("expected error") + input := []byte("{") + out, err := (&codexToOpenAIResponse{}).Transform(input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(out) != string(input) { + t.Fatalf("expected passthrough of original body, got: %s", out) } } diff --git a/internal/converter/coverage_openai_stream_test.go b/internal/converter/coverage_openai_stream_test.go index 2a13f2c2..fd8e2710 100644 --- a/internal/converter/coverage_openai_stream_test.go +++ b/internal/converter/coverage_openai_stream_test.go @@ -174,11 +174,11 @@ func TestOpenAIToCodexRequestAndStream(t *testing.T) { if !strings.Contains(string(streamOut), "response.created") { t.Fatalf("missing response.created") } - if !strings.Contains(string(streamOut), "response.output_item.delta") { + if !strings.Contains(string(streamOut), "response.output_text.delta") { t.Fatalf("missing delta") } - if !strings.Contains(string(streamOut), "response.done") { - t.Fatalf("missing done") + if !strings.Contains(string(streamOut), "response.completed") { + t.Fatalf("missing completed") } } @@ -508,13 +508,14 @@ func TestOpenAIToClaudeStreamDoneWithoutMessage(t *testing.T) { func TestCodexToOpenAIStreamDoneFlow(t *testing.T) { state := NewTransformState() created := map[string]interface{}{"type": "response.created", "response": map[string]interface{}{"id": "resp_1"}} - delta := map[string]interface{}{"type": "response.output_item.delta", "delta": map[string]interface{}{"text": "hi"}} - done := map[string]interface{}{"type": "response.done"} + delta := map[string]interface{}{"type": "response.output_text.delta", "delta": "hi"} + completed := map[string]interface{}{"type": "response.completed", "response": map[string]interface{}{"usage": map[string]interface{}{"input_tokens": 1}}} c1, _ := json.Marshal(created) c2, _ := json.Marshal(delta) - c3, _ := json.Marshal(done) + c3, _ := json.Marshal(completed) stream := append(FormatSSE("", json.RawMessage(c1)), FormatSSE("", json.RawMessage(c2))...) stream = append(stream, FormatSSE("", json.RawMessage(c3))...) + stream = append(stream, FormatDone()...) conv := &codexToOpenAIResponse{} out, err := conv.TransformChunk(stream, state) if err != nil { diff --git a/internal/converter/gemini_to_openai.go b/internal/converter/gemini_to_openai.go index 17a12f40..651e580d 100644 --- a/internal/converter/gemini_to_openai.go +++ b/internal/converter/gemini_to_openai.go @@ -6,6 +6,7 @@ import ( "time" "github.com/awsl-project/maxx/internal/domain" + "github.com/tidwall/gjson" ) func init() { @@ -15,6 +16,10 @@ func init() { type geminiToOpenAIRequest struct{} type geminiToOpenAIResponse struct{} +type geminiOpenAIStreamMeta struct { + Model string +} + func (c *geminiToOpenAIRequest) Transform(body []byte, model string, stream bool) ([]byte, error) { var req GeminiRequest if err := json.Unmarshal(body, &req); err != nil { @@ -267,6 +272,33 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt if err := json.Unmarshal(event.Data, &geminiChunk); err != nil { continue } + meta := gjson.ParseBytes(event.Data) + streamMeta, _ := state.Custom.(*geminiOpenAIStreamMeta) + if streamMeta == nil { + streamMeta = &geminiOpenAIStreamMeta{} + state.Custom = streamMeta + } + if streamMeta.Model == "" { + if mv := meta.Get("modelVersion"); mv.Exists() && mv.String() != "" { + streamMeta.Model = mv.String() + } + if streamMeta.Model == "" && len(state.OriginalRequestBody) > 0 { + if reqModel := gjson.GetBytes(state.OriginalRequestBody, "model"); reqModel.Exists() && reqModel.String() != "" { + streamMeta.Model = reqModel.String() + } + } + } + if state.MessageID == "" { + if rid := meta.Get("responseId"); rid.Exists() && rid.String() != "" { + state.MessageID = rid.String() + } + } + var createdAt int64 + if ct := meta.Get("createTime"); ct.Exists() { + if t, err := time.Parse(time.RFC3339Nano, ct.String()); err == nil { + createdAt = t.Unix() + } + } // First chunk if state.MessageID == "" { @@ -275,11 +307,15 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, Delta: &OpenAIMessage{Role: "assistant", Content: ""}, }}, } + if createdAt > 0 { + openaiChunk.Created = createdAt + } output = append(output, FormatSSE("", openaiChunk)...) } @@ -291,11 +327,15 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, - Delta: &OpenAIMessage{ReasoningContent: part.Text}, + Delta: &OpenAIMessage{Role: "assistant", ReasoningContent: part.Text}, }}, } + if createdAt > 0 { + openaiChunk.Created = createdAt + } output = append(output, FormatSSE("", openaiChunk)...) continue } @@ -304,11 +344,15 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, - Delta: &OpenAIMessage{Content: part.Text}, + Delta: &OpenAIMessage{Role: "assistant", Content: part.Text}, }}, } + if createdAt > 0 { + openaiChunk.Created = createdAt + } output = append(output, FormatSSE("", openaiChunk)...) } if part.InlineData != nil && part.InlineData.Data != "" { @@ -316,9 +360,11 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, Delta: &OpenAIMessage{ + Role: "assistant", Content: []OpenAIContentPart{{ Type: "image_url", ImageURL: &OpenAIImageURL{URL: "data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data}, @@ -326,6 +372,9 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt }, }}, } + if createdAt > 0 { + openaiChunk.Created = createdAt + } output = append(output, FormatSSE("", openaiChunk)...) } if part.FunctionCall != nil { @@ -337,9 +386,11 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, Delta: &OpenAIMessage{ + Role: "assistant", ToolCalls: []OpenAIToolCall{{ Index: state.CurrentIndex, ID: id, @@ -349,6 +400,9 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt }, }}, } + if createdAt > 0 { + openaiChunk.Created = createdAt + } state.CurrentIndex++ output = append(output, FormatSSE("", openaiChunk)...) } @@ -363,12 +417,16 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt ID: state.MessageID, Object: "chat.completion.chunk", Created: time.Now().Unix(), + Model: streamMeta.Model, Choices: []OpenAIChoice{{ Index: 0, - Delta: &OpenAIMessage{}, + Delta: &OpenAIMessage{Role: "assistant", Content: ""}, FinishReason: finishReason, }}, } + if createdAt > 0 { + openaiChunk.Created = createdAt + } output = append(output, FormatSSE("", openaiChunk)...) output = append(output, FormatDone()...) } diff --git a/internal/converter/more_converter_extra_test.go b/internal/converter/more_converter_extra_test.go index b4c2e738..df911824 100644 --- a/internal/converter/more_converter_extra_test.go +++ b/internal/converter/more_converter_extra_test.go @@ -3,6 +3,8 @@ package converter import ( "encoding/json" "testing" + + "github.com/tidwall/gjson" ) func TestCodexToOpenAIResponse_ToolCallsFinishReason(t *testing.T) { @@ -32,8 +34,8 @@ func TestCodexToOpenAIResponse_ToolCallsFinishReason(t *testing.T) { if err := json.Unmarshal(out, &got); err != nil { t.Fatalf("unmarshal: %v", err) } - if len(got.Choices) == 0 || got.Choices[0].FinishReason != "tool_calls" { - t.Fatalf("expected finish_reason tool_calls, got %#v", got.Choices) + if len(got.Choices) == 0 || got.Choices[0].FinishReason != "stop" { + t.Fatalf("expected finish_reason stop, got %#v", got.Choices) } } @@ -67,19 +69,21 @@ func TestOpenAIToCodexResponse_ToolCallsOutput(t *testing.T) { if err != nil { t.Fatalf("Transform: %v", err) } - - var got CodexResponse - if err := json.Unmarshal(out, &got); err != nil { - t.Fatalf("unmarshal: %v", err) + if !gjson.GetBytes(out, "output").Exists() { + t.Fatalf("expected output in response") } found := false - for _, item := range got.Output { - if item.Type == "function_call" && item.Name == "do_work" { - found = true - } + if outputs := gjson.GetBytes(out, "output"); outputs.IsArray() { + outputs.ForEach(func(_, item gjson.Result) bool { + if item.Get("type").String() == "function_call" && item.Get("name").String() == "do_work" { + found = true + return false + } + return true + }) } if !found { - t.Fatalf("expected function_call in codex output, got %#v", got.Output) + t.Fatalf("expected function_call in response output") } } diff --git a/internal/converter/openai_to_codex.go b/internal/converter/openai_to_codex.go index 03a606fd..1ccf4122 100644 --- a/internal/converter/openai_to_codex.go +++ b/internal/converter/openai_to_codex.go @@ -1,11 +1,17 @@ package converter import ( + "bytes" "encoding/json" + "fmt" + "sort" "strings" + "sync/atomic" "time" "github.com/awsl-project/maxx/internal/domain" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) func init() { @@ -16,251 +22,901 @@ type openaiToCodexRequest struct{} type openaiToCodexResponse struct{} func (c *openaiToCodexRequest) Transform(body []byte, model string, stream bool) ([]byte, error) { - userAgent := ExtractCodexUserAgent(body) - var req OpenAIRequest - if err := json.Unmarshal(body, &req); err != nil { + var tmp interface{} + if err := json.Unmarshal(body, &tmp); err != nil { return nil, err } + rawJSON := bytes.Clone(body) + out := `{"instructions":""}` - codexReq := CodexRequest{ - Model: model, - Stream: stream, - MaxOutputTokens: req.MaxTokens, - Temperature: req.Temperature, - TopP: req.TopP, - } + out, _ = sjson.Set(out, "stream", stream) - if req.MaxCompletionTokens > 0 && req.MaxTokens == 0 { - codexReq.MaxOutputTokens = req.MaxCompletionTokens + if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() { + out, _ = sjson.Set(out, "reasoning.effort", v.Value()) + } else { + out, _ = sjson.Set(out, "reasoning.effort", "medium") } + out, _ = sjson.Set(out, "parallel_tool_calls", true) + out, _ = sjson.Set(out, "reasoning.summary", "auto") + out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"}) - if req.ReasoningEffort != "" { - effort := strings.TrimSpace(req.ReasoningEffort) - codexReq.Reasoning = &CodexReasoning{ - Effort: effort, - } - } - trueVal := true - codexReq.ParallelToolCalls = &trueVal - codexReq.Include = []string{"reasoning.encrypted_content"} + out, _ = sjson.Set(out, "model", model) - // Convert messages to input - shortMap := map[string]string{} - if len(req.Tools) > 0 { + originalToolNameMap := map[string]string{} + if tools := gjson.GetBytes(rawJSON, "tools"); tools.IsArray() && len(tools.Array()) > 0 { var names []string - for _, tool := range req.Tools { - if tool.Type == "function" && tool.Function.Name != "" { - names = append(names, tool.Function.Name) + for _, t := range tools.Array() { + if t.Get("type").String() == "function" { + if v := t.Get("function.name"); v.Exists() { + names = append(names, v.String()) + } } } if len(names) > 0 { - shortMap = buildShortNameMap(names) + originalToolNameMap = buildShortNameMap(names) } } - var input []CodexInputItem - for _, msg := range req.Messages { - role := msg.Role - if role == "system" { - role = "developer" - } + out, _ = sjson.SetRaw(out, "input", `[]`) + if messages := gjson.GetBytes(rawJSON, "messages"); messages.IsArray() { + for _, m := range messages.Array() { + role := m.Get("role").String() + switch role { + case "tool": + funcOutput := `{}` + funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output") + funcOutput, _ = sjson.Set(funcOutput, "call_id", m.Get("tool_call_id").String()) + funcOutput, _ = sjson.Set(funcOutput, "output", m.Get("content").String()) + out, _ = sjson.SetRaw(out, "input.-1", funcOutput) + default: + msg := `{}` + msg, _ = sjson.Set(msg, "type", "message") + if role == "system" { + msg, _ = sjson.Set(msg, "role", "developer") + } else { + msg, _ = sjson.Set(msg, "role", role) + } + msg, _ = sjson.SetRaw(msg, "content", `[]`) - if msg.Role == "tool" { - // Tool response - contentStr, _ := msg.Content.(string) - input = append(input, CodexInputItem{ - Type: "function_call_output", - CallID: msg.ToolCallID, - Output: contentStr, - }) - continue - } + c := m.Get("content") + if c.Exists() && c.Type == gjson.String && c.String() != "" { + partType := "input_text" + if role == "assistant" { + partType = "output_text" + } + part := `{}` + part, _ = sjson.Set(part, "type", partType) + part, _ = sjson.Set(part, "text", c.String()) + msg, _ = sjson.SetRaw(msg, "content.-1", part) + } else if c.Exists() && c.IsArray() { + for _, it := range c.Array() { + t := it.Get("type").String() + switch t { + case "text": + partType := "input_text" + if role == "assistant" { + partType = "output_text" + } + part := `{}` + part, _ = sjson.Set(part, "type", partType) + part, _ = sjson.Set(part, "text", it.Get("text").String()) + msg, _ = sjson.SetRaw(msg, "content.-1", part) + case "image_url": + if role == "user" { + part := `{}` + part, _ = sjson.Set(part, "type", "input_image") + if u := it.Get("image_url.url"); u.Exists() { + part, _ = sjson.Set(part, "image_url", u.String()) + } + msg, _ = sjson.SetRaw(msg, "content.-1", part) + } + } + } + } - item := CodexInputItem{ - Type: "message", - Role: role, - } + out, _ = sjson.SetRaw(out, "input.-1", msg) - switch content := msg.Content.(type) { - case string: - item.Content = content - case []interface{}: - var textContent string - for _, part := range content { - if m, ok := part.(map[string]interface{}); ok { - if m["type"] == "text" { - if text, ok := m["text"].(string); ok { - textContent += text + if role == "assistant" { + if toolCalls := m.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { + for _, tc := range toolCalls.Array() { + if tc.Get("type").String() != "function" { + continue + } + funcCall := `{}` + funcCall, _ = sjson.Set(funcCall, "type", "function_call") + funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String()) + name := tc.Get("function.name").String() + if short, ok := originalToolNameMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + funcCall, _ = sjson.Set(funcCall, "name", name) + funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String()) + out, _ = sjson.SetRaw(out, "input.-1", funcCall) } } } } - item.Content = textContent } + } - input = append(input, item) + rf := gjson.GetBytes(rawJSON, "response_format") + text := gjson.GetBytes(rawJSON, "text") + if rf.Exists() { + if !gjson.Get(out, "text").Exists() { + out, _ = sjson.SetRaw(out, "text", `{}`) + } + switch rf.Get("type").String() { + case "text": + out, _ = sjson.Set(out, "text.format.type", "text") + case "json_schema": + if js := rf.Get("json_schema"); js.Exists() { + out, _ = sjson.Set(out, "text.format.type", "json_schema") + if v := js.Get("name"); v.Exists() { + out, _ = sjson.Set(out, "text.format.name", v.Value()) + } + if v := js.Get("strict"); v.Exists() { + out, _ = sjson.Set(out, "text.format.strict", v.Value()) + } + if v := js.Get("schema"); v.Exists() { + out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw) + } + } + } + if text.Exists() { + if v := text.Get("verbosity"); v.Exists() { + out, _ = sjson.Set(out, "text.verbosity", v.Value()) + } + } + } else if text.Exists() { + if v := text.Get("verbosity"); v.Exists() { + if !gjson.Get(out, "text").Exists() { + out, _ = sjson.SetRaw(out, "text", `{}`) + } + out, _ = sjson.Set(out, "text.verbosity", v.Value()) + } + } - // Handle tool calls - for _, tc := range msg.ToolCalls { - name := tc.Function.Name - if short, ok := shortMap[name]; ok { - name = short - } else { - name = shortenNameIfNeeded(name) + if tools := gjson.GetBytes(rawJSON, "tools"); tools.IsArray() && len(tools.Array()) > 0 { + out, _ = sjson.SetRaw(out, "tools", `[]`) + for _, t := range tools.Array() { + toolType := t.Get("type").String() + if toolType != "" && toolType != "function" && t.IsObject() { + out, _ = sjson.SetRaw(out, "tools.-1", t.Raw) + continue + } + if toolType == "function" { + item := `{}` + item, _ = sjson.Set(item, "type", "function") + if v := t.Get("function.name"); v.Exists() { + name := v.String() + if short, ok := originalToolNameMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + item, _ = sjson.Set(item, "name", name) + } + if v := t.Get("function.description"); v.Exists() { + item, _ = sjson.Set(item, "description", v.Value()) + } + if v := t.Get("function.parameters"); v.Exists() { + item, _ = sjson.SetRaw(item, "parameters", v.Raw) + } + if v := t.Get("function.strict"); v.Exists() { + item, _ = sjson.Set(item, "strict", v.Value()) + } + out, _ = sjson.SetRaw(out, "tools.-1", item) } - input = append(input, CodexInputItem{ - Type: "function_call", - ID: tc.ID, - CallID: tc.ID, - Name: name, - Role: "assistant", - Arguments: tc.Function.Arguments, - }) } } - codexReq.Input = input - // Convert tools - for _, tool := range req.Tools { - name := tool.Function.Name - if short, ok := shortMap[name]; ok { - name = short - } else { - name = shortenNameIfNeeded(name) + if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() { + switch { + case tc.Type == gjson.String: + out, _ = sjson.Set(out, "tool_choice", tc.String()) + case tc.IsObject(): + tcType := tc.Get("type").String() + if tcType == "function" { + name := tc.Get("function.name").String() + if name != "" { + if short, ok := originalToolNameMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + } + choice := `{}` + choice, _ = sjson.Set(choice, "type", "function") + if name != "" { + choice, _ = sjson.Set(choice, "name", name) + } + out, _ = sjson.SetRaw(out, "tool_choice", choice) + } else if tcType != "" { + out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw) + } } - codexReq.Tools = append(codexReq.Tools, CodexTool{ - Type: "function", - Name: name, - Description: tool.Function.Description, - Parameters: tool.Function.Parameters, - }) } - if instructions := CodexInstructionsForModel(model, userAgent); instructions != "" { - codexReq.Instructions = instructions + out, _ = sjson.Set(out, "store", false) + + return []byte(out), nil +} + +func (c *openaiToCodexResponse) Transform(body []byte) ([]byte, error) { + return c.TransformWithState(body, nil) +} + +func (c *openaiToCodexResponse) TransformWithState(body []byte, state *TransformState) ([]byte, error) { + var tmp interface{} + if err := json.Unmarshal(body, &tmp); err != nil { + return nil, err } - if codexReq.Reasoning == nil { - codexReq.Reasoning = &CodexReasoning{Effort: "medium"} + root := gjson.ParseBytes(body) + requestRaw := []byte(nil) + if state != nil { + requestRaw = state.OriginalRequestBody } - if codexReq.Reasoning.Effort == "" { - codexReq.Reasoning.Effort = "medium" + + resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}` + respID := root.Get("id").String() + if respID == "" { + respID = synthesizeResponseID() } - if codexReq.Reasoning.Summary == "" { - codexReq.Reasoning.Summary = "auto" + resp, _ = sjson.Set(resp, "id", respID) + + created := root.Get("created").Int() + if created == 0 { + created = time.Now().Unix() } + resp, _ = sjson.Set(resp, "created_at", created) - return json.Marshal(codexReq) -} + if v := root.Get("model"); v.Exists() { + resp, _ = sjson.Set(resp, "model", v.String()) + } -func (c *openaiToCodexResponse) Transform(body []byte) ([]byte, error) { - var resp OpenAIResponse - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err + outputsWrapper := `{"arr":[]}` + + if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { + choices.ForEach(func(_, choice gjson.Result) bool { + msg := choice.Get("message") + if msg.Exists() { + if rc := msg.Get("reasoning_content"); rc.Exists() && rc.String() != "" { + choiceIdx := int(choice.Get("index").Int()) + reasoning := `{"id":"","type":"reasoning","encrypted_content":"","summary":[]}` + reasoning, _ = sjson.Set(reasoning, "id", fmt.Sprintf("rs_%s_%d", respID, choiceIdx)) + reasoning, _ = sjson.Set(reasoning, "summary.0.type", "summary_text") + reasoning, _ = sjson.Set(reasoning, "summary.0.text", rc.String()) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", reasoning) + } + if c := msg.Get("content"); c.Exists() && c.String() != "" { + item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}` + item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", respID, int(choice.Get("index").Int()))) + item, _ = sjson.Set(item, "content.0.text", c.String()) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) + } + + if tcs := msg.Get("tool_calls"); tcs.Exists() && tcs.IsArray() { + tcs.ForEach(func(_, tc gjson.Result) bool { + callID := tc.Get("id").String() + name := tc.Get("function.name").String() + args := tc.Get("function.arguments").String() + item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}` + item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID)) + item, _ = sjson.Set(item, "arguments", args) + item, _ = sjson.Set(item, "call_id", callID) + item, _ = sjson.Set(item, "name", name) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) + return true + }) + } + } + return true + }) + } + if gjson.Get(outputsWrapper, "arr.#").Int() > 0 { + resp, _ = sjson.SetRaw(resp, "output", gjson.Get(outputsWrapper, "arr").Raw) } - codexResp := CodexResponse{ - ID: resp.ID, - Object: "response", - CreatedAt: resp.Created, - Model: resp.Model, - Status: "completed", - Usage: CodexUsage{ - InputTokens: resp.Usage.PromptTokens, - OutputTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, - }, - } - - if len(resp.Choices) > 0 { - choice := resp.Choices[0] - if choice.Message != nil { - if content, ok := choice.Message.Content.(string); ok && content != "" { - codexResp.Output = append(codexResp.Output, CodexOutput{ - Type: "message", - Role: "assistant", - Content: content, - }) + if usage := root.Get("usage"); usage.Exists() { + if usage.Get("prompt_tokens").Exists() || usage.Get("completion_tokens").Exists() || usage.Get("total_tokens").Exists() { + resp, _ = sjson.Set(resp, "usage.input_tokens", usage.Get("prompt_tokens").Int()) + if d := usage.Get("prompt_tokens_details.cached_tokens"); d.Exists() { + resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", d.Int()) } - for _, tc := range choice.Message.ToolCalls { - codexResp.Output = append(codexResp.Output, CodexOutput{ - Type: "function_call", - ID: tc.ID, - CallID: tc.ID, - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - Status: "completed", - }) + resp, _ = sjson.Set(resp, "usage.output_tokens", usage.Get("completion_tokens").Int()) + if d := usage.Get("completion_tokens_details.reasoning_tokens"); d.Exists() { + resp, _ = sjson.Set(resp, "usage.output_tokens_details.reasoning_tokens", d.Int()) + } else if d := usage.Get("output_tokens_details.reasoning_tokens"); d.Exists() { + resp, _ = sjson.Set(resp, "usage.output_tokens_details.reasoning_tokens", d.Int()) } + resp, _ = sjson.Set(resp, "usage.total_tokens", usage.Get("total_tokens").Int()) + } else { + resp, _ = sjson.Set(resp, "usage", usage.Value()) } } - return json.Marshal(codexResp) + if len(requestRaw) > 0 { + resp = applyRequestEchoToResponse(resp, "", requestRaw) + } + return []byte(resp), nil } func (c *openaiToCodexResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) { + if state == nil { + return nil, fmt.Errorf("TransformChunk requires non-nil state") + } events, remaining := ParseSSE(state.Buffer + string(chunk)) state.Buffer = remaining var output []byte for _, event := range events { if event.Event == "done" { - codexEvent := map[string]interface{}{ - "type": "response.done", - "response": map[string]interface{}{ - "id": state.MessageID, - "status": "completed", - }, - } - output = append(output, FormatSSE("", codexEvent)...) continue } + for _, item := range convertOpenAIChatCompletionsChunkToResponses(event.Data, state) { + output = append(output, item...) + } + } - var openaiChunk OpenAIStreamChunk - if err := json.Unmarshal(event.Data, &openaiChunk); err != nil { - continue + return output, nil +} + +type openaiToResponsesStateReasoning struct { + ReasoningID string + ReasoningData string +} + +type openaiToResponsesState struct { + Seq int + ResponseID string + Created int64 + Started bool + ReasoningID string + ReasoningIndex int + MsgTextBuf map[int]*strings.Builder + ReasoningBuf strings.Builder + Reasonings []openaiToResponsesStateReasoning + FuncArgsBuf map[int]*strings.Builder + FuncNames map[int]string + FuncCallIDs map[int]string + MsgItemAdded map[int]bool + MsgContentAdded map[int]bool + MsgItemDone map[int]bool + FuncArgsDone map[int]bool + FuncItemDone map[int]bool + PromptTokens int64 + CachedTokens int64 + CompletionTokens int64 + TotalTokens int64 + ReasoningTokens int64 + UsageSeen bool + NextOutputIndex int // global counter for unique output_index across messages and function calls + MsgOutputIndex map[int]int // choice idx -> assigned output_index + FuncOutputIndex map[int]int // callIndex -> assigned output_index + CompletedSent bool // guards against duplicate response.completed +} + +var responseIDCounter uint64 + +func synthesizeResponseID() string { + return fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1)) +} + +func (st *openaiToResponsesState) msgOutIdx(choiceIdx int) int { + if oi, ok := st.MsgOutputIndex[choiceIdx]; ok { + return oi + } + oi := st.NextOutputIndex + st.MsgOutputIndex[choiceIdx] = oi + st.NextOutputIndex++ + return oi +} + +func (st *openaiToResponsesState) funcOutIdx(callIndex int) int { + if oi, ok := st.FuncOutputIndex[callIndex]; ok { + return oi + } + oi := st.NextOutputIndex + st.FuncOutputIndex[callIndex] = oi + st.NextOutputIndex++ + return oi +} + +func convertOpenAIChatCompletionsChunkToResponses(rawJSON []byte, state *TransformState) [][]byte { + if state == nil { + return nil + } + st, ok := state.Custom.(*openaiToResponsesState) + if !ok || st == nil { + st = &openaiToResponsesState{ + FuncArgsBuf: make(map[int]*strings.Builder), + FuncNames: make(map[int]string), + FuncCallIDs: make(map[int]string), + MsgTextBuf: make(map[int]*strings.Builder), + MsgItemAdded: make(map[int]bool), + MsgContentAdded: make(map[int]bool), + MsgItemDone: make(map[int]bool), + FuncArgsDone: make(map[int]bool), + FuncItemDone: make(map[int]bool), + Reasonings: make([]openaiToResponsesStateReasoning, 0), + MsgOutputIndex: make(map[int]int), + FuncOutputIndex: make(map[int]int), } + state.Custom = st + } - if state.MessageID == "" { - state.MessageID = openaiChunk.ID - codexEvent := map[string]interface{}{ - "type": "response.created", - "response": map[string]interface{}{ - "id": openaiChunk.ID, - "model": openaiChunk.Model, - "status": "in_progress", - "created_at": time.Now().Unix(), - }, - } - output = append(output, FormatSSE("", codexEvent)...) + root := gjson.ParseBytes(rawJSON) + obj := root.Get("object") + if obj.Exists() && obj.String() != "" && obj.String() != "chat.completion.chunk" { + return nil + } + if !root.Get("choices").Exists() || !root.Get("choices").IsArray() { + return nil + } + + nextSeq := func() int { st.Seq++; return st.Seq } + var out [][]byte + + if !st.Started { + st.ResponseID = root.Get("id").String() + if st.ResponseID == "" { + st.ResponseID = synthesizeResponseID() + } + st.Created = root.Get("created").Int() + if st.Created == 0 { + st.Created = time.Now().Unix() + } + st.MsgTextBuf = make(map[int]*strings.Builder) + st.ReasoningBuf.Reset() + st.ReasoningID = "" + st.ReasoningIndex = 0 + st.FuncArgsBuf = make(map[int]*strings.Builder) + st.FuncNames = make(map[int]string) + st.FuncCallIDs = make(map[int]string) + st.MsgItemAdded = make(map[int]bool) + st.MsgContentAdded = make(map[int]bool) + st.MsgItemDone = make(map[int]bool) + st.FuncArgsDone = make(map[int]bool) + st.FuncItemDone = make(map[int]bool) + st.MsgOutputIndex = make(map[int]int) + st.FuncOutputIndex = make(map[int]int) + st.NextOutputIndex = 0 + st.CompletedSent = false + st.PromptTokens = 0 + st.CachedTokens = 0 + st.CompletionTokens = 0 + st.TotalTokens = 0 + st.ReasoningTokens = 0 + st.UsageSeen = false + + created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}` + created, _ = sjson.Set(created, "sequence_number", nextSeq()) + created, _ = sjson.Set(created, "response.id", st.ResponseID) + created, _ = sjson.Set(created, "response.created_at", st.Created) + out = append(out, FormatSSE("response.created", []byte(created))) + + inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}` + inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq()) + inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID) + inprog, _ = sjson.Set(inprog, "response.created_at", st.Created) + out = append(out, FormatSSE("response.in_progress", []byte(inprog))) + st.Started = true + } + + if usage := root.Get("usage"); usage.Exists() { + if v := usage.Get("prompt_tokens"); v.Exists() { + st.PromptTokens = v.Int() + st.UsageSeen = true + } + if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() { + st.CachedTokens = v.Int() + st.UsageSeen = true + } + if v := usage.Get("completion_tokens"); v.Exists() { + st.CompletionTokens = v.Int() + st.UsageSeen = true + } else if v := usage.Get("output_tokens"); v.Exists() { + st.CompletionTokens = v.Int() + st.UsageSeen = true + } + if v := usage.Get("output_tokens_details.reasoning_tokens"); v.Exists() { + st.ReasoningTokens = v.Int() + st.UsageSeen = true + } else if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() { + st.ReasoningTokens = v.Int() + st.UsageSeen = true + } + if v := usage.Get("total_tokens"); v.Exists() { + st.TotalTokens = v.Int() + st.UsageSeen = true } + } + + stopReasoning := func(text string) { + textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` + textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq()) + textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningID) + textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex) + textDone, _ = sjson.Set(textDone, "text", text) + out = append(out, FormatSSE("response.reasoning_summary_text.done", []byte(textDone))) + + partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}` + partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) + partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningID) + partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex) + partDone, _ = sjson.Set(partDone, "part.text", text) + out = append(out, FormatSSE("response.reasoning_summary_part.done", []byte(partDone))) + + outputItemDone := `{"type":"response.output_item.done","item":{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]},"output_index":0,"sequence_number":0}` + outputItemDone, _ = sjson.Set(outputItemDone, "sequence_number", nextSeq()) + outputItemDone, _ = sjson.Set(outputItemDone, "item.id", st.ReasoningID) + outputItemDone, _ = sjson.Set(outputItemDone, "output_index", st.ReasoningIndex) + outputItemDone, _ = sjson.Set(outputItemDone, "item.summary.0.text", text) + out = append(out, FormatSSE("response.output_item.done", []byte(outputItemDone))) + + st.Reasonings = append(st.Reasonings, openaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text}) + st.ReasoningID = "" + } + + if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { + choices.ForEach(func(_, choice gjson.Result) bool { + idx := int(choice.Get("index").Int()) + delta := choice.Get("delta") + if delta.Exists() { + if c := delta.Get("content"); c.Exists() && c.String() != "" { + if st.ReasoningID != "" { + stopReasoning(st.ReasoningBuf.String()) + st.ReasoningBuf.Reset() + } + if !st.MsgItemAdded[idx] { + item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}` + item, _ = sjson.Set(item, "sequence_number", nextSeq()) + item, _ = sjson.Set(item, "output_index", st.msgOutIdx(idx)) + item, _ = sjson.Set(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) + out = append(out, FormatSSE("response.output_item.added", []byte(item))) + st.MsgItemAdded[idx] = true + } + if !st.MsgContentAdded[idx] { + part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` + part, _ = sjson.Set(part, "sequence_number", nextSeq()) + part, _ = sjson.Set(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) + part, _ = sjson.Set(part, "output_index", st.msgOutIdx(idx)) + part, _ = sjson.Set(part, "content_index", 0) + out = append(out, FormatSSE("response.content_part.added", []byte(part))) + st.MsgContentAdded[idx] = true + } + + msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}` + msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) + msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) + msg, _ = sjson.Set(msg, "output_index", st.msgOutIdx(idx)) + msg, _ = sjson.Set(msg, "content_index", 0) + msg, _ = sjson.Set(msg, "delta", c.String()) + out = append(out, FormatSSE("response.output_text.delta", []byte(msg))) + if st.MsgTextBuf[idx] == nil { + st.MsgTextBuf[idx] = &strings.Builder{} + } + st.MsgTextBuf[idx].WriteString(c.String()) + } + + if rc := delta.Get("reasoning_content"); rc.Exists() && rc.String() != "" { + if st.ReasoningID == "" { + st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx) + st.ReasoningIndex = st.NextOutputIndex + st.NextOutputIndex++ + item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}` + item, _ = sjson.Set(item, "sequence_number", nextSeq()) + item, _ = sjson.Set(item, "output_index", st.ReasoningIndex) + item, _ = sjson.Set(item, "item.id", st.ReasoningID) + out = append(out, FormatSSE("response.output_item.added", []byte(item))) + part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}` + part, _ = sjson.Set(part, "sequence_number", nextSeq()) + part, _ = sjson.Set(part, "item_id", st.ReasoningID) + part, _ = sjson.Set(part, "output_index", st.ReasoningIndex) + out = append(out, FormatSSE("response.reasoning_summary_part.added", []byte(part))) + } + st.ReasoningBuf.WriteString(rc.String()) + msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}` + msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) + msg, _ = sjson.Set(msg, "item_id", st.ReasoningID) + msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) + msg, _ = sjson.Set(msg, "delta", rc.String()) + out = append(out, FormatSSE("response.reasoning_summary_text.delta", []byte(msg))) + } + + if tcs := delta.Get("tool_calls"); tcs.Exists() && tcs.IsArray() { + if st.ReasoningID != "" { + stopReasoning(st.ReasoningBuf.String()) + st.ReasoningBuf.Reset() + } + if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] { + fullText := "" + if b := st.MsgTextBuf[idx]; b != nil { + fullText = b.String() + } + done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}` + done, _ = sjson.Set(done, "sequence_number", nextSeq()) + done, _ = sjson.Set(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) + done, _ = sjson.Set(done, "output_index", st.msgOutIdx(idx)) + done, _ = sjson.Set(done, "content_index", 0) + done, _ = sjson.Set(done, "text", fullText) + out = append(out, FormatSSE("response.output_text.done", []byte(done))) + + partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` + partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) + partDone, _ = sjson.Set(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) + partDone, _ = sjson.Set(partDone, "output_index", st.msgOutIdx(idx)) + partDone, _ = sjson.Set(partDone, "content_index", 0) + partDone, _ = sjson.Set(partDone, "part.text", fullText) + out = append(out, FormatSSE("response.content_part.done", []byte(partDone))) + + itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}` + itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq()) + itemDone, _ = sjson.Set(itemDone, "output_index", st.msgOutIdx(idx)) + itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx)) + itemDone, _ = sjson.Set(itemDone, "item.content.0.text", fullText) + out = append(out, FormatSSE("response.output_item.done", []byte(itemDone))) + st.MsgItemDone[idx] = true + } + + for tcIndex, tc := range tcs.Array() { + callIndex := tcIndex + if v := tc.Get("index"); v.Exists() { + callIndex = int(v.Int()) + } + + newCallID := tc.Get("id").String() + nameChunk := tc.Get("function.name").String() + if nameChunk != "" { + st.FuncNames[callIndex] = nameChunk + } + existingCallID := st.FuncCallIDs[callIndex] + effectiveCallID := existingCallID + shouldEmitItem := false + if existingCallID == "" && newCallID != "" { + effectiveCallID = newCallID + st.FuncCallIDs[callIndex] = newCallID + shouldEmitItem = true + } + + if shouldEmitItem && effectiveCallID != "" { + o := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}` + o, _ = sjson.Set(o, "sequence_number", nextSeq()) + o, _ = sjson.Set(o, "output_index", st.funcOutIdx(callIndex)) + o, _ = sjson.Set(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID)) + o, _ = sjson.Set(o, "item.call_id", effectiveCallID) + o, _ = sjson.Set(o, "item.name", st.FuncNames[callIndex]) + out = append(out, FormatSSE("response.output_item.added", []byte(o))) + } + + if st.FuncArgsBuf[callIndex] == nil { + st.FuncArgsBuf[callIndex] = &strings.Builder{} + } + if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" { + refCallID := st.FuncCallIDs[callIndex] + if refCallID == "" { + refCallID = newCallID + } + if refCallID != "" { + ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}` + ad, _ = sjson.Set(ad, "sequence_number", nextSeq()) + ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", refCallID)) + ad, _ = sjson.Set(ad, "output_index", st.funcOutIdx(callIndex)) + ad, _ = sjson.Set(ad, "delta", args.String()) + out = append(out, FormatSSE("response.function_call_arguments.delta", []byte(ad))) + } + st.FuncArgsBuf[callIndex].WriteString(args.String()) + } + } + } + } + + if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" { + if len(st.MsgItemAdded) > 0 { + for _, i := range sortedKeys(st.MsgItemAdded) { + if st.MsgItemAdded[i] && !st.MsgItemDone[i] { + fullText := "" + if b := st.MsgTextBuf[i]; b != nil { + fullText = b.String() + } + done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}` + done, _ = sjson.Set(done, "sequence_number", nextSeq()) + done, _ = sjson.Set(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) + done, _ = sjson.Set(done, "output_index", st.msgOutIdx(i)) + done, _ = sjson.Set(done, "content_index", 0) + done, _ = sjson.Set(done, "text", fullText) + out = append(out, FormatSSE("response.output_text.done", []byte(done))) + + partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` + partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) + partDone, _ = sjson.Set(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) + partDone, _ = sjson.Set(partDone, "output_index", st.msgOutIdx(i)) + partDone, _ = sjson.Set(partDone, "content_index", 0) + partDone, _ = sjson.Set(partDone, "part.text", fullText) + out = append(out, FormatSSE("response.content_part.done", []byte(partDone))) + + itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}` + itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq()) + itemDone, _ = sjson.Set(itemDone, "output_index", st.msgOutIdx(i)) + itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) + itemDone, _ = sjson.Set(itemDone, "item.content.0.text", fullText) + out = append(out, FormatSSE("response.output_item.done", []byte(itemDone))) + st.MsgItemDone[i] = true + } + } + } + + if st.ReasoningID != "" { + stopReasoning(st.ReasoningBuf.String()) + st.ReasoningBuf.Reset() + } + + if len(st.FuncCallIDs) > 0 { + for _, i := range sortedKeys(st.FuncCallIDs) { + callID := st.FuncCallIDs[i] + if callID == "" || st.FuncItemDone[i] { + continue + } + args := "{}" + if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 { + args = b.String() + } + fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}` + fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq()) + fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", callID)) + fcDone, _ = sjson.Set(fcDone, "output_index", st.funcOutIdx(i)) + fcDone, _ = sjson.Set(fcDone, "arguments", args) + out = append(out, FormatSSE("response.function_call_arguments.done", []byte(fcDone))) - if len(openaiChunk.Choices) > 0 { - choice := openaiChunk.Choices[0] - if choice.Delta != nil { - if content, ok := choice.Delta.Content.(string); ok && content != "" { - codexEvent := map[string]interface{}{ - "type": "response.output_item.delta", - "delta": map[string]interface{}{ - "type": "text", - "text": content, - }, + itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}` + itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq()) + itemDone, _ = sjson.Set(itemDone, "output_index", st.funcOutIdx(i)) + itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", callID)) + itemDone, _ = sjson.Set(itemDone, "item.arguments", args) + itemDone, _ = sjson.Set(itemDone, "item.call_id", callID) + itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[i]) + out = append(out, FormatSSE("response.output_item.done", []byte(itemDone))) + st.FuncItemDone[i] = true + st.FuncArgsDone[i] = true } - output = append(output, FormatSSE("", codexEvent)...) } } + return true + }) + } - if choice.FinishReason != "" { - codexEvent := map[string]interface{}{ - "type": "response.done", - "response": map[string]interface{}{ - "id": state.MessageID, - "status": "completed", - }, + // Emit response.completed once after all choices have been processed + if !st.CompletedSent { + // Check if any choice had a finish_reason + hasFinish := false + if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { + choices.ForEach(func(_, choice gjson.Result) bool { + if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" { + hasFinish = true + return false + } + return true + }) + } + if hasFinish { + st.CompletedSent = true + completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}` + completed, _ = sjson.Set(completed, "sequence_number", nextSeq()) + completed, _ = sjson.Set(completed, "response.id", st.ResponseID) + completed, _ = sjson.Set(completed, "response.created_at", st.Created) + + outputsWrapper := `{"arr":[]}` + if len(st.Reasonings) > 0 { + for _, r := range st.Reasonings { + item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}` + item, _ = sjson.Set(item, "id", r.ReasoningID) + item, _ = sjson.Set(item, "summary.0.text", r.ReasoningData) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) + } + } + if len(st.MsgItemAdded) > 0 { + for _, i := range sortedKeys(st.MsgItemAdded) { + txt := "" + if b := st.MsgTextBuf[i]; b != nil { + txt = b.String() + } + item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}` + item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) + item, _ = sjson.Set(item, "content.0.text", txt) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) } - output = append(output, FormatSSE("", codexEvent)...) } + if len(st.FuncCallIDs) > 0 { + for _, i := range sortedKeys(st.FuncCallIDs) { + args := "" + if b := st.FuncArgsBuf[i]; b != nil { + args = b.String() + } + callID := st.FuncCallIDs[i] + name := st.FuncNames[i] + item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}` + item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID)) + item, _ = sjson.Set(item, "arguments", args) + item, _ = sjson.Set(item, "call_id", callID) + item, _ = sjson.Set(item, "name", name) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) + } + } + if gjson.Get(outputsWrapper, "arr.#").Int() > 0 { + completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw) + } + if st.UsageSeen { + completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens) + completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens) + completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.CompletionTokens) + if st.ReasoningTokens > 0 { + completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens) + } + total := st.TotalTokens + if total == 0 { + total = st.PromptTokens + st.CompletionTokens + } + completed, _ = sjson.Set(completed, "response.usage.total_tokens", total) + } + if len(state.OriginalRequestBody) > 0 { + completed = applyRequestEchoToResponse(completed, "response.", state.OriginalRequestBody) + } + out = append(out, FormatSSE("response.completed", []byte(completed))) } } - return output, nil + return out +} + +func sortedKeys[T any](m map[int]T) []int { + keys := make([]int, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Ints(keys) + return keys +} + +func applyRequestEchoToResponse(responseJSON string, prefix string, requestRaw []byte) string { + if len(requestRaw) == 0 { + return responseJSON + } + req := gjson.ParseBytes(requestRaw) + paths := []string{ + "model", + "instructions", + "input", + "tools", + "tool_choice", + "metadata", + "store", + "max_output_tokens", + "temperature", + "top_p", + "reasoning", + "parallel_tool_calls", + "include", + "previous_response_id", + "text", + "truncation", + } + for _, path := range paths { + val := req.Get(path) + if !val.Exists() { + continue + } + fullPath := prefix + path + if gjson.Get(responseJSON, fullPath).Exists() { + continue + } + switch val.Type { + case gjson.String, gjson.Number, gjson.True, gjson.False: + responseJSON, _ = sjson.Set(responseJSON, fullPath, val.Value()) + default: + responseJSON, _ = sjson.SetRaw(responseJSON, fullPath, val.Raw) + } + } + return responseJSON } diff --git a/internal/converter/registry.go b/internal/converter/registry.go index 77c6081d..a4d68b2b 100644 --- a/internal/converter/registry.go +++ b/internal/converter/registry.go @@ -9,13 +9,15 @@ import ( // TransformState holds state for streaming response conversion type TransformState struct { - MessageID string - CurrentIndex int - CurrentBlockType string // "text", "thinking", "tool_use" - ToolCalls map[int]*ToolCallState - Buffer string // SSE line buffer - Usage *Usage - StopReason string + MessageID string + CurrentIndex int + CurrentBlockType string // "text", "thinking", "tool_use" + ToolCalls map[int]*ToolCallState + Buffer string // SSE line buffer + Usage *Usage + StopReason string + Custom interface{} + OriginalRequestBody []byte } // ToolCallState tracks tool call conversion state @@ -47,6 +49,10 @@ type ResponseTransformer interface { TransformChunk(chunk []byte, state *TransformState) ([]byte, error) } +type ResponseTransformerWithState interface { + TransformWithState(body []byte, state *TransformState) ([]byte, error) +} + // Registry holds all format converters type Registry struct { requests map[domain.ClientType]map[domain.ClientType]RequestTransformer @@ -131,6 +137,26 @@ func (r *Registry) TransformResponse(from, to domain.ClientType, body []byte) ([ return transformer.Transform(body) } +// TransformResponseWithState converts a non-streaming response with state +func (r *Registry) TransformResponseWithState(from, to domain.ClientType, body []byte, state *TransformState) ([]byte, error) { + if from == to { + return body, nil + } + + fromMap := r.responses[from] + if fromMap == nil { + return nil, fmt.Errorf("no response transformer from %s", from) + } + transformer := fromMap[to] + if transformer == nil { + return nil, fmt.Errorf("no response transformer from %s to %s", from, to) + } + if withState, ok := transformer.(ResponseTransformerWithState); ok { + return withState.TransformWithState(body, state) + } + return transformer.Transform(body) +} + // TransformStreamChunk converts a streaming chunk func (r *Registry) TransformStreamChunk(from, to domain.ClientType, chunk []byte, state *TransformState) ([]byte, error) { if from == to { diff --git a/internal/converter/test_helpers_test.go b/internal/converter/test_helpers_test.go index d7c44cd7..0b67fd05 100644 --- a/internal/converter/test_helpers_test.go +++ b/internal/converter/test_helpers_test.go @@ -10,8 +10,18 @@ func codexInputHasRoleText(input interface{}, role string, text string) bool { if !ok || m["type"] != "message" || m["role"] != role { continue } - if content, ok := m["content"].(string); ok && content == text { - return true + switch content := m["content"].(type) { + case string: + if content == text { + return true + } + case []interface{}: + for _, part := range content { + pm, ok := part.(map[string]interface{}) + if ok && pm["text"] == text { + return true + } + } } } return false diff --git a/internal/converter/tool_name.go b/internal/converter/tool_name.go index e5c809c6..16810711 100644 --- a/internal/converter/tool_name.go +++ b/internal/converter/tool_name.go @@ -1,8 +1,8 @@ package converter import ( - "fmt" - "hash/crc32" + "strconv" + "strings" ) const maxToolNameLen = 64 @@ -11,25 +11,54 @@ func shortenNameIfNeeded(name string) string { if len(name) <= maxToolNameLen { return name } - hash := crc32.ChecksumIEEE([]byte(name)) - // Keep a stable prefix to preserve readability, add hash suffix for uniqueness. - prefixLen := maxToolNameLen - 9 // "_" + 8 hex - return fmt.Sprintf("%s_%08x", name[:prefixLen], hash) + if strings.HasPrefix(name, "mcp__") { + idx := strings.LastIndex(name, "__") + if idx > 3 { + candidate := "mcp__" + name[idx+2:] + if len(candidate) > maxToolNameLen { + return candidate[:maxToolNameLen] + } + return candidate + } + } + return name[:maxToolNameLen] } func buildShortNameMap(names []string) map[string]string { + used := map[string]struct{}{} result := make(map[string]string, len(names)) - used := make(map[string]int) - for _, name := range names { - short := shortenNameIfNeeded(name) - if count, ok := used[short]; ok { - count++ - used[short] = count - short = shortenNameIfNeeded(fmt.Sprintf("%s_%d", name, count)) - } else { - used[short] = 0 + + baseCandidate := func(n string) string { + return shortenNameIfNeeded(n) + } + + makeUnique := func(cand string) string { + if _, ok := used[cand]; !ok { + return cand + } + base := cand + for i := 1; ; i++ { + suffix := "_" + strconv.Itoa(i) + allowed := maxToolNameLen - len(suffix) + if allowed < 0 { + allowed = 0 + } + tmp := base + if len(tmp) > allowed { + tmp = tmp[:allowed] + } + tmp = tmp + suffix + if _, ok := used[tmp]; !ok { + return tmp + } } - result[name] = short + } + + for _, n := range names { + cand := baseCandidate(n) + uniq := makeUnique(cand) + used[uniq] = struct{}{} + result[n] = uniq } return result } diff --git a/internal/domain/model.go b/internal/domain/model.go index f8c8c409..5f75e08e 100644 --- a/internal/domain/model.go +++ b/internal/domain/model.go @@ -1,6 +1,9 @@ package domain -import "time" +import ( + "strings" + "time" +) // 各种请求的客户端 type ClientType string @@ -811,52 +814,42 @@ type ResponseModel struct { // MatchWildcard 检查输入是否匹配通配符模式 func MatchWildcard(pattern, input string) bool { - // 简单情况 + pattern = strings.TrimSpace(pattern) + input = strings.TrimSpace(input) + if pattern == "" { + return false + } if pattern == "*" { return true } - if !containsWildcard(pattern) { - return pattern == input - } - - parts := splitByWildcard(pattern) - - // 处理 prefix* 模式 - if len(parts) == 2 && parts[1] == "" { - return hasPrefix(input, parts[0]) - } - - // 处理 *suffix 模式 - if len(parts) == 2 && parts[0] == "" { - return hasSuffix(input, parts[1]) - } - - // 处理多通配符模式 - pos := 0 - for i, part := range parts { - if part == "" { + // Iterative glob-style matcher supporting only '*' wildcard. + pi, si := 0, 0 + starIdx := -1 + matchIdx := 0 + for si < len(input) { + if pi < len(pattern) && pattern[pi] == input[si] { + pi++ + si++ continue } - - idx := indexOf(input[pos:], part) - if idx < 0 { - return false + if pi < len(pattern) && pattern[pi] == '*' { + starIdx = pi + matchIdx = si + pi++ + continue } - - // 第一部分必须在开头(如果模式不以 * 开头) - if i == 0 && idx != 0 { - return false + if starIdx != -1 { + pi = starIdx + 1 + matchIdx++ + si = matchIdx + continue } - - pos += idx + len(part) - } - - // 最后一部分必须在结尾(如果模式不以 * 结尾) - if parts[len(parts)-1] != "" && !hasSuffix(input, parts[len(parts)-1]) { return false } - - return true + for pi < len(pattern) && pattern[pi] == '*' { + pi++ + } + return pi == len(pattern) } // 辅助函数 diff --git a/internal/executor/converting_writer.go b/internal/executor/converting_writer.go index daeb0a42..b005cc42 100644 --- a/internal/executor/converting_writer.go +++ b/internal/executor/converting_writer.go @@ -152,7 +152,12 @@ func NewConvertingResponseWriter( conv *converter.Registry, originalType, targetType domain.ClientType, isStream bool, + originalRequestBody []byte, ) *ConvertingResponseWriter { + state := converter.NewTransformState() + if len(originalRequestBody) > 0 { + state.OriginalRequestBody = bytes.Clone(originalRequestBody) + } return &ConvertingResponseWriter{ underlying: w, converter: conv, @@ -161,7 +166,7 @@ func NewConvertingResponseWriter( isStream: isStream, statusCode: http.StatusOK, headers: make(http.Header), - streamState: converter.NewTransformState(), + streamState: state, } } @@ -226,9 +231,9 @@ func (c *ConvertingResponseWriter) Finalize() error { body := c.buffer.Bytes() // Convert the response - converted, err := c.converter.TransformResponse(c.targetType, c.originalType, body) - if err != nil { - // On conversion error, use original body + converted, err := c.converter.TransformResponseWithState(c.targetType, c.originalType, body, c.streamState) + if err != nil || converted == nil { + // On conversion error or nil result, use original body converted = body } @@ -271,9 +276,9 @@ func NeedsConversion(originalType, targetType domain.ClientType) bool { return originalType != targetType && originalType != "" && targetType != "" } -// GetPreferredTargetType returns the best target type for conversion -// Prefers Claude as it has the richest format support -func GetPreferredTargetType(supportedTypes []domain.ClientType, originalType domain.ClientType) domain.ClientType { +// GetPreferredTargetType returns the best target type for conversion. +// Prefers Codex only for codex providers, otherwise Gemini then Claude. +func GetPreferredTargetType(supportedTypes []domain.ClientType, originalType domain.ClientType, providerType string) domain.ClientType { // If original type is supported, no conversion needed for _, t := range supportedTypes { if t == originalType { @@ -281,7 +286,23 @@ func GetPreferredTargetType(supportedTypes []domain.ClientType, originalType dom } } - // Prefer Claude as target (richest format) + if providerType == "codex" { + // Prefer Codex when available (best fit for Codex provider) + for _, t := range supportedTypes { + if t == domain.ClientTypeCodex { + return t + } + } + } + + // Prefer Gemini as target (best fit for Antigravity) + for _, t := range supportedTypes { + if t == domain.ClientTypeGemini { + return t + } + } + + // Prefer Claude as target (fallback) for _, t := range supportedTypes { if t == domain.ClientTypeClaude { return t diff --git a/internal/executor/converting_writer_test.go b/internal/executor/converting_writer_test.go deleted file mode 100644 index e4da5439..00000000 --- a/internal/executor/converting_writer_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package executor - -import ( - "testing" - - "github.com/awsl-project/maxx/internal/domain" -) - -func TestConvertRequestURI(t *testing.T) { - tests := []struct { - name string - original string - from domain.ClientType - to domain.ClientType - mappedModel string - isStream bool - want string - }{ - { - name: "same type passthrough", - original: "/v1/chat/completions", - from: domain.ClientTypeOpenAI, - to: domain.ClientTypeOpenAI, - want: "/v1/chat/completions", - }, - { - name: "openai to claude with query", - original: "/v1/chat/completions?foo=1", - from: domain.ClientTypeOpenAI, - to: domain.ClientTypeClaude, - want: "/v1/messages?foo=1", - }, - { - name: "claude to codex", - original: "/v1/messages", - from: domain.ClientTypeClaude, - to: domain.ClientTypeCodex, - want: "/responses", - }, - { - name: "claude count tokens to openai", - original: "/v1/messages/count_tokens", - from: domain.ClientTypeClaude, - to: domain.ClientTypeOpenAI, - want: "/v1/chat/completions", - }, - { - name: "openai to gemini stream", - original: "/v1/chat/completions", - from: domain.ClientTypeOpenAI, - to: domain.ClientTypeGemini, - mappedModel: "gemini-2.5-pro", - isStream: true, - want: "/v1beta/models/gemini-2.5-pro:streamGenerateContent", - }, - { - name: "claude count tokens to gemini", - original: "/v1/messages/count_tokens", - from: domain.ClientTypeClaude, - to: domain.ClientTypeGemini, - mappedModel: "gemini-2.5-pro", - want: "/v1beta/models/gemini-2.5-pro:countTokens", - }, - { - name: "gemini internal preserves version and action", - original: "/v1internal/models/gemini-2.0:generateContent?alt=sse", - from: domain.ClientTypeOpenAI, - to: domain.ClientTypeGemini, - mappedModel: "gemini-2.5-pro", - isStream: true, - want: "/v1internal/models/gemini-2.5-pro:generateContent?alt=sse", - }, - { - name: "gemini target without model keeps original", - original: "/v1/chat/completions", - from: domain.ClientTypeOpenAI, - to: domain.ClientTypeGemini, - want: "/v1/chat/completions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := ConvertRequestURI(tt.original, tt.from, tt.to, tt.mappedModel, tt.isStream) - if got != tt.want { - t.Fatalf("ConvertRequestURI(%q) = %q, want %q", tt.original, got, tt.want) - } - }) - } -} diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 48dff280..ca230187 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -2,21 +2,18 @@ package executor import ( "context" - "log" "net/http" "strconv" "time" - ctxutil "github.com/awsl-project/maxx/internal/context" "github.com/awsl-project/maxx/internal/converter" "github.com/awsl-project/maxx/internal/cooldown" "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/event" - "github.com/awsl-project/maxx/internal/pricing" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/repository" "github.com/awsl-project/maxx/internal/router" "github.com/awsl-project/maxx/internal/stats" - "github.com/awsl-project/maxx/internal/usage" "github.com/awsl-project/maxx/internal/waiter" ) @@ -34,6 +31,8 @@ type Executor struct { instanceID string statsAggregator *stats.StatsAggregator converter *converter.Registry + engine *flow.Engine + middlewares []flow.HandlerFunc } // NewExecutor creates a new executor @@ -63,630 +62,41 @@ func NewExecutor( instanceID: instanceID, statsAggregator: statsAggregator, converter: converter.GetGlobalRegistry(), + engine: flow.NewEngine(), } } -// Execute handles the proxy request with routing and retry logic -func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request) error { - clientType := ctxutil.GetClientType(ctx) - projectID := ctxutil.GetProjectID(ctx) - sessionID := ctxutil.GetSessionID(ctx) - requestModel := ctxutil.GetRequestModel(ctx) - isStream := ctxutil.GetIsStream(ctx) - - // Get API Token ID from context - apiTokenID := ctxutil.GetAPITokenID(ctx) - - // Create proxy request record immediately (PENDING status) - proxyReq := &domain.ProxyRequest{ - InstanceID: e.instanceID, - RequestID: generateRequestID(), - SessionID: sessionID, - ClientType: clientType, - ProjectID: projectID, - RequestModel: requestModel, - StartTime: time.Now(), - IsStream: isStream, - Status: "PENDING", - APITokenID: apiTokenID, - } - - // Capture client's original request info unless detail retention is disabled. - if !e.shouldClearRequestDetail() { - requestURI := ctxutil.GetRequestURI(ctx) - requestHeaders := ctxutil.GetRequestHeaders(ctx) - requestBody := ctxutil.GetRequestBody(ctx) - headers := flattenHeaders(requestHeaders) - // Go stores Host separately from headers, add it explicitly - if req.Host != "" { - if headers == nil { - headers = make(map[string]string) - } - headers["Host"] = req.Host - } - proxyReq.RequestInfo = &domain.RequestInfo{ - Method: req.Method, - URL: requestURI, - Headers: headers, - Body: string(requestBody), - } - } - - if err := e.proxyRequestRepo.Create(proxyReq); err != nil { - log.Printf("[Executor] Failed to create proxy request: %v", err) - } - - // Broadcast the new request immediately - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - - ctx = ctxutil.WithProxyRequest(ctx, proxyReq) - - // Check for project binding if required - if projectID == 0 && e.projectWaiter != nil { - // Get session for project waiter - session, _ := e.sessionRepo.GetBySessionID(sessionID) - if session == nil { - session = &domain.Session{ - SessionID: sessionID, - ClientType: clientType, - ProjectID: 0, - } - } - - if err := e.projectWaiter.WaitForProject(ctx, session); err != nil { - // Determine status based on error type - status := "REJECTED" - errorMsg := "project binding timeout: " + err.Error() - if err == context.Canceled { - status = "CANCELLED" - errorMsg = "client cancelled: " + err.Error() - // Notify frontend to close the dialog - if e.broadcaster != nil { - e.broadcaster.BroadcastMessage("session_pending_cancelled", map[string]interface{}{ - "sessionID": sessionID, - }) - } - } - - // Update request record with final status - proxyReq.Status = status - proxyReq.Error = errorMsg - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - _ = e.proxyRequestRepo.Update(proxyReq) - - // Broadcast the updated request - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - - return domain.NewProxyErrorWithMessage(err, false, "project binding required: "+err.Error()) - } - - // Update projectID from the now-bound session - projectID = session.ProjectID - proxyReq.ProjectID = projectID - ctx = ctxutil.WithProjectID(ctx, projectID) - } - - // Match routes - routes, err := e.router.Match(&router.MatchContext{ - ClientType: clientType, - ProjectID: projectID, - RequestModel: requestModel, - APITokenID: apiTokenID, - }) - if err != nil { - proxyReq.Status = "FAILED" - proxyReq.Error = "no routes available" - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - _ = e.proxyRequestRepo.Update(proxyReq) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - return domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, "no routes available") - } - - if len(routes) == 0 { - proxyReq.Status = "FAILED" - proxyReq.Error = "no routes configured" - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - _ = e.proxyRequestRepo.Update(proxyReq) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - return domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, "no routes configured") - } - - // Update status to IN_PROGRESS - proxyReq.Status = "IN_PROGRESS" - _ = e.proxyRequestRepo.Update(proxyReq) - ctx = ctxutil.WithProxyRequest(ctx, proxyReq) +func (e *Executor) Use(handlers ...flow.HandlerFunc) { + e.middlewares = append(e.middlewares, handlers...) +} - // Add broadcaster to context so adapters can send updates - if e.broadcaster != nil { - ctx = ctxutil.WithBroadcaster(ctx, e.broadcaster) +// Execute runs the executor middleware chain with a new flow context. +func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request) error { + c := flow.NewCtx(w, req) + if ctx != nil { + c.Set(flow.KeyProxyContext, ctx) } + return e.ExecuteWith(c) +} - // Broadcast new request immediately so frontend sees it - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) +// ExecuteWith runs the executor middleware chain using an existing flow context. +func (e *Executor) ExecuteWith(c *flow.Ctx) error { + if c == nil { + return domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "flow context missing") } - - // Track current attempt for cleanup - var currentAttempt *domain.ProxyUpstreamAttempt - - // Ensure final state is always updated - defer func() { - // If still IN_PROGRESS, mark as cancelled/failed - if proxyReq.Status == "IN_PROGRESS" { - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - if ctx.Err() != nil { - proxyReq.Status = "CANCELLED" - if ctx.Err() == context.Canceled { - proxyReq.Error = "client disconnected" - } else if ctx.Err() == context.DeadlineExceeded { - proxyReq.Error = "request timeout" - } else { - proxyReq.Error = ctx.Err().Error() - } - } else { - proxyReq.Status = "FAILED" - } - _ = e.proxyRequestRepo.Update(proxyReq) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - } - - // If current attempt is still IN_PROGRESS, mark as cancelled/failed - if currentAttempt != nil && currentAttempt.Status == "IN_PROGRESS" { - currentAttempt.EndTime = time.Now() - currentAttempt.Duration = currentAttempt.EndTime.Sub(currentAttempt.StartTime) - if ctx.Err() != nil { - currentAttempt.Status = "CANCELLED" - } else { - currentAttempt.Status = "FAILED" - } - _ = e.attemptRepo.Update(currentAttempt) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyUpstreamAttempt(currentAttempt) - } - } - }() - - // Try routes in order with retry logic - var lastErr error - for _, matchedRoute := range routes { - // Check context before starting new route - if ctx.Err() != nil { - return ctx.Err() - } - - // Update proxyReq with current route/provider for real-time tracking - proxyReq.RouteID = matchedRoute.Route.ID - proxyReq.ProviderID = matchedRoute.Provider.ID - _ = e.proxyRequestRepo.Update(proxyReq) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - - // Determine model mapping - // Model mapping is done in Executor after Router has filtered by SupportModels - clientType := ctxutil.GetClientType(ctx) - mappedModel := e.mapModel(requestModel, matchedRoute.Route, matchedRoute.Provider, clientType, projectID, apiTokenID) - ctx = ctxutil.WithMappedModel(ctx, mappedModel) - - // Format conversion: check if client type is supported by provider - // If not, convert request to a supported format - originalClientType := clientType - targetClientType := clientType - needsConversion := false - convertedBody := []byte(nil) - var convErr error - - supportedTypes := matchedRoute.ProviderAdapter.SupportedClientTypes() - if e.converter.NeedConvert(clientType, supportedTypes) { - targetClientType = GetPreferredTargetType(supportedTypes, clientType) - if targetClientType != clientType { - needsConversion = true - log.Printf("[Executor] Format conversion needed: %s -> %s for provider %s", - clientType, targetClientType, matchedRoute.Provider.Name) - - // Convert request body - requestBody := ctxutil.GetRequestBody(ctx) - if targetClientType == domain.ClientTypeCodex { - if headers := ctxutil.GetRequestHeaders(ctx); headers != nil { - requestBody = converter.InjectCodexUserAgent(requestBody, headers.Get("User-Agent")) - } - } - convertedBody, convErr = e.converter.TransformRequest( - clientType, targetClientType, requestBody, mappedModel, isStream) - if convErr != nil { - log.Printf("[Executor] Request conversion failed: %v, proceeding with original format", convErr) - needsConversion = false - } else { - // Update context with converted body and new client type - ctx = ctxutil.WithRequestBody(ctx, convertedBody) - ctx = ctxutil.WithClientType(ctx, targetClientType) - ctx = ctxutil.WithOriginalClientType(ctx, originalClientType) - - // Convert request URI to match the target client type - originalURI := ctxutil.GetRequestURI(ctx) - convertedURI := ConvertRequestURI(originalURI, clientType, targetClientType, mappedModel, isStream) - if convertedURI != originalURI { - ctx = ctxutil.WithRequestURI(ctx, convertedURI) - log.Printf("[Executor] URI converted: %s -> %s", originalURI, convertedURI) - } - } - } - } - - // Get retry config - retryConfig := e.getRetryConfig(matchedRoute.RetryConfig) - - // Execute with retries - for attempt := 0; attempt <= retryConfig.MaxRetries; attempt++ { - // Check context before each attempt - if ctx.Err() != nil { - return ctx.Err() - } - - // Create attempt record with start time and request info - attemptStartTime := time.Now() - attemptRecord := &domain.ProxyUpstreamAttempt{ - ProxyRequestID: proxyReq.ID, - RouteID: matchedRoute.Route.ID, - ProviderID: matchedRoute.Provider.ID, - IsStream: isStream, - Status: "IN_PROGRESS", - StartTime: attemptStartTime, - RequestModel: requestModel, - MappedModel: mappedModel, - RequestInfo: proxyReq.RequestInfo, // Use original request info initially - } - if err := e.attemptRepo.Create(attemptRecord); err != nil { - log.Printf("[Executor] Failed to create attempt record: %v", err) - } - currentAttempt = attemptRecord - - // Increment attempt count when creating a new attempt - proxyReq.ProxyUpstreamAttemptCount++ - - // Broadcast updated request with new attempt count - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - - // Broadcast new attempt immediately - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord) - } - - // Put attempt into context so adapter can populate request/response info - attemptCtx := ctxutil.WithUpstreamAttempt(ctx, attemptRecord) - - // Create event channel for adapter to send events - eventChan := domain.NewAdapterEventChan() - attemptCtx = ctxutil.WithEventChan(attemptCtx, eventChan) - - // Start real-time event processing goroutine - // This ensures RequestInfo is broadcast as soon as adapter sends it - eventDone := make(chan struct{}) - go e.processAdapterEventsRealtime(eventChan, attemptRecord, eventDone) - - // Wrap ResponseWriter to capture actual client response - // If format conversion is needed, use ConvertingResponseWriter - var responseWriter http.ResponseWriter - var convertingWriter *ConvertingResponseWriter - responseCapture := NewResponseCapture(w) - - if needsConversion { - // Use ConvertingResponseWriter to transform response from targetType back to originalType - convertingWriter = NewConvertingResponseWriter( - responseCapture, e.converter, originalClientType, targetClientType, isStream) - responseWriter = convertingWriter - } else { - responseWriter = responseCapture - } - - // Execute request - err := matchedRoute.ProviderAdapter.Execute(attemptCtx, responseWriter, req, matchedRoute.Provider) - - // For non-streaming responses with conversion, finalize the conversion - if needsConversion && convertingWriter != nil && !isStream { - if finalizeErr := convertingWriter.Finalize(); finalizeErr != nil { - log.Printf("[Executor] Response conversion finalize failed: %v", finalizeErr) - } - } - - // Close event channel and wait for processing goroutine to finish - eventChan.Close() - <-eventDone - - if err == nil { - // Success - set end time and duration - attemptRecord.EndTime = time.Now() - attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime) - attemptRecord.Status = "COMPLETED" - - // Calculate cost in executor (unified for all adapters) - // Adapter only needs to set token counts, executor handles pricing - if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 { - metrics := &usage.Metrics{ - InputTokens: attemptRecord.InputTokenCount, - OutputTokens: attemptRecord.OutputTokenCount, - CacheReadCount: attemptRecord.CacheReadCount, - CacheCreationCount: attemptRecord.CacheWriteCount, - Cache5mCreationCount: attemptRecord.Cache5mWriteCount, - Cache1hCreationCount: attemptRecord.Cache1hWriteCount, - } - // Use ResponseModel for pricing (actual model from API response), fallback to MappedModel - pricingModel := attemptRecord.ResponseModel - if pricingModel == "" { - pricingModel = attemptRecord.MappedModel - } - // Get multiplier from provider config - multiplier := getProviderMultiplier(matchedRoute.Provider, clientType) - result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier) - attemptRecord.Cost = result.Cost - attemptRecord.ModelPriceID = result.ModelPriceID - attemptRecord.Multiplier = result.Multiplier - } - - // 检查是否需要立即清理 attempt 详情(设置为 0 时不保存) - if e.shouldClearRequestDetail() { - attemptRecord.RequestInfo = nil - attemptRecord.ResponseInfo = nil - } - - _ = e.attemptRepo.Update(attemptRecord) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord) - } - currentAttempt = nil // Clear so defer doesn't update - - // Reset failure counts on success - clientType := string(ctxutil.GetClientType(attemptCtx)) - cooldown.Default().RecordSuccess(matchedRoute.Provider.ID, clientType) - - proxyReq.Status = "COMPLETED" - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID - proxyReq.ModelPriceID = attemptRecord.ModelPriceID - proxyReq.Multiplier = attemptRecord.Multiplier - proxyReq.ResponseModel = mappedModel // Record the actual model used - - // Capture actual client response (what was sent to client, e.g. Claude format) - // This is different from attemptRecord.ResponseInfo which is upstream response (Gemini format) - if !e.shouldClearRequestDetail() { - proxyReq.ResponseInfo = &domain.ResponseInfo{ - Status: responseCapture.StatusCode(), - Headers: responseCapture.CapturedHeaders(), - Body: responseCapture.Body(), - } - } - proxyReq.StatusCode = responseCapture.StatusCode() - - // Extract token usage from final client response (not from upstream attempt) - // This ensures we use the correct format (Claude/OpenAI/Gemini) for the client type - if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil { - proxyReq.InputTokenCount = metrics.InputTokens - proxyReq.OutputTokenCount = metrics.OutputTokens - proxyReq.CacheReadCount = metrics.CacheReadCount - proxyReq.CacheWriteCount = metrics.CacheCreationCount - proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount - proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount - } - proxyReq.Cost = attemptRecord.Cost - proxyReq.TTFT = attemptRecord.TTFT - - // 检查是否需要立即清理 proxyReq 详情(设置为 0 时不保存) - if e.shouldClearRequestDetail() { - proxyReq.RequestInfo = nil - proxyReq.ResponseInfo = nil - } - - _ = e.proxyRequestRepo.Update(proxyReq) - - // Broadcast to WebSocket clients - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - - return nil - } - - // Handle error - set end time and duration - attemptRecord.EndTime = time.Now() - attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime) - lastErr = err - - // Update attempt status first (before checking context) - if ctx.Err() != nil { - attemptRecord.Status = "CANCELLED" - } else { - attemptRecord.Status = "FAILED" - } - - // Calculate cost in executor even for failed attempts (may have partial token usage) - if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 { - metrics := &usage.Metrics{ - InputTokens: attemptRecord.InputTokenCount, - OutputTokens: attemptRecord.OutputTokenCount, - CacheReadCount: attemptRecord.CacheReadCount, - CacheCreationCount: attemptRecord.CacheWriteCount, - Cache5mCreationCount: attemptRecord.Cache5mWriteCount, - Cache1hCreationCount: attemptRecord.Cache1hWriteCount, - } - // Use ResponseModel for pricing (actual model from API response), fallback to MappedModel - pricingModel := attemptRecord.ResponseModel - if pricingModel == "" { - pricingModel = attemptRecord.MappedModel - } - // Get multiplier from provider config - multiplier := getProviderMultiplier(matchedRoute.Provider, clientType) - result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier) - attemptRecord.Cost = result.Cost - attemptRecord.ModelPriceID = result.ModelPriceID - attemptRecord.Multiplier = result.Multiplier - } - - // 检查是否需要立即清理 attempt 详情(设置为 0 时不保存) - if e.shouldClearRequestDetail() { - attemptRecord.RequestInfo = nil - attemptRecord.ResponseInfo = nil - } - - _ = e.attemptRepo.Update(attemptRecord) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord) - } - currentAttempt = nil // Clear so defer doesn't double update - - // Update proxyReq with latest attempt info (even on failure) - proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID - proxyReq.ModelPriceID = attemptRecord.ModelPriceID - proxyReq.Multiplier = attemptRecord.Multiplier - - // Capture actual client response (even on failure, if any response was sent) - if responseCapture.Body() != "" { - proxyReq.StatusCode = responseCapture.StatusCode() - if !e.shouldClearRequestDetail() { - proxyReq.ResponseInfo = &domain.ResponseInfo{ - Status: responseCapture.StatusCode(), - Headers: responseCapture.CapturedHeaders(), - Body: responseCapture.Body(), - } - } - - // Extract token usage from final client response - if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil { - proxyReq.InputTokenCount = metrics.InputTokens - proxyReq.OutputTokenCount = metrics.OutputTokens - proxyReq.CacheReadCount = metrics.CacheReadCount - proxyReq.CacheWriteCount = metrics.CacheCreationCount - proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount - proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount - } - } - proxyReq.Cost = attemptRecord.Cost - proxyReq.TTFT = attemptRecord.TTFT - - _ = e.proxyRequestRepo.Update(proxyReq) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - - // Handle cooldown only for real server/network errors, NOT client-side cancellations - proxyErr, ok := err.(*domain.ProxyError) - if ok && ctx.Err() != context.Canceled { - log.Printf("[Executor] ProxyError - IsNetworkError: %v, IsServerError: %v, Retryable: %v, Provider: %d", - proxyErr.IsNetworkError, proxyErr.IsServerError, proxyErr.Retryable, matchedRoute.Provider.ID) - // Handle cooldown (unified cooldown logic for all providers) - e.handleCooldown(attemptCtx, proxyErr, matchedRoute.Provider) - // Broadcast cooldown update event to frontend - if e.broadcaster != nil { - e.broadcaster.BroadcastMessage("cooldown_update", map[string]interface{}{ - "providerID": matchedRoute.Provider.ID, - }) - } - } else if ok && ctx.Err() == context.Canceled { - log.Printf("[Executor] Client disconnected, skipping cooldown for Provider: %d", matchedRoute.Provider.ID) - } else if !ok { - log.Printf("[Executor] Error is not ProxyError, type: %T, error: %v", err, err) - } - - // Check if context was cancelled or timed out - if ctx.Err() != nil { - proxyReq.Status = "CANCELLED" - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - if ctx.Err() == context.Canceled { - proxyReq.Error = "client disconnected" - } else if ctx.Err() == context.DeadlineExceeded { - proxyReq.Error = "request timeout" - } else { - proxyReq.Error = ctx.Err().Error() - } - _ = e.proxyRequestRepo.Update(proxyReq) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - return ctx.Err() - } - - // Check if retryable - if !ok { - break // Move to next route - } - - if !proxyErr.Retryable { - break // Move to next route - } - - // Wait before retry (unless last attempt) - if attempt < retryConfig.MaxRetries { - waitTime := e.calculateBackoff(retryConfig, attempt) - if proxyErr.RetryAfter > 0 { - waitTime = proxyErr.RetryAfter - } - select { - case <-ctx.Done(): - // Set final status before returning - proxyReq.Status = "CANCELLED" - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - if ctx.Err() == context.Canceled { - proxyReq.Error = "client disconnected during retry wait" - } else if ctx.Err() == context.DeadlineExceeded { - proxyReq.Error = "request timeout during retry wait" - } else { - proxyReq.Error = ctx.Err().Error() - } - _ = e.proxyRequestRepo.Update(proxyReq) - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - return ctx.Err() - case <-time.After(waitTime): - } - } + ctx := context.Background() + if v, ok := c.Get(flow.KeyProxyContext); ok { + if stored, ok := v.(context.Context); ok && stored != nil { + ctx = stored } - // Inner loop ended, will try next route if available - } - - // All routes failed - proxyReq.Status = "FAILED" - proxyReq.EndTime = time.Now() - proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) - if lastErr != nil { - proxyReq.Error = lastErr.Error() - } - - // 检查是否需要立即清理详情(设置为 0 时不保存) - if e.shouldClearRequestDetail() { - proxyReq.RequestInfo = nil - proxyReq.ResponseInfo = nil - } - - _ = e.proxyRequestRepo.Update(proxyReq) - - // Broadcast to WebSocket clients - if e.broadcaster != nil { - e.broadcaster.BroadcastProxyRequest(proxyReq) - } - - if lastErr != nil { - return lastErr } - return domain.NewProxyErrorWithMessage(domain.ErrAllRoutesFailed, false, "all routes exhausted") + state := &execState{ctx: ctx} + c.Set(flow.KeyExecutorState, state) + chain := []flow.HandlerFunc{e.egress, e.ingress} + chain = append(chain, e.middlewares...) + chain = append(chain, e.routeMatch, e.dispatch) + e.engine.HandleWith(c, chain...) + return state.lastErr } func (e *Executor) mapModel(requestModel string, route *domain.Route, provider *domain.Provider, clientType domain.ClientType, projectID uint64, apiTokenID uint64) string { @@ -761,19 +171,16 @@ func flattenHeaders(h http.Header) map[string]string { // handleCooldown processes cooldown information from ProxyError and sets provider cooldown // Priority: 1) Explicit time from API, 2) Policy-based calculation based on failure reason -func (e *Executor) handleCooldown(ctx context.Context, proxyErr *domain.ProxyError, provider *domain.Provider) { - // Determine which client type to apply cooldown to - clientType := proxyErr.CooldownClientType +func (e *Executor) handleCooldown(proxyErr *domain.ProxyError, provider *domain.Provider, clientType domain.ClientType, originalClientType domain.ClientType) { + selectedClientType := proxyErr.CooldownClientType if proxyErr.RateLimitInfo != nil && proxyErr.RateLimitInfo.ClientType != "" { - clientType = proxyErr.RateLimitInfo.ClientType + selectedClientType = proxyErr.RateLimitInfo.ClientType } - // Fallback to original client type (before format conversion) if not specified - if clientType == "" { - // Prefer original client type over converted type - if origCT := ctxutil.GetOriginalClientType(ctx); origCT != "" { - clientType = string(origCT) + if selectedClientType == "" { + if originalClientType != "" { + selectedClientType = string(originalClientType) } else { - clientType = string(ctxutil.GetClientType(ctx)) + selectedClientType = string(clientType) } } @@ -815,11 +222,11 @@ func (e *Executor) handleCooldown(ctx context.Context, proxyErr *domain.ProxyErr // Record failure and apply cooldown // If explicitUntil is not nil, it will be used directly // Otherwise, cooldown duration is calculated based on policy and failure count - cooldown.Default().RecordFailure(provider.ID, clientType, reason, explicitUntil) + cooldown.Default().RecordFailure(provider.ID, selectedClientType, reason, explicitUntil) // If there's an async update channel, listen for updates if proxyErr.CooldownUpdateChan != nil { - go e.handleAsyncCooldownUpdate(proxyErr.CooldownUpdateChan, provider, clientType) + go e.handleAsyncCooldownUpdate(proxyErr.CooldownUpdateChan, provider, selectedClientType) } } diff --git a/internal/executor/flow_state.go b/internal/executor/flow_state.go new file mode 100644 index 00000000..0508eb6d --- /dev/null +++ b/internal/executor/flow_state.go @@ -0,0 +1,38 @@ +package executor + +import ( + "context" + "net/http" + + "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" + "github.com/awsl-project/maxx/internal/router" +) + +type execState struct { + ctx context.Context + proxyReq *domain.ProxyRequest + routes []*router.MatchedRoute + currentAttempt *domain.ProxyUpstreamAttempt + lastErr error + + clientType domain.ClientType + projectID uint64 + sessionID string + requestModel string + isStream bool + apiTokenID uint64 + requestBody []byte + originalRequestBody []byte + requestHeaders http.Header + requestURI string +} + +func getExecState(c *flow.Ctx) (*execState, bool) { + v, ok := c.Get(flow.KeyExecutorState) + if !ok { + return nil, false + } + st, ok := v.(*execState) + return st, ok +} diff --git a/internal/executor/middleware_dispatch.go b/internal/executor/middleware_dispatch.go new file mode 100644 index 00000000..3b49d614 --- /dev/null +++ b/internal/executor/middleware_dispatch.go @@ -0,0 +1,399 @@ +package executor + +import ( + "context" + "log" + "net/http" + "time" + + "github.com/awsl-project/maxx/internal/converter" + "github.com/awsl-project/maxx/internal/cooldown" + "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" + "github.com/awsl-project/maxx/internal/pricing" + "github.com/awsl-project/maxx/internal/usage" +) + +func (e *Executor) dispatch(c *flow.Ctx) { + state, ok := getExecState(c) + if !ok { + err := domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "executor state missing") + c.Err = err + c.Abort() + return + } + + proxyReq := state.proxyReq + ctx := state.ctx + + for _, matchedRoute := range state.routes { + if ctx.Err() != nil { + state.lastErr = ctx.Err() + c.Err = state.lastErr + return + } + + proxyReq.RouteID = matchedRoute.Route.ID + proxyReq.ProviderID = matchedRoute.Provider.ID + _ = e.proxyRequestRepo.Update(proxyReq) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + + clientType := state.clientType + mappedModel := e.mapModel(state.requestModel, matchedRoute.Route, matchedRoute.Provider, clientType, state.projectID, state.apiTokenID) + + originalClientType := clientType + currentClientType := clientType + needsConversion := false + convertedBody := []byte(nil) + var convErr error + requestBody := state.requestBody + requestURI := state.requestURI + + supportedTypes := matchedRoute.ProviderAdapter.SupportedClientTypes() + if e.converter.NeedConvert(clientType, supportedTypes) { + currentClientType = GetPreferredTargetType(supportedTypes, clientType, matchedRoute.Provider.Type) + if currentClientType != clientType { + needsConversion = true + log.Printf("[Executor] Format conversion needed: %s -> %s for provider %s", + clientType, currentClientType, matchedRoute.Provider.Name) + + if currentClientType == domain.ClientTypeCodex { + if headers := state.requestHeaders; headers != nil { + requestBody = converter.InjectCodexUserAgent(requestBody, headers.Get("User-Agent")) + } + } + convertedBody, convErr = e.converter.TransformRequest( + clientType, currentClientType, requestBody, mappedModel, state.isStream) + if convErr != nil { + log.Printf("[Executor] Request conversion failed: %v, proceeding with original format", convErr) + needsConversion = false + currentClientType = clientType + } else { + requestBody = convertedBody + + originalURI := requestURI + convertedURI := ConvertRequestURI(requestURI, clientType, currentClientType, mappedModel, state.isStream) + if convertedURI != originalURI { + requestURI = convertedURI + log.Printf("[Executor] URI converted: %s -> %s", originalURI, convertedURI) + } + } + } + } + + retryConfig := e.getRetryConfig(matchedRoute.RetryConfig) + + for attempt := 0; attempt <= retryConfig.MaxRetries; attempt++ { + if ctx.Err() != nil { + state.lastErr = ctx.Err() + c.Err = state.lastErr + return + } + + attemptStartTime := time.Now() + attemptRecord := &domain.ProxyUpstreamAttempt{ + ProxyRequestID: proxyReq.ID, + RouteID: matchedRoute.Route.ID, + ProviderID: matchedRoute.Provider.ID, + IsStream: state.isStream, + Status: "IN_PROGRESS", + StartTime: attemptStartTime, + RequestModel: state.requestModel, + MappedModel: mappedModel, + RequestInfo: proxyReq.RequestInfo, + } + if err := e.attemptRepo.Create(attemptRecord); err != nil { + log.Printf("[Executor] Failed to create attempt record: %v", err) + } + state.currentAttempt = attemptRecord + + proxyReq.ProxyUpstreamAttemptCount++ + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord) + } + + eventChan := domain.NewAdapterEventChan() + c.Set(flow.KeyClientType, currentClientType) + c.Set(flow.KeyOriginalClientType, originalClientType) + c.Set(flow.KeyMappedModel, mappedModel) + c.Set(flow.KeyRequestBody, requestBody) + c.Set(flow.KeyRequestURI, requestURI) + c.Set(flow.KeyRequestHeaders, state.requestHeaders) + c.Set(flow.KeyProxyRequest, state.proxyReq) + c.Set(flow.KeyUpstreamAttempt, attemptRecord) + c.Set(flow.KeyEventChan, eventChan) + c.Set(flow.KeyBroadcaster, e.broadcaster) + eventDone := make(chan struct{}) + go e.processAdapterEventsRealtime(eventChan, attemptRecord, eventDone) + + var responseWriter http.ResponseWriter + var convertingWriter *ConvertingResponseWriter + responseCapture := NewResponseCapture(c.Writer) + if needsConversion { + convertingWriter = NewConvertingResponseWriter( + responseCapture, e.converter, originalClientType, currentClientType, state.isStream, state.originalRequestBody) + responseWriter = convertingWriter + } else { + responseWriter = responseCapture + } + + originalWriter := c.Writer + c.Writer = responseWriter + err := matchedRoute.ProviderAdapter.Execute(c, matchedRoute.Provider) + c.Writer = originalWriter + + if needsConversion && convertingWriter != nil && !state.isStream { + if finalizeErr := convertingWriter.Finalize(); finalizeErr != nil { + log.Printf("[Executor] Response conversion finalize failed: %v", finalizeErr) + } + } + + eventChan.Close() + <-eventDone + + if err == nil { + attemptRecord.EndTime = time.Now() + attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime) + attemptRecord.Status = "COMPLETED" + + if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 { + metrics := &usage.Metrics{ + InputTokens: attemptRecord.InputTokenCount, + OutputTokens: attemptRecord.OutputTokenCount, + CacheReadCount: attemptRecord.CacheReadCount, + CacheCreationCount: attemptRecord.CacheWriteCount, + Cache5mCreationCount: attemptRecord.Cache5mWriteCount, + Cache1hCreationCount: attemptRecord.Cache1hWriteCount, + } + pricingModel := attemptRecord.ResponseModel + if pricingModel == "" { + pricingModel = attemptRecord.MappedModel + } + multiplier := getProviderMultiplier(matchedRoute.Provider, clientType) + result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier) + attemptRecord.Cost = result.Cost + attemptRecord.ModelPriceID = result.ModelPriceID + attemptRecord.Multiplier = result.Multiplier + } + + if e.shouldClearRequestDetail() { + attemptRecord.RequestInfo = nil + attemptRecord.ResponseInfo = nil + } + + _ = e.attemptRepo.Update(attemptRecord) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord) + } + state.currentAttempt = nil + + cooldown.Default().RecordSuccess(matchedRoute.Provider.ID, string(currentClientType)) + + proxyReq.Status = "COMPLETED" + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID + proxyReq.ModelPriceID = attemptRecord.ModelPriceID + proxyReq.Multiplier = attemptRecord.Multiplier + proxyReq.ResponseModel = mappedModel + + if !e.shouldClearRequestDetail() { + proxyReq.ResponseInfo = &domain.ResponseInfo{ + Status: responseCapture.StatusCode(), + Headers: responseCapture.CapturedHeaders(), + Body: responseCapture.Body(), + } + } + proxyReq.StatusCode = responseCapture.StatusCode() + + if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil { + proxyReq.InputTokenCount = metrics.InputTokens + proxyReq.OutputTokenCount = metrics.OutputTokens + proxyReq.CacheReadCount = metrics.CacheReadCount + proxyReq.CacheWriteCount = metrics.CacheCreationCount + proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount + proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount + } + proxyReq.Cost = attemptRecord.Cost + proxyReq.TTFT = attemptRecord.TTFT + + if e.shouldClearRequestDetail() { + proxyReq.RequestInfo = nil + proxyReq.ResponseInfo = nil + } + + _ = e.proxyRequestRepo.Update(proxyReq) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + + state.lastErr = nil + state.ctx = ctx + return + } + + attemptRecord.EndTime = time.Now() + attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime) + state.lastErr = err + + if ctx.Err() != nil { + attemptRecord.Status = "CANCELLED" + } else { + attemptRecord.Status = "FAILED" + } + + if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 { + metrics := &usage.Metrics{ + InputTokens: attemptRecord.InputTokenCount, + OutputTokens: attemptRecord.OutputTokenCount, + CacheReadCount: attemptRecord.CacheReadCount, + CacheCreationCount: attemptRecord.CacheWriteCount, + Cache5mCreationCount: attemptRecord.Cache5mWriteCount, + Cache1hCreationCount: attemptRecord.Cache1hWriteCount, + } + pricingModel := attemptRecord.ResponseModel + if pricingModel == "" { + pricingModel = attemptRecord.MappedModel + } + multiplier := getProviderMultiplier(matchedRoute.Provider, clientType) + result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier) + attemptRecord.Cost = result.Cost + attemptRecord.ModelPriceID = result.ModelPriceID + attemptRecord.Multiplier = result.Multiplier + } + + if e.shouldClearRequestDetail() { + attemptRecord.RequestInfo = nil + attemptRecord.ResponseInfo = nil + } + + _ = e.attemptRepo.Update(attemptRecord) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord) + } + state.currentAttempt = nil + + proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID + proxyReq.ModelPriceID = attemptRecord.ModelPriceID + proxyReq.Multiplier = attemptRecord.Multiplier + + if responseCapture.Body() != "" { + proxyReq.StatusCode = responseCapture.StatusCode() + if !e.shouldClearRequestDetail() { + proxyReq.ResponseInfo = &domain.ResponseInfo{ + Status: responseCapture.StatusCode(), + Headers: responseCapture.CapturedHeaders(), + Body: responseCapture.Body(), + } + } + if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil { + proxyReq.InputTokenCount = metrics.InputTokens + proxyReq.OutputTokenCount = metrics.OutputTokens + proxyReq.CacheReadCount = metrics.CacheReadCount + proxyReq.CacheWriteCount = metrics.CacheCreationCount + proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount + proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount + } + } + proxyReq.Cost = attemptRecord.Cost + proxyReq.TTFT = attemptRecord.TTFT + + _ = e.proxyRequestRepo.Update(proxyReq) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + + proxyErr, ok := err.(*domain.ProxyError) + if ok && ctx.Err() != nil { + proxyReq.Status = "CANCELLED" + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + if ctx.Err() == context.Canceled { + proxyReq.Error = "client disconnected" + } else if ctx.Err() == context.DeadlineExceeded { + proxyReq.Error = "request timeout" + } else { + proxyReq.Error = ctx.Err().Error() + } + _ = e.proxyRequestRepo.Update(proxyReq) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + state.lastErr = ctx.Err() + c.Err = state.lastErr + return + } + + if ok && ctx.Err() != context.Canceled { + log.Printf("[Executor] ProxyError - IsNetworkError: %v, IsServerError: %v, Retryable: %v, Provider: %d", + proxyErr.IsNetworkError, proxyErr.IsServerError, proxyErr.Retryable, matchedRoute.Provider.ID) + e.handleCooldown(proxyErr, matchedRoute.Provider, currentClientType, originalClientType) + if e.broadcaster != nil { + e.broadcaster.BroadcastMessage("cooldown_update", map[string]interface{}{ + "providerID": matchedRoute.Provider.ID, + }) + } + } else if ok && ctx.Err() == context.Canceled { + log.Printf("[Executor] Client disconnected, skipping cooldown for Provider: %d", matchedRoute.Provider.ID) + } else if !ok { + log.Printf("[Executor] Error is not ProxyError, type: %T, error: %v", err, err) + } + + if !ok || !proxyErr.Retryable { + break + } + + if attempt < retryConfig.MaxRetries { + waitTime := e.calculateBackoff(retryConfig, attempt) + if proxyErr.RetryAfter > 0 { + waitTime = proxyErr.RetryAfter + } + select { + case <-ctx.Done(): + proxyReq.Status = "CANCELLED" + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + if ctx.Err() == context.Canceled { + proxyReq.Error = "client disconnected during retry wait" + } else if ctx.Err() == context.DeadlineExceeded { + proxyReq.Error = "request timeout during retry wait" + } else { + proxyReq.Error = ctx.Err().Error() + } + _ = e.proxyRequestRepo.Update(proxyReq) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + state.lastErr = ctx.Err() + c.Err = state.lastErr + return + case <-time.After(waitTime): + } + } + } + } + + proxyReq.Status = "FAILED" + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + if state.lastErr != nil { + proxyReq.Error = state.lastErr.Error() + } + if e.shouldClearRequestDetail() { + proxyReq.RequestInfo = nil + proxyReq.ResponseInfo = nil + } + _ = e.proxyRequestRepo.Update(proxyReq) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + + if state.lastErr == nil { + state.lastErr = domain.NewProxyErrorWithMessage(domain.ErrAllRoutesFailed, false, "all routes exhausted") + } + state.ctx = ctx + c.Err = state.lastErr +} diff --git a/internal/executor/middleware_egress.go b/internal/executor/middleware_egress.go new file mode 100644 index 00000000..9b41cd1b --- /dev/null +++ b/internal/executor/middleware_egress.go @@ -0,0 +1,56 @@ +package executor + +import ( + "context" + "time" + + "github.com/awsl-project/maxx/internal/flow" +) + +func (e *Executor) egress(c *flow.Ctx) { + state, ok := getExecState(c) + if !ok { + c.Next() + return + } + + c.Next() + + proxyReq := state.proxyReq + if proxyReq != nil && proxyReq.Status == "IN_PROGRESS" { + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + if state.ctx != nil && state.ctx.Err() != nil { + proxyReq.Status = "CANCELLED" + if state.ctx.Err() == context.Canceled { + proxyReq.Error = "client disconnected" + } else if state.ctx.Err() == context.DeadlineExceeded { + proxyReq.Error = "request timeout" + } else { + proxyReq.Error = state.ctx.Err().Error() + } + } else { + proxyReq.Status = "FAILED" + } + _ = e.proxyRequestRepo.Update(proxyReq) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + } + + if state.currentAttempt != nil && state.currentAttempt.Status == "IN_PROGRESS" { + state.currentAttempt.EndTime = time.Now() + state.currentAttempt.Duration = state.currentAttempt.EndTime.Sub(state.currentAttempt.StartTime) + if state.ctx != nil && state.ctx.Err() != nil { + state.currentAttempt.Status = "CANCELLED" + } else { + state.currentAttempt.Status = "FAILED" + } + _ = e.attemptRepo.Update(state.currentAttempt) + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyUpstreamAttempt(state.currentAttempt) + } + } + + _ = state.lastErr +} diff --git a/internal/executor/middleware_ingress.go b/internal/executor/middleware_ingress.go new file mode 100644 index 00000000..00a959d7 --- /dev/null +++ b/internal/executor/middleware_ingress.go @@ -0,0 +1,169 @@ +package executor + +import ( + "context" + "errors" + "log" + "net/http" + "time" + + "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" +) + +func (e *Executor) ingress(c *flow.Ctx) { + state, ok := getExecState(c) + if !ok { + err := domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "executor state missing") + c.Err = err + c.Abort() + return + } + + ctx := state.ctx + if v, ok := c.Get(flow.KeyClientType); ok { + if ct, ok := v.(domain.ClientType); ok { + state.clientType = ct + } + } + if v, ok := c.Get(flow.KeyProjectID); ok { + if pid, ok := v.(uint64); ok { + state.projectID = pid + } + } + if v, ok := c.Get(flow.KeySessionID); ok { + if sid, ok := v.(string); ok { + state.sessionID = sid + } + } + if v, ok := c.Get(flow.KeyRequestModel); ok { + if model, ok := v.(string); ok { + state.requestModel = model + } + } + if v, ok := c.Get(flow.KeyIsStream); ok { + if s, ok := v.(bool); ok { + state.isStream = s + } + } + if v, ok := c.Get(flow.KeyAPITokenID); ok { + if id, ok := v.(uint64); ok { + state.apiTokenID = id + } + } + if v, ok := c.Get(flow.KeyRequestBody); ok { + if body, ok := v.([]byte); ok { + state.requestBody = body + } + } + if v, ok := c.Get(flow.KeyOriginalRequestBody); ok { + if body, ok := v.([]byte); ok { + state.originalRequestBody = body + } + } + if v, ok := c.Get(flow.KeyRequestHeaders); ok { + if headers, ok := v.(map[string][]string); ok { + state.requestHeaders = headers + } + if headers, ok := v.(http.Header); ok { + state.requestHeaders = headers + } + } + if v, ok := c.Get(flow.KeyRequestURI); ok { + if uri, ok := v.(string); ok { + state.requestURI = uri + } + } + + proxyReq := &domain.ProxyRequest{ + InstanceID: e.instanceID, + RequestID: generateRequestID(), + SessionID: state.sessionID, + ClientType: state.clientType, + ProjectID: state.projectID, + RequestModel: state.requestModel, + StartTime: time.Now(), + IsStream: state.isStream, + Status: "PENDING", + APITokenID: state.apiTokenID, + } + + if !e.shouldClearRequestDetail() { + requestURI := state.requestURI + requestHeaders := state.requestHeaders + requestBody := state.requestBody + headers := flattenHeaders(requestHeaders) + if c.Request != nil { + if c.Request.Host != "" { + if headers == nil { + headers = make(map[string]string) + } + headers["Host"] = c.Request.Host + } + proxyReq.RequestInfo = &domain.RequestInfo{ + Method: c.Request.Method, + URL: requestURI, + Headers: headers, + Body: string(requestBody), + } + } + } + + if err := e.proxyRequestRepo.Create(proxyReq); err != nil { + log.Printf("[Executor] Failed to create proxy request: %v", err) + } + + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + + state.proxyReq = proxyReq + state.ctx = ctx + + if state.projectID == 0 && e.projectWaiter != nil { + session, _ := e.sessionRepo.GetBySessionID(state.sessionID) + if session == nil { + session = &domain.Session{ + SessionID: state.sessionID, + ClientType: state.clientType, + ProjectID: 0, + } + } + + if err := e.projectWaiter.WaitForProject(ctx, session); err != nil { + status := "REJECTED" + errorMsg := "project binding timeout: " + err.Error() + if errors.Is(err, context.Canceled) { + status = "CANCELLED" + errorMsg = "client cancelled: " + err.Error() + if e.broadcaster != nil { + e.broadcaster.BroadcastMessage("session_pending_cancelled", map[string]interface{}{ + "sessionID": state.sessionID, + }) + } + } + + proxyReq.Status = status + proxyReq.Error = errorMsg + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + _ = e.proxyRequestRepo.Update(proxyReq) + + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + + err := domain.NewProxyErrorWithMessage(err, false, "project binding required: "+err.Error()) + state.lastErr = err + c.Err = err + c.Abort() + return + } + + state.projectID = session.ProjectID + proxyReq.ProjectID = state.projectID + state.ctx = ctx + } + + c.Next() +} diff --git a/internal/executor/middleware_route_match.go b/internal/executor/middleware_route_match.go new file mode 100644 index 00000000..fb11ad71 --- /dev/null +++ b/internal/executor/middleware_route_match.go @@ -0,0 +1,75 @@ +package executor + +import ( + "fmt" + "log" + "time" + + "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/flow" + "github.com/awsl-project/maxx/internal/router" +) + +func (e *Executor) routeMatch(c *flow.Ctx) { + state, ok := getExecState(c) + if !ok { + err := domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "executor state missing") + c.Err = err + c.Abort() + return + } + + proxyReq := state.proxyReq + routes, err := e.router.Match(&router.MatchContext{ + ClientType: state.clientType, + ProjectID: state.projectID, + RequestModel: state.requestModel, + APITokenID: state.apiTokenID, + }) + if err != nil { + proxyReq.Status = "FAILED" + proxyReq.Error = "no routes available" + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + if err := e.proxyRequestRepo.Update(proxyReq); err != nil { + log.Printf("[Executor] failed to update proxy request: %v", err) + } + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + err = domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, fmt.Sprintf("route match failed: %v", err)) + state.lastErr = err + c.Err = err + c.Abort() + return + } + + if len(routes) == 0 { + proxyReq.Status = "FAILED" + proxyReq.Error = "no routes configured" + proxyReq.EndTime = time.Now() + proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) + if err := e.proxyRequestRepo.Update(proxyReq); err != nil { + log.Printf("[Executor] failed to update proxy request: %v", err) + } + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + err = domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, "no routes configured") + state.lastErr = err + c.Err = err + c.Abort() + return + } + + proxyReq.Status = "IN_PROGRESS" + if err := e.proxyRequestRepo.Update(proxyReq); err != nil { + log.Printf("[Executor] failed to update proxy request: %v", err) + } + if e.broadcaster != nil { + e.broadcaster.BroadcastProxyRequest(proxyReq) + } + state.routes = routes + + c.Next() +} diff --git a/internal/flow/engine.go b/internal/flow/engine.go new file mode 100644 index 00000000..34c81191 --- /dev/null +++ b/internal/flow/engine.go @@ -0,0 +1,88 @@ +package flow + +import ( + "io" + "net/http" +) + +type HandlerFunc func(*Ctx) + +type Engine struct { + handlers []HandlerFunc +} + +func NewEngine() *Engine { + return &Engine{} +} + +func (e *Engine) Use(handlers ...HandlerFunc) { + e.handlers = append(e.handlers, handlers...) +} + +func (e *Engine) Handle(c *Ctx) { + c.handlers = e.handlers + c.index = -1 + c.Next() +} + +func (e *Engine) HandleWith(c *Ctx, handlers ...HandlerFunc) { + c.handlers = append(append([]HandlerFunc{}, e.handlers...), handlers...) + c.index = -1 + c.Next() +} + +type Ctx struct { + Writer http.ResponseWriter + Request *http.Request + InboundBody []byte + OutboundBody []byte + StreamBody io.ReadCloser + IsStream bool + Keys map[string]interface{} + Err error + + handlers []HandlerFunc + index int + aborted bool +} + +func NewCtx(w http.ResponseWriter, r *http.Request) *Ctx { + return &Ctx{ + Writer: w, + Request: r, + Keys: make(map[string]interface{}), + } +} + +func (c *Ctx) Next() { + if c.aborted { + return + } + c.index++ + for c.index < len(c.handlers) { + c.handlers[c.index](c) + if c.aborted { + return + } + c.index++ + } +} + +func (c *Ctx) Abort() { + c.aborted = true +} + +func (c *Ctx) Set(key string, value interface{}) { + if c.Keys == nil { + c.Keys = make(map[string]interface{}) + } + c.Keys[key] = value +} + +func (c *Ctx) Get(key string) (interface{}, bool) { + if c.Keys == nil { + return nil, false + } + v, ok := c.Keys[key] + return v, ok +} diff --git a/internal/flow/helpers.go b/internal/flow/helpers.go new file mode 100644 index 00000000..37cb14eb --- /dev/null +++ b/internal/flow/helpers.go @@ -0,0 +1,152 @@ +package flow + +import ( + "net/http" + + "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/event" +) + +func GetClientType(c *Ctx) domain.ClientType { + if v, ok := c.Get(KeyClientType); ok { + if ct, ok := v.(domain.ClientType); ok { + return ct + } + } + return "" +} + +func GetOriginalClientType(c *Ctx) domain.ClientType { + if v, ok := c.Get(KeyOriginalClientType); ok { + if ct, ok := v.(domain.ClientType); ok { + return ct + } + } + return "" +} + +func GetSessionID(c *Ctx) string { + if v, ok := c.Get(KeySessionID); ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func GetProjectID(c *Ctx) uint64 { + if v, ok := c.Get(KeyProjectID); ok { + if id, ok := v.(uint64); ok { + return id + } + } + return 0 +} + +func GetRequestModel(c *Ctx) string { + if v, ok := c.Get(KeyRequestModel); ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func GetMappedModel(c *Ctx) string { + if v, ok := c.Get(KeyMappedModel); ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func GetRequestBody(c *Ctx) []byte { + if v, ok := c.Get(KeyRequestBody); ok { + if b, ok := v.([]byte); ok { + return b + } + } + return nil +} + +func GetOriginalRequestBody(c *Ctx) []byte { + if v, ok := c.Get(KeyOriginalRequestBody); ok { + if b, ok := v.([]byte); ok { + return b + } + } + return nil +} + +func GetRequestHeaders(c *Ctx) http.Header { + if v, ok := c.Get(KeyRequestHeaders); ok { + if h, ok := v.(http.Header); ok { + return h + } + } + return nil +} + +func GetRequestURI(c *Ctx) string { + if v, ok := c.Get(KeyRequestURI); ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func GetIsStream(c *Ctx) bool { + if v, ok := c.Get(KeyIsStream); ok { + if s, ok := v.(bool); ok { + return s + } + } + return false +} + +func GetAPITokenID(c *Ctx) uint64 { + if v, ok := c.Get(KeyAPITokenID); ok { + if id, ok := v.(uint64); ok { + return id + } + } + return 0 +} + +func GetProxyRequest(c *Ctx) *domain.ProxyRequest { + if v, ok := c.Get(KeyProxyRequest); ok { + if pr, ok := v.(*domain.ProxyRequest); ok { + return pr + } + } + return nil +} + +func GetUpstreamAttempt(c *Ctx) *domain.ProxyUpstreamAttempt { + if v, ok := c.Get(KeyUpstreamAttempt); ok { + if at, ok := v.(*domain.ProxyUpstreamAttempt); ok { + return at + } + } + return nil +} + +func GetEventChan(c *Ctx) domain.AdapterEventChan { + if v, ok := c.Get(KeyEventChan); ok { + if ch, ok := v.(domain.AdapterEventChan); ok { + return ch + } + } + return nil +} + +func GetBroadcaster(c *Ctx) event.Broadcaster { + if v, ok := c.Get(KeyBroadcaster); ok { + if b, ok := v.(event.Broadcaster); ok { + return b + } + } + return nil +} diff --git a/internal/flow/keys.go b/internal/flow/keys.go new file mode 100644 index 00000000..caaec1cf --- /dev/null +++ b/internal/flow/keys.go @@ -0,0 +1,25 @@ +package flow + +const ( + KeyProxyContext = "proxy_context" + KeyProxyStream = "proxy_stream" + KeyProxyRequestModel = "proxy_request_model" + KeyExecutorState = "executor_state" + + KeyClientType = "client_type" + KeyOriginalClientType = "original_client_type" + KeySessionID = "session_id" + KeyProjectID = "project_id" + KeyRequestModel = "request_model" + KeyMappedModel = "mapped_model" + KeyRequestBody = "request_body" + KeyOriginalRequestBody = "original_request_body" + KeyRequestHeaders = "request_headers" + KeyRequestURI = "request_uri" + KeyIsStream = "is_stream" + KeyAPITokenID = "api_token_id" + KeyProxyRequest = "proxy_request" + KeyUpstreamAttempt = "upstream_attempt" + KeyEventChan = "event_chan" + KeyBroadcaster = "broadcaster" +) diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index 5934e27e..f2e63dfc 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "encoding/json" "io" "log" @@ -10,9 +11,10 @@ import ( "sync" "github.com/awsl-project/maxx/internal/adapter/client" - ctxutil "github.com/awsl-project/maxx/internal/context" + "github.com/awsl-project/maxx/internal/converter" "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/executor" + "github.com/awsl-project/maxx/internal/flow" "github.com/awsl-project/maxx/internal/repository/cached" ) @@ -31,6 +33,8 @@ type ProxyHandler struct { tokenAuth *TokenAuthMiddleware tracker RequestTracker trackerMu sync.RWMutex + engine *flow.Engine + extra []flow.HandlerFunc } // NewProxyHandler creates a new proxy handler @@ -40,12 +44,19 @@ func NewProxyHandler( sessionRepo *cached.SessionRepository, tokenAuth *TokenAuthMiddleware, ) *ProxyHandler { - return &ProxyHandler{ + h := &ProxyHandler{ clientAdapter: clientAdapter, executor: exec, sessionRepo: sessionRepo, tokenAuth: tokenAuth, + engine: flow.NewEngine(), } + h.engine.Use(h.ingress) + return h +} + +func (h *ProxyHandler) Use(handlers ...flow.HandlerFunc) { + h.extra = append(h.extra, handlers...) } // SetRequestTracker sets the request tracker for graceful shutdown @@ -57,6 +68,16 @@ func (h *ProxyHandler) SetRequestTracker(tracker RequestTracker) { // ServeHTTP handles proxy requests func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := flow.NewCtx(w, r) + handlers := make([]flow.HandlerFunc, len(h.extra)+1) + copy(handlers, h.extra) + handlers[len(h.extra)] = h.dispatch + h.engine.HandleWith(ctx, handlers...) +} + +func (h *ProxyHandler) ingress(c *flow.Ctx) { + r := c.Request + w := c.Writer log.Printf("[Proxy] Received request: %s %s", r.Method, r.URL.Path) // Track request for graceful shutdown @@ -66,9 +87,9 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if tracker != nil { if !tracker.Add() { - // Server is shutting down, reject new requests log.Printf("[Proxy] Rejecting request during shutdown: %s %s", r.Method, r.URL.Path) writeError(w, http.StatusServiceUnavailable, "server is shutting down") + c.Abort() return } defer tracker.Done() @@ -76,6 +97,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed") + c.Abort() return } @@ -88,26 +110,36 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, _ = io.Copy(io.Discard, r.Body) _ = r.Body.Close() writeCountTokensResponse(w) + c.Abort() return } - // Read body body, err := io.ReadAll(r.Body) if err != nil { writeError(w, http.StatusBadRequest, "failed to read request body") + c.Abort() return } - defer r.Body.Close() + _ = r.Body.Close() + + // Normalize OpenAI Responses payloads sent to chat/completions + if strings.HasPrefix(r.URL.Path, "/v1/chat/completions") { + if normalized, ok := normalizeOpenAIChatCompletionsPayload(body); ok { + body = normalized + } + } + + r.Body = io.NopCloser(bytes.NewReader(body)) + ctx := r.Context() - // Detect client type and extract info clientType := h.clientAdapter.DetectClientType(r, body) log.Printf("[Proxy] Detected client type: %s", clientType) if clientType == "" { writeError(w, http.StatusBadRequest, "unable to detect client type") + c.Abort() return } - // Token authentication (uses clientType for primary header, with fallback) var apiToken *domain.APIToken var apiTokenID uint64 if h.tokenAuth != nil { @@ -115,6 +147,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err != nil { log.Printf("[Proxy] Token auth failed: %v", err) writeError(w, http.StatusUnauthorized, err.Error()) + c.Abort() return } if apiToken != nil { @@ -128,18 +161,17 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sessionID := h.clientAdapter.ExtractSessionID(r, body, clientType) stream := h.clientAdapter.IsStreamRequest(r, body) - // Build context - ctx := r.Context() - ctx = ctxutil.WithClientType(ctx, clientType) - ctx = ctxutil.WithSessionID(ctx, sessionID) - ctx = ctxutil.WithRequestModel(ctx, requestModel) - ctx = ctxutil.WithRequestBody(ctx, body) - ctx = ctxutil.WithRequestHeaders(ctx, r.Header) - ctx = ctxutil.WithRequestURI(ctx, r.URL.RequestURI()) - ctx = ctxutil.WithIsStream(ctx, stream) - ctx = ctxutil.WithAPITokenID(ctx, apiTokenID) - - // Check for project ID from header (set by ProjectProxyHandler) + c.Set(flow.KeyClientType, clientType) + c.Set(flow.KeySessionID, sessionID) + c.Set(flow.KeyRequestModel, requestModel) + originalBody := bytes.Clone(body) + c.Set(flow.KeyRequestBody, body) + c.Set(flow.KeyOriginalRequestBody, originalBody) + c.Set(flow.KeyRequestHeaders, r.Header) + c.Set(flow.KeyRequestURI, r.URL.RequestURI()) + c.Set(flow.KeyIsStream, stream) + c.Set(flow.KeyAPITokenID, apiTokenID) + var projectID uint64 if pidStr := r.Header.Get("X-Maxx-Project-ID"); pidStr != "" { if pid, err := strconv.ParseUint(pidStr, 10, 64); err == nil { @@ -148,10 +180,8 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - // Get or create session to get project ID session, _ := h.sessionRepo.GetBySessionID(sessionID) if session != nil { - // Priority: Session binding (Admin configured) > Token association > Header > 0 if session.ProjectID > 0 { projectID = session.ProjectID log.Printf("[Proxy] Using project ID from session binding: %d", projectID) @@ -160,8 +190,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("[Proxy] Using project ID from token: %d", projectID) } } else { - // Create new session - // If no project from header, use token's project if projectID == 0 && apiToken != nil && apiToken.ProjectID > 0 { projectID = apiToken.ProjectID log.Printf("[Proxy] Using project ID from token for new session: %d", projectID) @@ -174,22 +202,74 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { _ = h.sessionRepo.Create(session) } - ctx = ctxutil.WithProjectID(ctx, projectID) + c.Set(flow.KeyProjectID, projectID) - // Execute request (executor handles request recording, project binding, routing, etc.) - err = h.executor.Execute(ctx, w, r) - if err != nil { - proxyErr, ok := err.(*domain.ProxyError) - if ok { - if stream { - writeStreamError(w, proxyErr) - } else { - writeProxyError(w, proxyErr) - } + r = r.WithContext(ctx) + c.Request = r + c.InboundBody = body + c.IsStream = stream + c.Set(flow.KeyProxyContext, ctx) + c.Set(flow.KeyProxyStream, stream) + c.Set(flow.KeyProxyRequestModel, requestModel) + + c.Next() +} + +func (h *ProxyHandler) dispatch(c *flow.Ctx) { + stream := c.IsStream + if v, ok := c.Get(flow.KeyProxyStream); ok { + if s, ok := v.(bool); ok { + stream = s + } + } + + err := h.executor.ExecuteWith(c) + if err == nil { + return + } + proxyErr, ok := err.(*domain.ProxyError) + if ok { + if stream { + writeStreamError(c.Writer, proxyErr) } else { - writeError(w, http.StatusInternalServerError, err.Error()) + writeProxyError(c.Writer, proxyErr) } + c.Err = err + c.Abort() + return + } + writeError(c.Writer, http.StatusInternalServerError, err.Error()) + c.Err = err + c.Abort() +} + +func normalizeOpenAIChatCompletionsPayload(body []byte) ([]byte, bool) { + var data map[string]interface{} + if err := json.Unmarshal(body, &data); err != nil { + return nil, false + } + if _, hasMessages := data["messages"]; hasMessages { + return nil, false + } + if _, hasInput := data["input"]; !hasInput { + if _, hasInstructions := data["instructions"]; !hasInstructions { + return nil, false + } + } + + model, _ := data["model"].(string) + stream, _ := data["stream"].(bool) + converted, err := converter.GetGlobalRegistry().TransformRequest( + domain.ClientTypeCodex, + domain.ClientTypeOpenAI, + body, + model, + stream, + ) + if err != nil { + return nil, false } + return converted, true } // Helper functions diff --git a/internal/service/admin.go b/internal/service/admin.go index bcd56147..12a7c32a 100644 --- a/internal/service/admin.go +++ b/internal/service/admin.go @@ -553,11 +553,11 @@ func (s *AdminService) GetLogs(limit int) (*LogsResult, error) { func (s *AdminService) autoSetSupportedClientTypes(provider *domain.Provider) { switch provider.Type { case "antigravity": - // Antigravity natively supports Claude and Gemini - // OpenAI requests will be converted to Claude format by Executor + // Antigravity natively supports Claude and Gemini. + // Conversion preference is Gemini-first. provider.SupportedClientTypes = []domain.ClientType{ - domain.ClientTypeClaude, domain.ClientTypeGemini, + domain.ClientTypeClaude, } case "kiro": // Kiro natively supports Claude protocol only diff --git a/web/package.json b/web/package.json index 817611a9..373b61b1 100644 --- a/web/package.json +++ b/web/package.json @@ -52,6 +52,11 @@ "tw-animate-css": "^1.4.0", "zustand": "^5.0.9" }, + "pnpm": { + "onlyBuiltDependencies": [ + "esbuild" + ] + }, "devDependencies": { "@eslint/js": "^9.39.1", "@tailwindcss/postcss": "^4.1.18", diff --git a/web/src/hooks/queries/use-providers.ts b/web/src/hooks/queries/use-providers.ts index 009b8731..d5bcfa16 100644 --- a/web/src/hooks/queries/use-providers.ts +++ b/web/src/hooks/queries/use-providers.ts @@ -51,8 +51,16 @@ export function useUpdateProvider() { const queryClient = useQueryClient(); return useMutation({ - mutationFn: ({ id, data }: { id: number; data: Partial }) => - getTransport().updateProvider(id, data), + mutationFn: async ({ id, data }: { id: number; data: Partial }) => { + const existing = + queryClient.getQueryData(providerKeys.detail(id)) || + queryClient + .getQueryData(providerKeys.list()) + ?.find((provider) => provider.id === id) || + (await getTransport().getProvider(id)); + const payload = existing ? { ...existing, ...data } : (data as Provider); + return getTransport().updateProvider(id, payload); + }, onSuccess: (_, { id }) => { queryClient.invalidateQueries({ queryKey: providerKeys.detail(id) }); queryClient.invalidateQueries({ queryKey: providerKeys.lists() }); diff --git a/web/src/pages/providers/components/clients-config-section.tsx b/web/src/pages/providers/components/clients-config-section.tsx index ac9c7496..0094bc5f 100644 --- a/web/src/pages/providers/components/clients-config-section.tsx +++ b/web/src/pages/providers/components/clients-config-section.tsx @@ -170,12 +170,12 @@ export function ClientsConfigSection({

-
-
-

+

+
+
@@ -193,7 +193,7 @@ export function ClientsConfigSection({ value={cloak.sensitiveWords} onChange={(e) => onUpdateCloak({ sensitiveWords: e.target.value })} placeholder={t('provider.cloakSensitiveWordsPlaceholder')} - className="min-h-[88px]" + className="min-h-[100px] bg-card" />

{t('provider.cloakSensitiveWordsDesc')}