diff --git a/internal/adapter/provider/adapter.go b/internal/adapter/provider/adapter.go
index 46519208..9013af11 100644
--- a/internal/adapter/provider/adapter.go
+++ b/internal/adapter/provider/adapter.go
@@ -1,10 +1,8 @@
package provider
import (
- "context"
- "net/http"
-
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
)
// ProviderAdapter handles communication with upstream providers
@@ -13,10 +11,10 @@ type ProviderAdapter interface {
SupportedClientTypes() []domain.ClientType
// Execute performs the proxy request to the upstream provider
- // It reads from ctx for ClientType, MappedModel, RequestBody
- // It writes the response to w
+ // It reads from flow.Ctx for ClientType, MappedModel, RequestBody
+ // It writes the response to c.Writer
// Returns ProxyError on failure
- Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error
+ Execute(c *flow.Ctx, provider *domain.Provider) error
}
// AdapterFactory creates ProviderAdapter instances
diff --git a/internal/adapter/provider/antigravity/adapter.go b/internal/adapter/provider/antigravity/adapter.go
index 0f610573..b83b8011 100644
--- a/internal/adapter/provider/antigravity/adapter.go
+++ b/internal/adapter/provider/antigravity/adapter.go
@@ -16,8 +16,9 @@ import (
"github.com/awsl-project/maxx/internal/adapter/provider"
cliproxyapi "github.com/awsl-project/maxx/internal/adapter/provider/cliproxyapi_antigravity"
- ctxutil "github.com/awsl-project/maxx/internal/context"
+ "github.com/awsl-project/maxx/internal/converter"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/usage"
)
@@ -32,10 +33,11 @@ type TokenCache struct {
}
type AntigravityAdapter struct {
- provider *domain.Provider
- tokenCache *TokenCache
- tokenMu sync.RWMutex
- httpClient *http.Client
+ provider *domain.Provider
+ tokenCache *TokenCache
+ tokenMu sync.RWMutex
+ projectIDOnce sync.Once
+ httpClient *http.Client
}
func NewAdapter(p *domain.Provider) (provider.ProviderAdapter, error) {
@@ -71,17 +73,21 @@ func NewAdapter(p *domain.Provider) (provider.ProviderAdapter, error) {
}
func (a *AntigravityAdapter) SupportedClientTypes() []domain.ClientType {
- // Antigravity natively supports Claude and Gemini by converting to Gemini/v1internal API
- // OpenAI requests will be converted to Claude format by Executor before reaching this adapter
- return []domain.ClientType{domain.ClientTypeClaude, domain.ClientTypeGemini}
+ // Antigravity natively supports Claude and Gemini (via Gemini/v1internal API).
+ // Prefer Gemini when choosing a target format.
+ return []domain.ClientType{domain.ClientTypeGemini, domain.ClientTypeClaude}
}
-func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error {
- clientType := ctxutil.GetClientType(ctx)
- baseCtx := ctx
- requestModel := ctxutil.GetRequestModel(ctx) // Original model from request (e.g., "claude-3-5-sonnet-20241022-online")
- mappedModel := ctxutil.GetMappedModel(ctx) // Mapped model after executor's unified mapping
- requestBody := ctxutil.GetRequestBody(ctx)
+func (a *AntigravityAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error {
+ clientType := flow.GetClientType(c)
+ requestModel := flow.GetRequestModel(c)
+ mappedModel := flow.GetMappedModel(c)
+ requestBody := flow.GetRequestBody(c)
+ request := c.Request
+ ctx := context.Background()
+ if request != nil {
+ ctx = request.Context()
+ }
backgroundDowngrade := false
backgroundModel := ""
@@ -99,8 +105,8 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter,
retriedWithoutThinking := false
for attemptIdx := 0; attemptIdx < 2; attemptIdx++ {
- ctx = ctxutil.WithRequestModel(baseCtx, requestModel)
- ctx = ctxutil.WithRequestBody(ctx, requestBody)
+ c.Set(flow.KeyRequestModel, requestModel)
+ c.Set(flow.KeyRequestBody, requestBody)
// Apply background downgrade override if needed
config := provider.Config.Antigravity
@@ -109,12 +115,12 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter,
}
// Update attempt record with the final mapped model (in case of background downgrade)
- if attempt := ctxutil.GetUpstreamAttempt(ctx); attempt != nil {
+ if attempt := flow.GetUpstreamAttempt(c); attempt != nil {
attempt.MappedModel = mappedModel
}
// Get streaming flag from context (already detected correctly for Gemini URL path)
- stream := ctxutil.GetIsStream(ctx)
+ stream := flow.GetIsStream(c)
clientWantsStream := stream
actualStream := stream
if clientType == domain.ClientTypeClaude && !clientWantsStream {
@@ -133,6 +139,7 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter,
// Transform request based on client type
var geminiBody []byte
+ openAIWrapped := false
switch clientType {
case domain.ClientTypeClaude:
// Use direct transformation (no converter dependency)
@@ -151,207 +158,247 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter,
// Apply minimal post-processing for features not yet fully integrated
geminiBody = applyClaudePostProcess(geminiBody, sessionID, hasThinking, requestBody, mappedModel)
case domain.ClientTypeOpenAI:
- // TODO: Implement OpenAI transformation in the future
- return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, true, "OpenAI transformation not yet implemented")
+ geminiBody = ConvertOpenAIRequestToAntigravity(mappedModel, requestBody, actualStream)
+ openAIWrapped = true
default:
// For Gemini, unwrap CLI envelope if present
geminiBody = unwrapGeminiCLIEnvelope(requestBody)
}
- // Wrap request in v1internal format
- var toolsForConfig []interface{}
- if clientType == domain.ClientTypeClaude {
- var raw map[string]interface{}
- if err := json.Unmarshal(requestBody, &raw); err == nil {
- if tools, ok := raw["tools"].([]interface{}); ok {
- toolsForConfig = tools
- }
+ // Resolve project ID (CLIProxyAPI behavior)
+ a.projectIDOnce.Do(func() {
+ if strings.TrimSpace(config.ProjectID) != "" {
+ return
}
- }
- upstreamBody, err := wrapV1InternalRequest(geminiBody, config.ProjectID, requestModel, mappedModel, sessionID, toolsForConfig)
- if err != nil {
- return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, true, "failed to wrap request for v1internal")
- }
-
- // Build upstream URLs (prod first, daily fallback)
- baseURLs := []string{V1InternalBaseURLProd, V1InternalBaseURLDaily}
- client := a.httpClient
- var lastErr error
-
- for idx, base := range baseURLs {
- upstreamURL := a.buildUpstreamURL(base, actualStream)
-
- upstreamReq, reqErr := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody))
- if reqErr != nil {
- lastErr = reqErr
- continue
+ if pid, _, err := FetchProjectInfo(ctx, accessToken, config.Email); err == nil {
+ pid = strings.TrimSpace(pid)
+ if pid != "" {
+ config.ProjectID = pid
+ }
}
+ })
+ projectID := strings.TrimSpace(config.ProjectID)
- // Set only the required headers (like Antigravity-Manager)
- upstreamReq.Header.Set("Content-Type", "application/json")
- upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
- upstreamReq.Header.Set("User-Agent", AntigravityUserAgent)
-
- // Send request info via EventChannel (only once per attempt)
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
- eventChan.SendRequestInfo(&domain.RequestInfo{
- Method: upstreamReq.Method,
- URL: upstreamURL,
- Headers: flattenHeaders(upstreamReq.Header),
- Body: string(upstreamBody),
- })
+ var upstreamBody []byte
+ if openAIWrapped {
+ upstreamBody = finalizeOpenAIWrappedRequest(geminiBody, projectID, mappedModel, sessionID)
+ } else {
+ // Wrap request in v1internal format
+ var toolsForConfig []interface{}
+ if clientType == domain.ClientTypeClaude {
+ var raw map[string]interface{}
+ if err := json.Unmarshal(requestBody, &raw); err == nil {
+ if tools, ok := raw["tools"].([]interface{}); ok {
+ toolsForConfig = tools
+ }
+ }
}
-
- resp, err := client.Do(upstreamReq)
+ upstreamBody, err = wrapV1InternalRequest(geminiBody, projectID, requestModel, mappedModel, sessionID, toolsForConfig)
if err != nil {
- lastErr = err
- if hasNextEndpoint(idx, len(baseURLs)) {
- continue
- }
- proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream")
- proxyErr.IsNetworkError = true // Mark as network error (connection timeout, DNS failure, etc.)
- return proxyErr
+ return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, true, "failed to wrap request for v1internal")
}
- defer resp.Body.Close()
+ }
- // Check for 401 (token expired) and retry once
- if resp.StatusCode == http.StatusUnauthorized {
- resp.Body.Close()
+ // Build upstream URLs (CLIProxyAPI fallback order)
+ baseURLs := antigravityBaseURLFallbackOrder(config.Endpoint)
+ client := a.httpClient
+ var lastErr error
- // Invalidate token cache
- a.tokenMu.Lock()
- a.tokenCache = &TokenCache{}
- a.tokenMu.Unlock()
+ for attempt := 0; attempt < antigravityRetryAttempts; attempt++ {
+ for idx, base := range baseURLs {
+ upstreamURL := a.buildUpstreamURL(base, actualStream)
- // Get new token
- accessToken, err = a.getAccessToken(ctx)
- if err != nil {
- return domain.NewProxyErrorWithMessage(err, true, "failed to refresh access token")
+ upstreamReq, reqErr := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody))
+ if reqErr != nil {
+ lastErr = reqErr
+ continue
}
- // Retry request with only required headers
- upstreamReq, _ = http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody))
+ // Set only the required headers (like Antigravity-Manager)
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
upstreamReq.Header.Set("User-Agent", AntigravityUserAgent)
- resp, err = client.Do(upstreamReq)
+
+ // Send request info via EventChannel (only once per attempt)
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
+ eventChan.SendRequestInfo(&domain.RequestInfo{
+ Method: upstreamReq.Method,
+ URL: upstreamURL,
+ Headers: flattenHeaders(upstreamReq.Header),
+ Body: string(upstreamBody),
+ })
+ }
+
+ resp, err := client.Do(upstreamReq)
if err != nil {
lastErr = err
if hasNextEndpoint(idx, len(baseURLs)) {
continue
}
- proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream after token refresh")
- proxyErr.IsNetworkError = true // Mark as network error
+ proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream")
+ proxyErr.IsNetworkError = true // Mark as network error (connection timeout, DNS failure, etc.)
return proxyErr
}
- defer resp.Body.Close()
- }
- // Check for error response
- if resp.StatusCode >= 400 {
- body, _ := io.ReadAll(resp.Body)
- // Send error response info via EventChannel
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
- eventChan.SendResponseInfo(&domain.ResponseInfo{
- Status: resp.StatusCode,
- Headers: flattenHeaders(resp.Header),
- Body: string(body),
- })
- }
+ // Check for 401 (token expired) and retry once
+ if resp.StatusCode == http.StatusUnauthorized {
+ resp.Body.Close()
- // Check for RESOURCE_EXHAUSTED (429) and extract cooldown info
- var rateLimitInfo *domain.RateLimitInfo
- var cooldownUpdateChan chan time.Time
- if resp.StatusCode == http.StatusTooManyRequests {
- rateLimitInfo, cooldownUpdateChan = a.parseRateLimitInfo(ctx, body, provider)
- }
+ // Invalidate token cache
+ a.tokenMu.Lock()
+ a.tokenCache = &TokenCache{}
+ a.tokenMu.Unlock()
- // Parse retry info for 429/5xx responses (like Antigravity-Manager)
- var retryAfter time.Duration
+ // Get new token
+ accessToken, err = a.getAccessToken(ctx)
+ if err != nil {
+ return domain.NewProxyErrorWithMessage(err, true, "failed to refresh access token")
+ }
- // 1) Prefer Retry-After header (seconds)
- if ra := strings.TrimSpace(resp.Header.Get("Retry-After")); ra != "" {
- if secs, err := strconv.Atoi(ra); err == nil && secs > 0 {
- retryAfter = time.Duration(secs) * time.Second
+ // Retry request with only required headers
+ upstreamReq, reqErr = http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(upstreamBody))
+ if reqErr != nil {
+ return domain.NewProxyErrorWithMessage(reqErr, false, "failed to create upstream request after token refresh")
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ upstreamReq.Header.Set("User-Agent", AntigravityUserAgent)
+ resp, err = client.Do(upstreamReq)
+ if err != nil {
+ lastErr = err
+ if hasNextEndpoint(idx, len(baseURLs)) {
+ continue
+ }
+ proxyErr := domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to connect to upstream after token refresh")
+ proxyErr.IsNetworkError = true // Mark as network error
+ return proxyErr
}
}
- // 2) Fallback to body parsing (google.rpc.RetryInfo / quotaResetDelay)
- if retryAfter == 0 {
- if retryInfo := ParseRetryInfo(resp.StatusCode, body); retryInfo != nil {
- retryAfter = retryInfo.Delay
+ // Check for error response
+ if resp.StatusCode >= 400 {
+ body, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ // Send error response info via EventChannel
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
+ eventChan.SendResponseInfo(&domain.ResponseInfo{
+ Status: resp.StatusCode,
+ Headers: flattenHeaders(resp.Header),
+ Body: string(body),
+ })
+ }
- // Manager: add a small buffer and cap for 429 retries
- if resp.StatusCode == http.StatusTooManyRequests {
- retryAfter += 200 * time.Millisecond
- if retryAfter > 10*time.Second {
- retryAfter = 10 * time.Second
- }
+ // Check for RESOURCE_EXHAUSTED (429) and extract cooldown info
+ var rateLimitInfo *domain.RateLimitInfo
+ var cooldownUpdateChan chan time.Time
+ if resp.StatusCode == http.StatusTooManyRequests {
+ rateLimitInfo, cooldownUpdateChan = a.parseRateLimitInfo(ctx, body, provider)
+ }
+
+ // Parse retry info for 429/5xx responses (like Antigravity-Manager)
+ var retryAfter time.Duration
+
+ // 1) Prefer Retry-After header (seconds)
+ if ra := strings.TrimSpace(resp.Header.Get("Retry-After")); ra != "" {
+ if secs, err := strconv.Atoi(ra); err == nil && secs > 0 {
+ retryAfter = time.Duration(secs) * time.Second
}
+ }
+
+ // 2) Fallback to body parsing (google.rpc.RetryInfo / quotaResetDelay)
+ if retryAfter == 0 {
+ if retryInfo := ParseRetryInfo(resp.StatusCode, body); retryInfo != nil {
+ retryAfter = retryInfo.Delay
+
+ // Manager: add a small buffer and cap for 429 retries
+ if resp.StatusCode == http.StatusTooManyRequests {
+ retryAfter += 200 * time.Millisecond
+ if retryAfter > 10*time.Second {
+ retryAfter = 10 * time.Second
+ }
+ }
- retryAfter = ApplyJitter(retryAfter)
+ retryAfter = ApplyJitter(retryAfter)
+ }
}
- }
- proxyErr := domain.NewProxyErrorWithMessage(
- fmt.Errorf("upstream error: %s", string(body)),
- isRetryableStatusCode(resp.StatusCode),
- fmt.Sprintf("upstream returned status %d", resp.StatusCode),
- )
+ proxyErr := domain.NewProxyErrorWithMessage(
+ fmt.Errorf("upstream error: %s", string(body)),
+ isRetryableStatusCode(resp.StatusCode),
+ fmt.Sprintf("upstream returned status %d", resp.StatusCode),
+ )
- // Set status code and check if it's a server error (5xx)
- proxyErr.HTTPStatusCode = resp.StatusCode
- proxyErr.IsServerError = resp.StatusCode >= 500 && resp.StatusCode < 600
+ // Set status code and check if it's a server error (5xx)
+ proxyErr.HTTPStatusCode = resp.StatusCode
+ proxyErr.IsServerError = resp.StatusCode >= 500 && resp.StatusCode < 600
- // Set retry info on error for upstream handling
- if retryAfter > 0 {
- proxyErr.RetryAfter = retryAfter
- }
+ // Set retry info on error for upstream handling
+ if retryAfter > 0 {
+ proxyErr.RetryAfter = retryAfter
+ }
- // Set rate limit info for cooldown handling
- if rateLimitInfo != nil {
- proxyErr.RateLimitInfo = rateLimitInfo
- proxyErr.CooldownUpdateChan = cooldownUpdateChan
- }
+ // Set rate limit info for cooldown handling
+ if rateLimitInfo != nil {
+ proxyErr.RateLimitInfo = rateLimitInfo
+ proxyErr.CooldownUpdateChan = cooldownUpdateChan
+ }
- lastErr = proxyErr
+ lastErr = proxyErr
- // Signature failure recovery: retry once without thinking (like Manager)
- if resp.StatusCode == http.StatusBadRequest && !retriedWithoutThinking && isThinkingSignatureError(body) {
- retriedWithoutThinking = true
+ // Signature failure recovery: retry once without thinking (like Manager)
+ if resp.StatusCode == http.StatusBadRequest && !retriedWithoutThinking && isThinkingSignatureError(body) {
+ retriedWithoutThinking = true
+
+ // Manager uses a small fixed delay before retrying.
+ select {
+ case <-ctx.Done():
+ return domain.NewProxyErrorWithMessage(ctx.Err(), false, "client disconnected")
+ case <-time.After(200 * time.Millisecond):
+ }
- // Manager uses a small fixed delay before retrying.
- select {
- case <-ctx.Done():
- return domain.NewProxyErrorWithMessage(ctx.Err(), false, "client disconnected")
- case <-time.After(200 * time.Millisecond):
+ requestBody = stripThinkingFromClaude(requestBody)
+ if newModel := extractModelFromBody(requestBody); newModel != "" {
+ requestModel = newModel
+ }
+ mappedModel = "" // force remap
+ continue
}
- requestBody = stripThinkingFromClaude(requestBody)
- if newModel := extractModelFromBody(requestBody); newModel != "" {
- requestModel = newModel
+ // Retry fallback handling (CLIProxyAPI behavior)
+ if resp.StatusCode == http.StatusTooManyRequests && hasNextEndpoint(idx, len(baseURLs)) {
+ continue
}
- mappedModel = "" // force remap
- continue
+ if antigravityShouldRetryNoCapacity(resp.StatusCode, body) {
+ if hasNextEndpoint(idx, len(baseURLs)) {
+ continue
+ }
+ if attempt+1 < antigravityRetryAttempts {
+ delay := antigravityNoCapacityRetryDelay(attempt)
+ if err := antigravityWait(ctx, delay); err != nil {
+ return domain.NewProxyErrorWithMessage(err, false, "client disconnected")
+ }
+ break
+ }
+ }
+
+ return proxyErr
}
- // Fallback to next endpoint if available and retryable
- if hasNextEndpoint(idx, len(baseURLs)) && shouldTryNextEndpoint(resp.StatusCode) {
+ // Handle response
+ if actualStream && !clientWantsStream {
+ err := a.handleCollectedStreamResponse(c, resp, clientType, requestModel)
resp.Body.Close()
- continue
+ return err
}
-
- return proxyErr
- }
-
- // Handle response
- if actualStream && !clientWantsStream {
- return a.handleCollectedStreamResponse(ctx, w, resp, clientType, requestModel)
- }
- if actualStream {
- return a.handleStreamResponse(ctx, w, resp, clientType)
+ if actualStream {
+ err := a.handleStreamResponse(c, resp, clientType)
+ resp.Body.Close()
+ return err
+ }
+ nErr := a.handleNonStreamResponse(c, resp, clientType)
+ resp.Body.Close()
+ return nErr
}
- return a.handleNonStreamResponse(ctx, w, resp, clientType)
}
// All endpoints failed in this iteration
@@ -490,17 +537,26 @@ func applyClaudePostProcess(geminiBody []byte, sessionID string, hasThinking boo
return result
}
-// v1internal endpoints (prod + daily fallback, like Antigravity-Manager)
+// v1internal endpoints (CLIProxyAPI fallback order)
const (
- V1InternalBaseURLProd = "https://cloudcode-pa.googleapis.com/v1internal"
- V1InternalBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com/v1internal"
+ V1InternalBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
+ V1InternalSandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
+ V1InternalBaseURLProd = "https://cloudcode-pa.googleapis.com"
+ antigravityRetryAttempts = 3
)
func (a *AntigravityAdapter) buildUpstreamURL(base string, stream bool) string {
+ base = strings.TrimRight(base, "/")
+ if strings.Contains(base, "/v1internal") {
+ if stream {
+ return fmt.Sprintf("%s:streamGenerateContent?alt=sse", base)
+ }
+ return fmt.Sprintf("%s:generateContent", base)
+ }
if stream {
- return fmt.Sprintf("%s:streamGenerateContent?alt=sse", base)
+ return fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", base)
}
- return fmt.Sprintf("%s:generateContent", base)
+ return fmt.Sprintf("%s/v1internal:generateContent", base)
}
func hasNextEndpoint(index, total int) bool {
@@ -516,6 +572,76 @@ func shouldTryNextEndpoint(status int) bool {
return status >= 500
}
+func antigravityBaseURLFallbackOrder(endpoint string) []string {
+ if endpoint = strings.TrimSpace(endpoint); endpoint != "" {
+ if isAntigravityEndpoint(endpoint) {
+ return []string{strings.TrimRight(endpoint, "/")}
+ }
+ }
+ return []string{
+ V1InternalBaseURLDaily,
+ V1InternalSandboxBaseURLDaily,
+ // V1InternalBaseURLProd,
+ }
+}
+
+func isAntigravityEndpoint(endpoint string) bool {
+ endpoint = strings.ToLower(strings.TrimSpace(endpoint))
+ if endpoint == "" {
+ return false
+ }
+ // Only accept Antigravity v1internal endpoints, not Vertex AI endpoints.
+ if strings.Contains(endpoint, "cloudcode-pa.googleapis.com") {
+ return true
+ }
+ if strings.Contains(endpoint, "daily-cloudcode-pa.googleapis.com") {
+ return true
+ }
+ if strings.Contains(endpoint, "daily-cloudcode-pa.sandbox.googleapis.com") {
+ return true
+ }
+ if strings.Contains(endpoint, "/v1internal") && strings.Contains(endpoint, "cloudcode-pa") {
+ return true
+ }
+ return false
+}
+
+func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
+ if statusCode != http.StatusServiceUnavailable {
+ return false
+ }
+ if len(body) == 0 {
+ return false
+ }
+ msg := strings.ToLower(string(body))
+ return strings.Contains(msg, "no capacity available")
+}
+
+func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
+ if attempt < 0 {
+ attempt = 0
+ }
+ delay := time.Duration(attempt+1) * 250 * time.Millisecond
+ if delay > 2*time.Second {
+ delay = 2 * time.Second
+ }
+ return delay
+}
+
+func antigravityWait(ctx context.Context, wait time.Duration) error {
+ if wait <= 0 {
+ return nil
+ }
+ timer := time.NewTimer(wait)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+}
+
// isThinkingSignatureError detects thinking signature related 400 errors (like Manager)
func isThinkingSignatureError(body []byte) bool {
bodyStr := strings.ToLower(string(body))
@@ -526,7 +652,8 @@ func isThinkingSignatureError(body []byte) bool {
strings.Contains(bodyStr, "failed to deserialise")
}
-func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType) error {
+func (a *AntigravityAdapter) handleNonStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType) error {
+ w := c.Writer
body, err := io.ReadAll(resp.Body)
if err != nil {
return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to read upstream response")
@@ -535,31 +662,30 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http
// Unwrap v1internal response wrapper (extract "response" field)
unwrappedBody := unwrapV1InternalResponse(body)
- // Send events via EventChannel (executor will process them)
- eventChan := ctxutil.GetEventChan(ctx)
-
- // Send response info event
- eventChan.SendResponseInfo(&domain.ResponseInfo{
- Status: resp.StatusCode,
- Headers: flattenHeaders(resp.Header),
- Body: string(body), // Keep original for debugging
- })
-
- // Extract and send token usage metrics
- if metrics := usage.ExtractFromResponse(string(unwrappedBody)); metrics != nil {
- eventChan.SendMetrics(&domain.AdapterMetrics{
- InputTokens: metrics.InputTokens,
- OutputTokens: metrics.OutputTokens,
- CacheReadCount: metrics.CacheReadCount,
- CacheCreationCount: metrics.CacheCreationCount,
- Cache5mCreationCount: metrics.Cache5mCreationCount,
- Cache1hCreationCount: metrics.Cache1hCreationCount,
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
+ eventChan.SendResponseInfo(&domain.ResponseInfo{
+ Status: resp.StatusCode,
+ Headers: flattenHeaders(resp.Header),
+ Body: string(body),
})
+
+ if metrics := usage.ExtractFromResponse(string(unwrappedBody)); metrics != nil {
+ eventChan.SendMetrics(&domain.AdapterMetrics{
+ InputTokens: metrics.InputTokens,
+ OutputTokens: metrics.OutputTokens,
+ CacheReadCount: metrics.CacheReadCount,
+ CacheCreationCount: metrics.CacheCreationCount,
+ Cache5mCreationCount: metrics.Cache5mCreationCount,
+ Cache1hCreationCount: metrics.Cache1hCreationCount,
+ })
+ }
}
// Extract and send response model
if modelVersion := extractModelVersion(unwrappedBody); modelVersion != "" {
- eventChan.SendResponseModel(modelVersion)
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
+ eventChan.SendResponseModel(modelVersion)
+ }
}
var responseBody []byte
@@ -567,14 +693,17 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http
// Transform response based on client type
switch clientType {
case domain.ClientTypeClaude:
- requestModel := ctxutil.GetRequestModel(ctx)
+ requestModel := flow.GetRequestModel(c)
responseBody, err = convertGeminiToClaudeResponse(unwrappedBody, requestModel)
if err != nil {
return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "failed to transform response")
}
case domain.ClientTypeOpenAI:
- // TODO: Implement OpenAI response transformation
- return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "OpenAI response transformation not yet implemented")
+ responseBody, err = converter.GetGlobalRegistry().TransformResponse(
+ domain.ClientTypeGemini, domain.ClientTypeOpenAI, unwrappedBody)
+ if err != nil {
+ return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "failed to transform response")
+ }
default:
// Gemini native
responseBody = unwrappedBody
@@ -588,8 +717,13 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http
return nil
}
-func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType) error {
- eventChan := ctxutil.GetEventChan(ctx)
+func (a *AntigravityAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType) error {
+ w := c.Writer
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
+ eventChan := flow.GetEventChan(c)
// Send initial response info (for streaming, we only capture status and headers)
eventChan.SendResponseInfo(&domain.ResponseInfo{
@@ -614,18 +748,23 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re
// Use specialized Claude SSE handler for Claude clients
isClaudeClient := clientType == domain.ClientTypeClaude
+ isOpenAIClient := clientType == domain.ClientTypeOpenAI
// Extract sessionID for signature caching (like CLIProxyAPI)
- requestBody := ctxutil.GetRequestBody(ctx)
+ requestBody := flow.GetRequestBody(c)
sessionID := extractSessionID(requestBody)
// Get original request model for Claude response (like Antigravity-Manager)
- requestModel := ctxutil.GetRequestModel(ctx)
+ requestModel := flow.GetRequestModel(c)
var claudeState *ClaudeStreamingState
if isClaudeClient {
claudeState = NewClaudeStreamingStateWithSession(sessionID, requestModel)
}
+ var openaiState *converter.TransformState
+ if isOpenAIClient {
+ openaiState = converter.NewTransformState()
+ }
// Collect all SSE events for response body and token extraction
var sseBuffer strings.Builder
@@ -698,6 +837,9 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re
// Unwrap v1internal SSE chunk before processing
unwrappedLine := unwrapV1InternalSSEChunk(lineBytes)
+ if len(unwrappedLine) == 0 {
+ continue
+ }
// Collect original SSE for token extraction (extractor handles v1internal wrapper)
sseBuffer.WriteString(line)
@@ -706,9 +848,13 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re
if isClaudeClient {
// Use specialized Claude SSE transformation
output = claudeState.ProcessGeminiSSELine(string(unwrappedLine))
- } else if clientType == domain.ClientTypeOpenAI {
- // TODO: Implement OpenAI streaming transformation
- continue
+ } else if isOpenAIClient {
+ converted, convErr := converter.GetGlobalRegistry().TransformStreamChunk(
+ domain.ClientTypeGemini, domain.ClientTypeOpenAI, unwrappedLine, openaiState)
+ if convErr != nil {
+ continue
+ }
+ output = converted
} else {
// Gemini native
output = unwrappedLine
@@ -770,15 +916,20 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re
}
// handleCollectedStreamResponse forwards upstream SSE but collects into a single response body (like Manager non-stream auto-convert)
-func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType, requestModel string) error {
- eventChan := ctxutil.GetEventChan(ctx)
-
- // Send initial response info
- eventChan.SendResponseInfo(&domain.ResponseInfo{
- Status: resp.StatusCode,
- Headers: flattenHeaders(resp.Header),
- Body: "[stream-collected]",
- })
+func (a *AntigravityAdapter) handleCollectedStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType, requestModel string) error {
+ w := c.Writer
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
+ eventChan := flow.GetEventChan(c)
+ if eventChan != nil {
+ eventChan.SendResponseInfo(&domain.ResponseInfo{
+ Status: resp.StatusCode,
+ Headers: flattenHeaders(resp.Header),
+ Body: "[stream-collected]",
+ })
+ }
// Copy upstream headers (except those we override)
copyResponseHeaders(w.Header(), resp.Header)
@@ -788,14 +939,14 @@ func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context,
var claudeSSE strings.Builder
if isClaudeClient {
// Extract sessionID for signature caching (like CLIProxyAPI)
- requestBody := ctxutil.GetRequestBody(ctx)
+ requestBody := flow.GetRequestBody(c)
sessionID := extractSessionID(requestBody)
claudeState = NewClaudeStreamingStateWithSession(sessionID, requestModel)
}
// Collect upstream SSE for attempt/debug and token extraction.
var upstreamSSE strings.Builder
- var lastPayload []byte
+ var unwrappedSSE strings.Builder
var responseBody []byte
var lineBuffer bytes.Buffer
@@ -826,15 +977,7 @@ func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context,
if len(unwrappedLine) == 0 {
continue
}
-
- // Track last Gemini payload for non-Claude responses (best-effort)
- lineStr := strings.TrimSpace(string(unwrappedLine))
- if strings.HasPrefix(lineStr, "data: ") {
- dataStr := strings.TrimSpace(strings.TrimPrefix(lineStr, "data: "))
- if dataStr != "" && dataStr != "[DONE]" {
- lastPayload = []byte(dataStr)
- }
- }
+ unwrappedSSE.Write(unwrappedLine)
if isClaudeClient && claudeState != nil {
out := claudeState.ProcessGeminiSSELine(string(unwrappedLine))
@@ -901,16 +1044,23 @@ func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context,
}
responseBody = collected
} else {
- if len(lastPayload) == 0 {
+ if unwrappedSSE.Len() == 0 {
return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "empty upstream stream response")
}
+ geminiWrapped := convertStreamToNonStream([]byte(unwrappedSSE.String()))
+ geminiResponse := unwrapV1InternalResponse(geminiWrapped)
switch clientType {
case domain.ClientTypeGemini:
- responseBody = lastPayload
+ responseBody = geminiResponse
case domain.ClientTypeOpenAI:
- return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "OpenAI response transformation not yet implemented")
+ var convErr error
+ responseBody, convErr = converter.GetGlobalRegistry().TransformResponse(
+ domain.ClientTypeGemini, domain.ClientTypeOpenAI, geminiResponse)
+ if convErr != nil {
+ return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "failed to transform response")
+ }
default:
- responseBody = lastPayload
+ responseBody = geminiResponse
}
}
diff --git a/internal/adapter/provider/antigravity/openai_request.go b/internal/adapter/provider/antigravity/openai_request.go
new file mode 100644
index 00000000..521f7ce9
--- /dev/null
+++ b/internal/adapter/provider/antigravity/openai_request.go
@@ -0,0 +1,426 @@
+package antigravity
+
+import (
+ "bytes"
+ "fmt"
+ "log"
+ "mime"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
+
+// ConvertOpenAIRequestToAntigravity converts an OpenAI Chat Completions request (raw JSON)
+// into a Gemini CLI compatible request JSON (antigravity format).
+// Ported from CLIProxyAPI antigravity/openai/chat-completions translator.
+func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
+ rawJSON := bytes.Clone(inputRawJSON)
+ // Base envelope (no default thinkingConfig)
+ out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
+
+ // Model
+ out, _ = sjson.SetBytes(out, "model", modelName)
+
+ // Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
+ re := gjson.GetBytes(rawJSON, "reasoning_effort")
+ if re.Exists() {
+ effort := strings.ToLower(strings.TrimSpace(re.String()))
+ if effort != "" {
+ thinkingPath := "request.generationConfig.thinkingConfig"
+ if effort == "auto" {
+ out, _ = sjson.SetBytes(out, thinkingPath+".thinkingBudget", -1)
+ out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", true)
+ } else {
+ out, _ = sjson.SetBytes(out, thinkingPath+".thinkingLevel", effort)
+ out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", effort != "none")
+ }
+ }
+ }
+
+ // Temperature/top_p/top_k/max_tokens
+ if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
+ }
+ if tpr := gjson.GetBytes(rawJSON, "top_p"); tpr.Exists() && tpr.Type == gjson.Number {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.topP", tpr.Num)
+ }
+ if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
+ }
+ if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
+ }
+
+ // Candidate count (OpenAI 'n' parameter)
+ if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
+ if val := n.Int(); val > 1 {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
+ }
+ }
+
+ // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
+ if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
+ var responseMods []string
+ for _, m := range mods.Array() {
+ switch strings.ToLower(m.String()) {
+ case "text":
+ responseMods = append(responseMods, "TEXT")
+ case "image":
+ responseMods = append(responseMods, "IMAGE")
+ }
+ }
+ if len(responseMods) > 0 {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.responseModalities", responseMods)
+ }
+ }
+
+ // OpenRouter-style image_config support
+ if imgCfg := gjson.GetBytes(rawJSON, "image_config"); imgCfg.Exists() && imgCfg.IsObject() {
+ if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.aspectRatio", ar.Str)
+ }
+ if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String {
+ out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.imageSize", size.Str)
+ }
+ }
+
+ // messages -> systemInstruction + contents
+ messages := gjson.GetBytes(rawJSON, "messages")
+ if messages.IsArray() {
+ arr := messages.Array()
+ // First pass: assistant tool_calls id->name map
+ tcID2Name := map[string]string{}
+ for i := 0; i < len(arr); i++ {
+ m := arr[i]
+ if m.Get("role").String() == "assistant" {
+ tcs := m.Get("tool_calls")
+ if tcs.IsArray() {
+ for _, tc := range tcs.Array() {
+ if tc.Get("type").String() == "function" {
+ id := tc.Get("id").String()
+ name := tc.Get("function.name").String()
+ if id != "" && name != "" {
+ tcID2Name[id] = name
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Second pass build systemInstruction/tool responses cache
+ toolResponses := map[string]string{} // tool_call_id -> response text
+ for i := 0; i < len(arr); i++ {
+ m := arr[i]
+ role := m.Get("role").String()
+ if role == "tool" {
+ toolCallID := m.Get("tool_call_id").String()
+ if toolCallID != "" {
+ c := m.Get("content")
+ toolResponses[toolCallID] = c.Raw
+ }
+ }
+ }
+
+ systemPartIndex := 0
+ for i := 0; i < len(arr); i++ {
+ m := arr[i]
+ role := m.Get("role").String()
+ content := m.Get("content")
+
+ if (role == "system" || role == "developer") && len(arr) > 1 {
+ // system -> request.systemInstruction as a user message style
+ if content.Type == gjson.String {
+ out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
+ out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.String())
+ systemPartIndex++
+ } else if content.IsObject() && content.Get("type").String() == "text" {
+ out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
+ out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
+ systemPartIndex++
+ } else if content.IsArray() {
+ contents := content.Array()
+ if len(contents) > 0 {
+ out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
+ for j := 0; j < len(contents); j++ {
+ text := contents[j].Get("text").String()
+ if text != "" {
+ out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), text)
+ systemPartIndex++
+ }
+ }
+ }
+ }
+ } else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
+ // Build single user content node to avoid splitting into multiple contents
+ node := []byte(`{"role":"user","parts":[]}`)
+ if content.Type == gjson.String {
+ node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
+ } else if content.IsArray() {
+ items := content.Array()
+ p := 0
+ for _, item := range items {
+ switch item.Get("type").String() {
+ case "text":
+ text := item.Get("text").String()
+ if text != "" {
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
+ p++
+ }
+ case "image_url":
+ imageURL := item.Get("image_url.url").String()
+ if strings.HasPrefix(imageURL, "data:") {
+ pieces := strings.SplitN(imageURL[len("data:"):], ";", 2)
+ if len(pieces) == 2 && len(pieces[1]) > 7 {
+ mimeType := pieces[0]
+ data := pieces[1][7:]
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
+ p++
+ }
+ }
+ case "file":
+ filename := item.Get("file.filename").String()
+ fileData := item.Get("file.file_data").String()
+ if filename != "" && fileData != "" {
+ ext := ""
+ if sp := strings.Split(filename, "."); len(sp) > 1 {
+ ext = sp[len(sp)-1]
+ }
+ mimeType := mime.TypeByExtension("." + ext)
+ if mimeType != "" {
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
+ p++
+ } else {
+ log.Printf("unknown file extension '%s' in user message, skip", ext)
+ }
+ }
+ }
+ }
+ }
+ out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
+ } else if role == "assistant" {
+ node := []byte(`{"role":"model","parts":[]}`)
+ p := 0
+ if content.Type == gjson.String && content.String() != "" {
+ node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
+ p++
+ } else if content.IsArray() {
+ // Assistant multimodal content -> single model content with parts
+ for _, item := range content.Array() {
+ switch item.Get("type").String() {
+ case "text":
+ text := item.Get("text").String()
+ if text != "" {
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
+ p++
+ }
+ case "image_url":
+ imageURL := item.Get("image_url.url").String()
+ if strings.HasPrefix(imageURL, "data:") { // expect data:...
+ pieces := strings.SplitN(imageURL[len("data:"):], ";", 2)
+ if len(pieces) == 2 && len(pieces[1]) > 7 {
+ mimeType := pieces[0]
+ data := pieces[1][7:]
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
+ p++
+ }
+ }
+ }
+ }
+ }
+
+ // Tool calls -> single model content with functionCall parts
+ tcs := m.Get("tool_calls")
+ if tcs.IsArray() {
+ fIDs := make([]string, 0)
+ for _, tc := range tcs.Array() {
+ if tc.Get("type").String() != "function" {
+ continue
+ }
+ fid := tc.Get("id").String()
+ fname := tc.Get("function.name").String()
+ fargs := tc.Get("function.arguments").String()
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
+ if gjson.Valid(fargs) {
+ node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
+ } else {
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", fargs)
+ }
+ node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
+ p++
+ if fid != "" {
+ fIDs = append(fIDs, fid)
+ }
+ }
+ out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
+
+ // Append a single tool content combining name + response per function
+ toolNode := []byte(`{"role":"user","parts":[]}`)
+ pp := 0
+ for _, fid := range fIDs {
+ if name, ok := tcID2Name[fid]; ok {
+ toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
+ toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
+ resp := toolResponses[fid]
+ if resp == "" {
+ resp = "{}"
+ }
+ if resp != "null" {
+ parsed := gjson.Parse(resp)
+ if parsed.Type == gjson.JSON {
+ toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
+ } else {
+ toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
+ }
+ }
+ pp++
+ }
+ }
+ if pp > 0 {
+ out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
+ }
+ } else {
+ out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
+ }
+ }
+ }
+ }
+
+ // tools -> request.tools[].functionDeclarations + request.tools[].googleSearch/codeExecution/urlContext passthrough
+ tools := gjson.GetBytes(rawJSON, "tools")
+ if tools.IsArray() && len(tools.Array()) > 0 {
+ functionToolNode := []byte(`{}`)
+ hasFunction := false
+ googleSearchNodes := make([][]byte, 0)
+ codeExecutionNodes := make([][]byte, 0)
+ urlContextNodes := make([][]byte, 0)
+ for _, t := range tools.Array() {
+ if t.Get("type").String() == "function" {
+ fn := t.Get("function")
+ if fn.Exists() && fn.IsObject() {
+ fnRaw := fn.Raw
+ if fn.Get("parameters").Exists() {
+ renamed, errRename := RenameKey(fnRaw, "parameters", "parametersJsonSchema")
+ if errRename != nil {
+ log.Printf("failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
+ fnRaw, _ = sjson.Delete(fnRaw, "parameters")
+ var errSet error
+ fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
+ if errSet != nil {
+ log.Printf("failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
+ continue
+ }
+ fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
+ if errSet != nil {
+ log.Printf("failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
+ continue
+ }
+ } else {
+ fnRaw = renamed
+ }
+ } else {
+ var errSet error
+ fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
+ if errSet != nil {
+ log.Printf("failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
+ continue
+ }
+ fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
+ if errSet != nil {
+ log.Printf("failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
+ continue
+ }
+ }
+ fnRaw, _ = sjson.Delete(fnRaw, "strict")
+ if !hasFunction {
+ functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
+ }
+ tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
+ if errSet != nil {
+ log.Printf("failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
+ continue
+ }
+ functionToolNode = tmp
+ hasFunction = true
+ }
+ }
+ if gs := t.Get("google_search"); gs.Exists() {
+ googleToolNode := []byte(`{}`)
+ var errSet error
+ googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
+ if errSet != nil {
+ log.Printf("failed to set googleSearch tool: %v", errSet)
+ continue
+ }
+ googleSearchNodes = append(googleSearchNodes, googleToolNode)
+ }
+ if ce := t.Get("code_execution"); ce.Exists() {
+ codeToolNode := []byte(`{}`)
+ var errSet error
+ codeToolNode, errSet = sjson.SetRawBytes(codeToolNode, "codeExecution", []byte(ce.Raw))
+ if errSet != nil {
+ log.Printf("failed to set codeExecution tool: %v", errSet)
+ continue
+ }
+ codeExecutionNodes = append(codeExecutionNodes, codeToolNode)
+ }
+ if uc := t.Get("url_context"); uc.Exists() {
+ urlToolNode := []byte(`{}`)
+ var errSet error
+ urlToolNode, errSet = sjson.SetRawBytes(urlToolNode, "urlContext", []byte(uc.Raw))
+ if errSet != nil {
+ log.Printf("failed to set urlContext tool: %v", errSet)
+ continue
+ }
+ urlContextNodes = append(urlContextNodes, urlToolNode)
+ }
+ }
+ if hasFunction || len(googleSearchNodes) > 0 || len(codeExecutionNodes) > 0 || len(urlContextNodes) > 0 {
+ toolsNode := []byte("[]")
+ if hasFunction {
+ toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
+ }
+ for _, googleNode := range googleSearchNodes {
+ toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
+ }
+ for _, codeNode := range codeExecutionNodes {
+ toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", codeNode)
+ }
+ for _, urlNode := range urlContextNodes {
+ toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", urlNode)
+ }
+ out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
+ }
+ }
+
+ return attachDefaultSafetySettings(out, "request.safetySettings")
+}
+
+func itoa(i int) string { return fmt.Sprintf("%d", i) }
+
+func attachDefaultSafetySettings(rawJSON []byte, path string) []byte {
+ if gjson.GetBytes(rawJSON, path).Exists() {
+ return rawJSON
+ }
+ defaults := []map[string]string{
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
+ {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
+ }
+ out, err := sjson.SetBytes(rawJSON, path, defaults)
+ if err != nil {
+ return rawJSON
+ }
+ return out
+}
diff --git a/internal/adapter/provider/antigravity/request.go b/internal/adapter/provider/antigravity/request.go
index e4b45748..0e37f81c 100644
--- a/internal/adapter/provider/antigravity/request.go
+++ b/internal/adapter/provider/antigravity/request.go
@@ -1,11 +1,24 @@
package antigravity
import (
+ "crypto/sha256"
+ "encoding/binary"
"encoding/json"
"fmt"
+ "math/rand"
+ "strconv"
"strings"
+ "sync"
+ "time"
"github.com/google/uuid"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+var (
+ randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
+ randSourceMutex sync.Mutex
)
// RequestConfig holds resolved request configuration (like Antigravity-Manager)
@@ -178,6 +191,11 @@ func wrapV1InternalRequest(body []byte, projectID, originalModel, mappedModel, s
// Remove model field from inner request if present (will be at top level)
delete(innerRequest, "model")
+ // Strip v1internal wrapper fields if client passed them through
+ delete(innerRequest, "project")
+ delete(innerRequest, "requestId")
+ delete(innerRequest, "requestType")
+ delete(innerRequest, "userAgent")
// Resolve request configuration (like Antigravity-Manager)
toolsForDetection := toolsForConfig
@@ -215,18 +233,22 @@ func wrapV1InternalRequest(body []byte, projectID, originalModel, mappedModel, s
// Deep clean [undefined] strings (Cherry Studio client common injection)
deepCleanUndefined(innerRequest)
- // [Safety Settings] Inject safety settings from environment variable (like Antigravity-Manager)
- safetyThreshold := GetSafetyThresholdFromEnv()
- innerRequest["safetySettings"] = BuildSafetySettingsMap(safetyThreshold)
+ // [Safety Settings] Antigravity v1internal does not accept request.safetySettings
+ delete(innerRequest, "safetySettings")
- // [SessionID Support] If metadata.user_id was provided, use it as sessionId (like Antigravity-Manager)
- if sessionID != "" {
- innerRequest["sessionId"] = sessionID
+ // [SessionID Support] Use metadata.user_id if provided, otherwise generate a stable session id
+ if sessionID == "" {
+ sessionID = generateStableSessionID(body)
}
+ innerRequest["sessionId"] = sessionID
// Generate UUID requestId (like Antigravity-Manager)
requestID := fmt.Sprintf("agent-%s", uuid.New().String())
+ if strings.TrimSpace(projectID) == "" {
+ projectID = generateProjectID()
+ }
+
wrapped := map[string]interface{}{
"project": projectID,
"requestId": requestID,
@@ -236,7 +258,130 @@ func wrapV1InternalRequest(body []byte, projectID, originalModel, mappedModel, s
"requestType": config.RequestType,
}
- return json.Marshal(wrapped)
+ payload, err := json.Marshal(wrapped)
+ if err != nil {
+ return nil, err
+ }
+ payload = applyAntigravityRequestTuning(payload, config.FinalModel)
+ return payload, nil
+}
+
+// finalizeOpenAIWrappedRequest ensures an OpenAI->Antigravity converted request
+// has required envelope fields (project/requestId/sessionId/userAgent/requestType),
+// and applies Antigravity request tuning.
+func finalizeOpenAIWrappedRequest(payload []byte, projectID, modelName, sessionID string) []byte {
+ if len(payload) == 0 {
+ return payload
+ }
+ if strings.TrimSpace(projectID) == "" {
+ projectID = generateProjectID()
+ }
+ if sessionID == "" {
+ sessionID = generateStableSessionID(payload)
+ }
+
+ out := payload
+ out, _ = sjson.SetBytes(out, "project", projectID)
+ out, _ = sjson.SetBytes(out, "requestId", fmt.Sprintf("agent-%s", uuid.New().String()))
+ out, _ = sjson.SetBytes(out, "requestType", "agent")
+ out, _ = sjson.SetBytes(out, "userAgent", "antigravity")
+ out, _ = sjson.SetBytes(out, "model", modelName)
+ out, _ = sjson.DeleteBytes(out, "request.safetySettings")
+
+ // Move toolConfig to request.toolConfig if needed
+ if toolConfig := gjson.GetBytes(out, "toolConfig"); toolConfig.Exists() && !gjson.GetBytes(out, "request.toolConfig").Exists() {
+ out, _ = sjson.SetRawBytes(out, "request.toolConfig", []byte(toolConfig.Raw))
+ out, _ = sjson.DeleteBytes(out, "toolConfig")
+ }
+
+ // Ensure sessionId
+ out, _ = sjson.SetBytes(out, "request.sessionId", sessionID)
+ return applyAntigravityRequestTuning(out, modelName)
+}
+
+const antigravitySystemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
+
+func applyAntigravityRequestTuning(payload []byte, modelName string) []byte {
+ if len(payload) == 0 {
+ return payload
+ }
+ strJSON := string(payload)
+ paths := make([]string, 0)
+ Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
+ for _, p := range paths {
+ if !strings.HasSuffix(p, "parametersJsonSchema") {
+ continue
+ }
+ if renamed, err := RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters"); err == nil {
+ strJSON = renamed
+ }
+ }
+
+ if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
+ strJSON = CleanJSONSchemaForAntigravity(strJSON)
+ } else {
+ strJSON = CleanJSONSchemaForGemini(strJSON)
+ }
+
+ payload = []byte(strJSON)
+
+ if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
+ partsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
+ payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
+ payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", antigravitySystemInstruction)
+ payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", antigravitySystemInstruction))
+ if partsResult.Exists() && partsResult.IsArray() {
+ for _, part := range partsResult.Array() {
+ payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(part.Raw))
+ }
+ }
+ }
+
+ if strings.Contains(modelName, "claude") {
+ payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
+ } else {
+ payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
+ }
+
+ return payload
+}
+
+func generateSessionID() string {
+ randSourceMutex.Lock()
+ n := randSource.Int63n(9_000_000_000_000_000_000)
+ randSourceMutex.Unlock()
+ return "-" + strconv.FormatInt(n, 10)
+}
+
+func generateStableSessionID(payload []byte) string {
+ contents := gjson.GetBytes(payload, "request.contents")
+ if !contents.IsArray() {
+ contents = gjson.GetBytes(payload, "contents")
+ }
+ if contents.IsArray() {
+ for _, content := range contents.Array() {
+ if content.Get("role").String() == "user" {
+ text := content.Get("parts.0.text").String()
+ if text != "" {
+ h := sha256.Sum256([]byte(text))
+ n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
+ return "-" + strconv.FormatInt(n, 10)
+ }
+ }
+ }
+ }
+ return generateSessionID()
+}
+
+func generateProjectID() string {
+ adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
+ nouns := []string{"fuze", "wave", "spark", "flow", "core"}
+ randSourceMutex.Lock()
+ adj := adjectives[randSource.Intn(len(adjectives))]
+ noun := nouns[randSource.Intn(len(nouns))]
+ randSourceMutex.Unlock()
+ randomPart := strings.ToLower(uuid.NewString())[:5]
+ return adj + "-" + noun + "-" + randomPart
}
// stripThinkingFromClaude removes thinking config and blocks to retry without thinking (like Manager 400 retry)
diff --git a/internal/adapter/provider/antigravity/request_test.go b/internal/adapter/provider/antigravity/request_test.go
new file mode 100644
index 00000000..6578cfdb
--- /dev/null
+++ b/internal/adapter/provider/antigravity/request_test.go
@@ -0,0 +1,44 @@
+package antigravity
+
+import (
+ "testing"
+
+ "github.com/tidwall/gjson"
+)
+
+func TestApplyAntigravityRequestTuning(t *testing.T) {
+ input := `{
+ "request": {
+ "systemInstruction": {
+ "parts": [{"text":"original"}]
+ },
+ "tools": [{
+ "functionDeclarations": [{
+ "name": "t1",
+ "parametersJsonSchema": {"type":"object","properties":{"x":{"type":"string"}}}
+ }]
+ }]
+ },
+ "model": "claude-sonnet-4-5"
+}`
+ out := applyAntigravityRequestTuning([]byte(input), "claude-sonnet-4-5")
+
+ if !gjson.GetBytes(out, "request.systemInstruction.role").Exists() {
+ t.Fatalf("expected systemInstruction.role to be set")
+ }
+ if gjson.GetBytes(out, "request.systemInstruction.parts.0.text").String() == "" {
+ t.Fatalf("expected systemInstruction parts[0].text to be injected")
+ }
+ if gjson.GetBytes(out, "request.systemInstruction.parts.1.text").String() == "" {
+ t.Fatalf("expected systemInstruction parts[1].text to be injected")
+ }
+ if gjson.GetBytes(out, "request.toolConfig.functionCallingConfig.mode").String() != "VALIDATED" {
+ t.Fatalf("expected toolConfig.functionCallingConfig.mode=VALIDATED")
+ }
+ if gjson.GetBytes(out, "request.tools.0.functionDeclarations.0.parametersJsonSchema").Exists() {
+ t.Fatalf("expected parametersJsonSchema to be renamed")
+ }
+ if !gjson.GetBytes(out, "request.tools.0.functionDeclarations.0.parameters").Exists() {
+ t.Fatalf("expected parameters to exist after rename")
+ }
+}
diff --git a/internal/adapter/provider/antigravity/response.go b/internal/adapter/provider/antigravity/response.go
index b8ef97c7..53db1f62 100644
--- a/internal/adapter/provider/antigravity/response.go
+++ b/internal/adapter/provider/antigravity/response.go
@@ -1,10 +1,14 @@
package antigravity
import (
+ "bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
// Response headers to exclude when copying
@@ -114,6 +118,191 @@ func isRetryableStatusCode(code int) bool {
}
}
+// convertStreamToNonStream collects Gemini SSE stream into a single response payload.
+// Ported from CLIProxyAPI Antigravity convertStreamToNonStream.
+func convertStreamToNonStream(stream []byte) []byte {
+ responseTemplate := ""
+ var traceID string
+ var finishReason string
+ var modelVersion string
+ var responseID string
+ var role string
+ var usageRaw string
+ parts := make([]map[string]interface{}, 0)
+ var pendingKind string
+ var pendingText strings.Builder
+ var pendingThoughtSig string
+
+ flushPending := func() {
+ if pendingKind == "" {
+ return
+ }
+ text := pendingText.String()
+ switch pendingKind {
+ case "text":
+ if strings.TrimSpace(text) == "" {
+ pendingKind = ""
+ pendingText.Reset()
+ pendingThoughtSig = ""
+ return
+ }
+ parts = append(parts, map[string]interface{}{"text": text})
+ case "thought":
+ if strings.TrimSpace(text) == "" && pendingThoughtSig == "" {
+ pendingKind = ""
+ pendingText.Reset()
+ pendingThoughtSig = ""
+ return
+ }
+ part := map[string]interface{}{"thought": true}
+ part["text"] = text
+ if pendingThoughtSig != "" {
+ part["thoughtSignature"] = pendingThoughtSig
+ }
+ parts = append(parts, part)
+ }
+ pendingKind = ""
+ pendingText.Reset()
+ pendingThoughtSig = ""
+ }
+
+ normalizePart := func(partResult gjson.Result) map[string]interface{} {
+ var m map[string]interface{}
+ _ = json.Unmarshal([]byte(partResult.Raw), &m)
+ if m == nil {
+ m = map[string]interface{}{}
+ }
+ sig := partResult.Get("thoughtSignature").String()
+ if sig == "" {
+ sig = partResult.Get("thought_signature").String()
+ }
+ if sig != "" {
+ m["thoughtSignature"] = sig
+ delete(m, "thought_signature")
+ }
+ if inlineData, ok := m["inline_data"]; ok {
+ m["inlineData"] = inlineData
+ delete(m, "inline_data")
+ }
+ return m
+ }
+
+ for _, line := range bytes.Split(stream, []byte("\n")) {
+ trimmed := bytes.TrimSpace(line)
+ trimmed = bytes.TrimPrefix(trimmed, []byte("data: "))
+ if len(trimmed) == 0 || !gjson.ValidBytes(trimmed) {
+ continue
+ }
+
+ root := gjson.ParseBytes(trimmed)
+ responseNode := root.Get("response")
+ if !responseNode.Exists() {
+ if root.Get("candidates").Exists() {
+ responseNode = root
+ } else {
+ continue
+ }
+ }
+ responseTemplate = responseNode.Raw
+
+ if traceResult := root.Get("traceId"); traceResult.Exists() && traceResult.String() != "" {
+ traceID = traceResult.String()
+ }
+
+ if roleResult := responseNode.Get("candidates.0.content.role"); roleResult.Exists() {
+ role = roleResult.String()
+ }
+
+ if finishResult := responseNode.Get("candidates.0.finishReason"); finishResult.Exists() && finishResult.String() != "" {
+ finishReason = finishResult.String()
+ }
+
+ if modelResult := responseNode.Get("modelVersion"); modelResult.Exists() && modelResult.String() != "" {
+ modelVersion = modelResult.String()
+ }
+ if responseIDResult := responseNode.Get("responseId"); responseIDResult.Exists() && responseIDResult.String() != "" {
+ responseID = responseIDResult.String()
+ }
+ if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() {
+ usageRaw = usageResult.Raw
+ } else if usageMetadataResult := root.Get("usageMetadata"); usageMetadataResult.Exists() {
+ usageRaw = usageMetadataResult.Raw
+ }
+
+ if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() {
+ for _, part := range partsResult.Array() {
+ hasFunctionCall := part.Get("functionCall").Exists()
+ hasInlineData := part.Get("inlineData").Exists() || part.Get("inline_data").Exists()
+ sig := part.Get("thoughtSignature").String()
+ if sig == "" {
+ sig = part.Get("thought_signature").String()
+ }
+ text := part.Get("text").String()
+ thought := part.Get("thought").Bool()
+
+ if hasFunctionCall || hasInlineData {
+ flushPending()
+ parts = append(parts, normalizePart(part))
+ continue
+ }
+
+ if thought || part.Get("text").Exists() {
+ kind := "text"
+ if thought {
+ kind = "thought"
+ }
+ if pendingKind != "" && pendingKind != kind {
+ flushPending()
+ }
+ pendingKind = kind
+ pendingText.WriteString(text)
+ if kind == "thought" && sig != "" {
+ pendingThoughtSig = sig
+ }
+ continue
+ }
+
+ flushPending()
+ parts = append(parts, normalizePart(part))
+ }
+ }
+ }
+ flushPending()
+
+ if responseTemplate == "" {
+ responseTemplate = `{"candidates":[{"content":{"role":"model","parts":[]}}]}`
+ }
+
+ partsJSON, _ := json.Marshal(parts)
+ responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
+ if role != "" {
+ responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
+ }
+ if finishReason != "" {
+ responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
+ }
+ if modelVersion != "" {
+ responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
+ }
+ if responseID != "" {
+ responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
+ }
+ if usageRaw != "" {
+ responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
+ } else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
+ responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
+ responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
+ responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
+ }
+
+ output := `{"response":{},"traceId":""}`
+ output, _ = sjson.SetRaw(output, "response", responseTemplate)
+ if traceID != "" {
+ output, _ = sjson.Set(output, "traceId", traceID)
+ }
+ return []byte(output)
+}
+
// convertGeminiToClaudeResponse converts a non-streaming Gemini response to Claude format
// (like Antigravity-Manager's response conversion)
func convertGeminiToClaudeResponse(geminiBody []byte, requestModel string) ([]byte, error) {
diff --git a/internal/adapter/provider/antigravity/schema_cleaner.go b/internal/adapter/provider/antigravity/schema_cleaner.go
new file mode 100644
index 00000000..41f4c371
--- /dev/null
+++ b/internal/adapter/provider/antigravity/schema_cleaner.go
@@ -0,0 +1,731 @@
+// Package util provides utility functions for the CLI Proxy API server.
+package antigravity
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
+
+const placeholderReasonDescription = "Brief explanation of why you are calling this tool"
+
+// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
+// It handles unsupported keywords, type flattening, and schema simplification while preserving
+// semantic information as description hints.
+func CleanJSONSchemaForAntigravity(jsonStr string) string {
+ return cleanJSONSchema(jsonStr, true)
+}
+
+// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini tool calling.
+// It removes unsupported keywords and simplifies schemas, without adding empty-schema placeholders.
+func CleanJSONSchemaForGemini(jsonStr string) string {
+ return cleanJSONSchema(jsonStr, false)
+}
+
+// cleanJSONSchema performs the core cleaning operations on the JSON schema.
+func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
+ // Phase 1: Convert and add hints
+ jsonStr = convertRefsToHints(jsonStr)
+ jsonStr = convertConstToEnum(jsonStr)
+ jsonStr = convertEnumValuesToStrings(jsonStr)
+ jsonStr = addEnumHints(jsonStr)
+ jsonStr = addAdditionalPropertiesHints(jsonStr)
+ jsonStr = moveConstraintsToDescription(jsonStr)
+
+ // Phase 2: Flatten complex structures
+ jsonStr = mergeAllOf(jsonStr)
+ jsonStr = flattenAnyOfOneOf(jsonStr)
+ jsonStr = flattenTypeArrays(jsonStr)
+
+ // Phase 3: Cleanup
+ jsonStr = removeUnsupportedKeywords(jsonStr)
+ if !addPlaceholder {
+ // Gemini schema cleanup: remove nullable/title and placeholder-only fields.
+ jsonStr = removeKeywords(jsonStr, []string{"nullable", "title"})
+ jsonStr = removePlaceholderFields(jsonStr)
+ }
+ jsonStr = cleanupRequiredFields(jsonStr)
+ // Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
+ if addPlaceholder {
+ jsonStr = addEmptySchemaPlaceholder(jsonStr)
+ }
+
+ return jsonStr
+}
+
+// removeKeywords removes all occurrences of specified keywords from the JSON schema.
+func removeKeywords(jsonStr string, keywords []string) string {
+ for _, key := range keywords {
+ for _, p := range findPaths(jsonStr, key) {
+ if isPropertyDefinition(trimSuffix(p, "."+key)) {
+ continue
+ }
+ jsonStr, _ = sjson.Delete(jsonStr, p)
+ }
+ }
+ return jsonStr
+}
+
+// removePlaceholderFields removes placeholder-only properties ("_" and "reason") and their required entries.
+func removePlaceholderFields(jsonStr string) string {
+ // Remove "_" placeholder properties.
+ paths := findPaths(jsonStr, "_")
+ sortByDepth(paths)
+ for _, p := range paths {
+ if !strings.HasSuffix(p, ".properties._") {
+ continue
+ }
+ jsonStr, _ = sjson.Delete(jsonStr, p)
+ parentPath := trimSuffix(p, ".properties._")
+ reqPath := joinPath(parentPath, "required")
+ req := gjson.Get(jsonStr, reqPath)
+ if req.IsArray() {
+ var filtered []string
+ for _, r := range req.Array() {
+ if r.String() != "_" {
+ filtered = append(filtered, r.String())
+ }
+ }
+ if len(filtered) == 0 {
+ jsonStr, _ = sjson.Delete(jsonStr, reqPath)
+ } else {
+ jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
+ }
+ }
+ }
+
+ // Remove placeholder-only "reason" objects.
+ reasonPaths := findPaths(jsonStr, "reason")
+ sortByDepth(reasonPaths)
+ for _, p := range reasonPaths {
+ if !strings.HasSuffix(p, ".properties.reason") {
+ continue
+ }
+ parentPath := trimSuffix(p, ".properties.reason")
+ props := gjson.Get(jsonStr, joinPath(parentPath, "properties"))
+ if !props.IsObject() || len(props.Map()) != 1 {
+ continue
+ }
+ desc := gjson.Get(jsonStr, p+".description").String()
+ if desc != placeholderReasonDescription {
+ continue
+ }
+ jsonStr, _ = sjson.Delete(jsonStr, p)
+ reqPath := joinPath(parentPath, "required")
+ req := gjson.Get(jsonStr, reqPath)
+ if req.IsArray() {
+ var filtered []string
+ for _, r := range req.Array() {
+ if r.String() != "reason" {
+ filtered = append(filtered, r.String())
+ }
+ }
+ if len(filtered) == 0 {
+ jsonStr, _ = sjson.Delete(jsonStr, reqPath)
+ } else {
+ jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
+ }
+ }
+ }
+
+ return jsonStr
+}
+
+// convertRefsToHints converts $ref to description hints (Lazy Hint strategy).
+func convertRefsToHints(jsonStr string) string {
+ paths := findPaths(jsonStr, "$ref")
+ sortByDepth(paths)
+
+ for _, p := range paths {
+ refVal := gjson.Get(jsonStr, p).String()
+ defName := refVal
+ if idx := strings.LastIndex(refVal, "/"); idx >= 0 {
+ defName = refVal[idx+1:]
+ }
+
+ parentPath := trimSuffix(p, ".$ref")
+ hint := fmt.Sprintf("See: %s", defName)
+ if existing := gjson.Get(jsonStr, descriptionPath(parentPath)).String(); existing != "" {
+ hint = fmt.Sprintf("%s (%s)", existing, hint)
+ }
+
+ replacement := `{"type":"object","description":""}`
+ replacement, _ = sjson.Set(replacement, "description", hint)
+ jsonStr = setRawAt(jsonStr, parentPath, replacement)
+ }
+ return jsonStr
+}
+
+func convertConstToEnum(jsonStr string) string {
+ for _, p := range findPaths(jsonStr, "const") {
+ val := gjson.Get(jsonStr, p)
+ if !val.Exists() {
+ continue
+ }
+ enumPath := trimSuffix(p, ".const") + ".enum"
+ if !gjson.Get(jsonStr, enumPath).Exists() {
+ jsonStr, _ = sjson.Set(jsonStr, enumPath, []interface{}{val.Value()})
+ }
+ }
+ return jsonStr
+}
+
+// convertEnumValuesToStrings ensures all enum values are strings and the schema type is set to string.
+// Gemini API requires enum values to be of type string, not numbers or booleans.
+func convertEnumValuesToStrings(jsonStr string) string {
+ for _, p := range findPaths(jsonStr, "enum") {
+ arr := gjson.Get(jsonStr, p)
+ if !arr.IsArray() {
+ continue
+ }
+
+ var stringVals []string
+ for _, item := range arr.Array() {
+ stringVals = append(stringVals, item.String())
+ }
+
+ // Always update enum values to strings and set type to "string"
+ // This ensures compatibility with Antigravity Gemini which only allows enum for STRING type
+ jsonStr, _ = sjson.Set(jsonStr, p, stringVals)
+ parentPath := trimSuffix(p, ".enum")
+ jsonStr, _ = sjson.Set(jsonStr, joinPath(parentPath, "type"), "string")
+ }
+ return jsonStr
+}
+
+func addEnumHints(jsonStr string) string {
+ for _, p := range findPaths(jsonStr, "enum") {
+ arr := gjson.Get(jsonStr, p)
+ if !arr.IsArray() {
+ continue
+ }
+ items := arr.Array()
+ if len(items) <= 1 || len(items) > 10 {
+ continue
+ }
+
+ var vals []string
+ for _, item := range items {
+ vals = append(vals, item.String())
+ }
+ jsonStr = appendHint(jsonStr, trimSuffix(p, ".enum"), "Allowed: "+strings.Join(vals, ", "))
+ }
+ return jsonStr
+}
+
+func addAdditionalPropertiesHints(jsonStr string) string {
+ for _, p := range findPaths(jsonStr, "additionalProperties") {
+ if gjson.Get(jsonStr, p).Type == gjson.False {
+ jsonStr = appendHint(jsonStr, trimSuffix(p, ".additionalProperties"), "No extra properties allowed")
+ }
+ }
+ return jsonStr
+}
+
+var unsupportedConstraints = []string{
+ "minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
+ "pattern", "minItems", "maxItems", "format",
+ "default", "examples", // Claude rejects these in VALIDATED mode
+}
+
+func moveConstraintsToDescription(jsonStr string) string {
+ for _, key := range unsupportedConstraints {
+ for _, p := range findPaths(jsonStr, key) {
+ val := gjson.Get(jsonStr, p)
+ if !val.Exists() || val.IsObject() || val.IsArray() {
+ continue
+ }
+ parentPath := trimSuffix(p, "."+key)
+ if isPropertyDefinition(parentPath) {
+ continue
+ }
+ jsonStr = appendHint(jsonStr, parentPath, fmt.Sprintf("%s: %s", key, val.String()))
+ }
+ }
+ return jsonStr
+}
+
+func mergeAllOf(jsonStr string) string {
+ paths := findPaths(jsonStr, "allOf")
+ sortByDepth(paths)
+
+ for _, p := range paths {
+ allOf := gjson.Get(jsonStr, p)
+ if !allOf.IsArray() {
+ continue
+ }
+ parentPath := trimSuffix(p, ".allOf")
+
+ for _, item := range allOf.Array() {
+ if props := item.Get("properties"); props.IsObject() {
+ props.ForEach(func(key, value gjson.Result) bool {
+ destPath := joinPath(parentPath, "properties."+escapeGJSONPathKey(key.String()))
+ jsonStr, _ = sjson.SetRaw(jsonStr, destPath, value.Raw)
+ return true
+ })
+ }
+ if req := item.Get("required"); req.IsArray() {
+ reqPath := joinPath(parentPath, "required")
+ current := getStrings(jsonStr, reqPath)
+ for _, r := range req.Array() {
+ if s := r.String(); !contains(current, s) {
+ current = append(current, s)
+ }
+ }
+ jsonStr, _ = sjson.Set(jsonStr, reqPath, current)
+ }
+ }
+ jsonStr, _ = sjson.Delete(jsonStr, p)
+ }
+ return jsonStr
+}
+
+func flattenAnyOfOneOf(jsonStr string) string {
+ for _, key := range []string{"anyOf", "oneOf"} {
+ paths := findPaths(jsonStr, key)
+ sortByDepth(paths)
+
+ for _, p := range paths {
+ arr := gjson.Get(jsonStr, p)
+ if !arr.IsArray() || len(arr.Array()) == 0 {
+ continue
+ }
+
+ parentPath := trimSuffix(p, "."+key)
+ parentDesc := gjson.Get(jsonStr, descriptionPath(parentPath)).String()
+
+ items := arr.Array()
+ bestIdx, allTypes := selectBest(items)
+ selected := items[bestIdx].Raw
+
+ if parentDesc != "" {
+ selected = mergeDescriptionRaw(selected, parentDesc)
+ }
+
+ if len(allTypes) > 1 {
+ hint := "Accepts: " + strings.Join(allTypes, " | ")
+ selected = appendHintRaw(selected, hint)
+ }
+
+ jsonStr = setRawAt(jsonStr, parentPath, selected)
+ }
+ }
+ return jsonStr
+}
+
+func selectBest(items []gjson.Result) (bestIdx int, types []string) {
+ bestScore := -1
+ for i, item := range items {
+ t := item.Get("type").String()
+ score := 0
+
+ switch {
+ case t == "object" || item.Get("properties").Exists():
+ score, t = 3, orDefault(t, "object")
+ case t == "array" || item.Get("items").Exists():
+ score, t = 2, orDefault(t, "array")
+ case t != "" && t != "null":
+ score = 1
+ default:
+ t = orDefault(t, "null")
+ }
+
+ if t != "" {
+ types = append(types, t)
+ }
+ if score > bestScore {
+ bestScore, bestIdx = score, i
+ }
+ }
+ return
+}
+
+func flattenTypeArrays(jsonStr string) string {
+ paths := findPaths(jsonStr, "type")
+ sortByDepth(paths)
+
+ nullableFields := make(map[string][]string)
+
+ for _, p := range paths {
+ res := gjson.Get(jsonStr, p)
+ if !res.IsArray() || len(res.Array()) == 0 {
+ continue
+ }
+
+ hasNull := false
+ var nonNullTypes []string
+ for _, item := range res.Array() {
+ s := item.String()
+ if s == "null" {
+ hasNull = true
+ } else if s != "" {
+ nonNullTypes = append(nonNullTypes, s)
+ }
+ }
+
+ firstType := "string"
+ if len(nonNullTypes) > 0 {
+ firstType = nonNullTypes[0]
+ }
+
+ jsonStr, _ = sjson.Set(jsonStr, p, firstType)
+
+ parentPath := trimSuffix(p, ".type")
+ if len(nonNullTypes) > 1 {
+ hint := "Accepts: " + strings.Join(nonNullTypes, " | ")
+ jsonStr = appendHint(jsonStr, parentPath, hint)
+ }
+
+ if hasNull {
+ parts := splitGJSONPath(p)
+ if len(parts) >= 3 && parts[len(parts)-3] == "properties" {
+ fieldNameEscaped := parts[len(parts)-2]
+ fieldName := unescapeGJSONPathKey(fieldNameEscaped)
+ objectPath := strings.Join(parts[:len(parts)-3], ".")
+ nullableFields[objectPath] = append(nullableFields[objectPath], fieldName)
+
+ propPath := joinPath(objectPath, "properties."+fieldNameEscaped)
+ jsonStr = appendHint(jsonStr, propPath, "(nullable)")
+ }
+ }
+ }
+
+ for objectPath, fields := range nullableFields {
+ reqPath := joinPath(objectPath, "required")
+ req := gjson.Get(jsonStr, reqPath)
+ if !req.IsArray() {
+ continue
+ }
+
+ var filtered []string
+ for _, r := range req.Array() {
+ if !contains(fields, r.String()) {
+ filtered = append(filtered, r.String())
+ }
+ }
+
+ if len(filtered) == 0 {
+ jsonStr, _ = sjson.Delete(jsonStr, reqPath)
+ } else {
+ jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
+ }
+ }
+ return jsonStr
+}
+
+func removeUnsupportedKeywords(jsonStr string) string {
+ keywords := append(unsupportedConstraints,
+ "$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
+ "propertyNames", // Gemini doesn't support property name validation
+ )
+ for _, key := range keywords {
+ for _, p := range findPaths(jsonStr, key) {
+ if isPropertyDefinition(trimSuffix(p, "."+key)) {
+ continue
+ }
+ jsonStr, _ = sjson.Delete(jsonStr, p)
+ }
+ }
+ // Remove x-* extension fields (e.g., x-google-enum-descriptions) that are not supported by Gemini API
+ jsonStr = removeExtensionFields(jsonStr)
+ return jsonStr
+}
+
+// removeExtensionFields removes all x-* extension fields from the JSON schema.
+// These are OpenAPI/JSON Schema extension fields that Google APIs don't recognize.
+func removeExtensionFields(jsonStr string) string {
+ var paths []string
+ walkForExtensions(gjson.Parse(jsonStr), "", &paths)
+ // walkForExtensions returns paths in a way that deeper paths are added before their ancestors
+ // when they are not deleted wholesale, but since we skip children of deleted x-* nodes,
+ // any collected path is safe to delete. We still use DeleteBytes for efficiency.
+
+ b := []byte(jsonStr)
+ for _, p := range paths {
+ b, _ = sjson.DeleteBytes(b, p)
+ }
+ return string(b)
+}
+
+func walkForExtensions(value gjson.Result, path string, paths *[]string) {
+ if value.IsArray() {
+ arr := value.Array()
+ for i := len(arr) - 1; i >= 0; i-- {
+ walkForExtensions(arr[i], joinPath(path, strconv.Itoa(i)), paths)
+ }
+ return
+ }
+
+ if value.IsObject() {
+ value.ForEach(func(key, val gjson.Result) bool {
+ keyStr := key.String()
+ safeKey := escapeGJSONPathKey(keyStr)
+ childPath := joinPath(path, safeKey)
+
+ // If it's an extension field, we delete it and don't need to look at its children.
+ if strings.HasPrefix(keyStr, "x-") && !isPropertyDefinition(path) {
+ *paths = append(*paths, childPath)
+ return true
+ }
+
+ walkForExtensions(val, childPath, paths)
+ return true
+ })
+ }
+}
+
+func cleanupRequiredFields(jsonStr string) string {
+ for _, p := range findPaths(jsonStr, "required") {
+ parentPath := trimSuffix(p, ".required")
+ propsPath := joinPath(parentPath, "properties")
+
+ req := gjson.Get(jsonStr, p)
+ props := gjson.Get(jsonStr, propsPath)
+ if !req.IsArray() || !props.IsObject() {
+ continue
+ }
+
+ var valid []string
+ for _, r := range req.Array() {
+ key := r.String()
+ if props.Get(escapeGJSONPathKey(key)).Exists() {
+ valid = append(valid, key)
+ }
+ }
+
+ if len(valid) != len(req.Array()) {
+ if len(valid) == 0 {
+ jsonStr, _ = sjson.Delete(jsonStr, p)
+ } else {
+ jsonStr, _ = sjson.Set(jsonStr, p, valid)
+ }
+ }
+ }
+ return jsonStr
+}
+
+// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
+// Claude VALIDATED mode requires at least one required property in tool schemas.
+func addEmptySchemaPlaceholder(jsonStr string) string {
+ // Find all "type" fields
+ paths := findPaths(jsonStr, "type")
+
+ // Process from deepest to shallowest (to handle nested objects properly)
+ sortByDepth(paths)
+
+ for _, p := range paths {
+ typeVal := gjson.Get(jsonStr, p)
+ if typeVal.String() != "object" {
+ continue
+ }
+
+ // Get the parent path (the object containing "type")
+ parentPath := trimSuffix(p, ".type")
+
+ // Check if properties exists and is empty or missing
+ propsPath := joinPath(parentPath, "properties")
+ propsVal := gjson.Get(jsonStr, propsPath)
+ reqPath := joinPath(parentPath, "required")
+ reqVal := gjson.Get(jsonStr, reqPath)
+ hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
+
+ needsPlaceholder := false
+ if !propsVal.Exists() {
+ // No properties field at all
+ needsPlaceholder = true
+ } else if propsVal.IsObject() && len(propsVal.Map()) == 0 {
+ // Empty properties object
+ needsPlaceholder = true
+ }
+
+ if needsPlaceholder {
+ // Add placeholder "reason" property
+ reasonPath := joinPath(propsPath, "reason")
+ jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
+ jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", placeholderReasonDescription)
+
+ // Add to required array
+ jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
+ continue
+ }
+
+ // If schema has properties but none are required, add a minimal placeholder.
+ if propsVal.IsObject() && !hasRequiredProperties {
+ // DO NOT add placeholder if it's a top-level schema (parentPath is empty)
+ // or if we've already added a placeholder reason above.
+ if parentPath == "" {
+ continue
+ }
+ placeholderPath := joinPath(propsPath, "_")
+ if !gjson.Get(jsonStr, placeholderPath).Exists() {
+ jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
+ }
+ jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
+ }
+ }
+
+ return jsonStr
+}
+
+// --- Helpers ---
+
+func findPaths(jsonStr, field string) []string {
+ var paths []string
+ Walk(gjson.Parse(jsonStr), "", field, &paths)
+ return paths
+}
+
+func sortByDepth(paths []string) {
+ sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
+}
+
+func trimSuffix(path, suffix string) string {
+ if path == strings.TrimPrefix(suffix, ".") {
+ return ""
+ }
+ return strings.TrimSuffix(path, suffix)
+}
+
+func joinPath(base, suffix string) string {
+ if base == "" {
+ return suffix
+ }
+ return base + "." + suffix
+}
+
+func setRawAt(jsonStr, path, value string) string {
+ if path == "" {
+ return value
+ }
+ result, _ := sjson.SetRaw(jsonStr, path, value)
+ return result
+}
+
+func isPropertyDefinition(path string) bool {
+ return path == "properties" || strings.HasSuffix(path, ".properties")
+}
+
+func descriptionPath(parentPath string) string {
+ if parentPath == "" || parentPath == "@this" {
+ return "description"
+ }
+ return parentPath + ".description"
+}
+
+func appendHint(jsonStr, parentPath, hint string) string {
+ descPath := parentPath + ".description"
+ if parentPath == "" || parentPath == "@this" {
+ descPath = "description"
+ }
+ existing := gjson.Get(jsonStr, descPath).String()
+ if existing != "" {
+ hint = fmt.Sprintf("%s (%s)", existing, hint)
+ }
+ jsonStr, _ = sjson.Set(jsonStr, descPath, hint)
+ return jsonStr
+}
+
+func appendHintRaw(jsonRaw, hint string) string {
+ existing := gjson.Get(jsonRaw, "description").String()
+ if existing != "" {
+ hint = fmt.Sprintf("%s (%s)", existing, hint)
+ }
+ jsonRaw, _ = sjson.Set(jsonRaw, "description", hint)
+ return jsonRaw
+}
+
+func getStrings(jsonStr, path string) []string {
+ var result []string
+ if arr := gjson.Get(jsonStr, path); arr.IsArray() {
+ for _, r := range arr.Array() {
+ result = append(result, r.String())
+ }
+ }
+ return result
+}
+
+func contains(slice []string, item string) bool {
+ for _, s := range slice {
+ if s == item {
+ return true
+ }
+ }
+ return false
+}
+
+func orDefault(val, def string) string {
+ if val == "" {
+ return def
+ }
+ return val
+}
+
+func escapeGJSONPathKey(key string) string {
+ return gjsonPathKeyReplacer.Replace(key)
+}
+
+func unescapeGJSONPathKey(key string) string {
+ if !strings.Contains(key, "\\") {
+ return key
+ }
+ var b strings.Builder
+ b.Grow(len(key))
+ for i := 0; i < len(key); i++ {
+ if key[i] == '\\' && i+1 < len(key) {
+ i++
+ b.WriteByte(key[i])
+ continue
+ }
+ b.WriteByte(key[i])
+ }
+ return b.String()
+}
+
+func splitGJSONPath(path string) []string {
+ if path == "" {
+ return nil
+ }
+
+ parts := make([]string, 0, strings.Count(path, ".")+1)
+ var b strings.Builder
+ b.Grow(len(path))
+
+ for i := 0; i < len(path); i++ {
+ c := path[i]
+ if c == '\\' && i+1 < len(path) {
+ b.WriteByte('\\')
+ i++
+ b.WriteByte(path[i])
+ continue
+ }
+ if c == '.' {
+ parts = append(parts, b.String())
+ b.Reset()
+ continue
+ }
+ b.WriteByte(c)
+ }
+ parts = append(parts, b.String())
+ return parts
+}
+
+func mergeDescriptionRaw(schemaRaw, parentDesc string) string {
+ childDesc := gjson.Get(schemaRaw, "description").String()
+ switch {
+ case childDesc == "":
+ schemaRaw, _ = sjson.Set(schemaRaw, "description", parentDesc)
+ return schemaRaw
+ case childDesc == parentDesc:
+ return schemaRaw
+ default:
+ combined := fmt.Sprintf("%s (%s)", parentDesc, childDesc)
+ schemaRaw, _ = sjson.Set(schemaRaw, "description", combined)
+ return schemaRaw
+ }
+}
diff --git a/internal/adapter/provider/antigravity/service.go b/internal/adapter/provider/antigravity/service.go
index 96eb66b1..e5d11d1f 100644
--- a/internal/adapter/provider/antigravity/service.go
+++ b/internal/adapter/provider/antigravity/service.go
@@ -27,8 +27,8 @@ const (
UserAgentLoadCodeAssist = "antigravity/windows/amd64"
// fetchAvailableModels 使用带版本号的 User-Agent
UserAgentFetchModels = "antigravity/1.11.3 Darwin/arm64"
- // 代理请求使用的 User-Agent
- AntigravityUserAgent = "antigravity/1.11.9 windows/amd64"
+ // 代理请求使用的 User-Agent (CLIProxyAPI default)
+ AntigravityUserAgent = "antigravity/1.104.0 darwin/arm64"
// 默认 Project ID (当 API 未返回时使用)
DefaultProjectID = "bamboo-precept-lgxtn"
diff --git a/internal/adapter/provider/antigravity/transform_request.go b/internal/adapter/provider/antigravity/transform_request.go
index 8a9609ac..e3a2f536 100644
--- a/internal/adapter/provider/antigravity/transform_request.go
+++ b/internal/adapter/provider/antigravity/transform_request.go
@@ -74,13 +74,7 @@ func TransformClaudeToGemini(
genConfig := buildGenerationConfig(&claudeReq, mappedModel, stream, hasThinking)
geminiReq["generationConfig"] = genConfig
- // 5.5 Safety Settings (configurable via environment)
- // Reference: Antigravity-Manager's build_safety_settings
- safetyThreshold := GetSafetyThresholdFromEnv()
- safetySettings := BuildSafetySettingsMap(safetyThreshold)
- geminiReq["safetySettings"] = safetySettings
-
- // 5.6 Deep clean [undefined] strings (Cherry Studio injection fix)
+ // 5.5 Deep clean [undefined] strings (Cherry Studio injection fix)
// Reference: Antigravity-Manager line 278
deepCleanUndefined(geminiReq)
diff --git a/internal/adapter/provider/antigravity/translator_helpers.go b/internal/adapter/provider/antigravity/translator_helpers.go
new file mode 100644
index 00000000..2303b1db
--- /dev/null
+++ b/internal/adapter/provider/antigravity/translator_helpers.go
@@ -0,0 +1,231 @@
+// Package util provides utility functions for the CLI Proxy API server.
+// It includes helper functions for JSON manipulation, proxy configuration,
+// and other common operations used across the application.
+package antigravity
+
+import (
+ "bytes"
+ "fmt"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+// Walk recursively traverses a JSON structure to find all occurrences of a specific field.
+// It builds paths to each occurrence and adds them to the provided paths slice.
+//
+// Parameters:
+// - value: The gjson.Result object to traverse
+// - path: The current path in the JSON structure (empty string for root)
+// - field: The field name to search for
+// - paths: Pointer to a slice where found paths will be stored
+//
+// The function works recursively, building dot-notation paths to each occurrence
+// of the specified field throughout the JSON structure.
+func Walk(value gjson.Result, path, field string, paths *[]string) {
+ switch value.Type {
+ case gjson.JSON:
+ // For JSON objects and arrays, iterate through each child
+ value.ForEach(func(key, val gjson.Result) bool {
+ var childPath string
+ // Escape special characters for gjson/sjson path syntax
+ // . -> \.
+ // * -> \*
+ // ? -> \?
+ var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
+ safeKey := keyReplacer.Replace(key.String())
+
+ if path == "" {
+ childPath = safeKey
+ } else {
+ childPath = path + "." + safeKey
+ }
+ if key.String() == field {
+ *paths = append(*paths, childPath)
+ }
+ Walk(val, childPath, field, paths)
+ return true
+ })
+ case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
+ // Terminal types - no further traversal needed
+ }
+}
+
+// RenameKey renames a key in a JSON string by moving its value to a new key path
+// and then deleting the old key path.
+//
+// Parameters:
+// - jsonStr: The JSON string to modify
+// - oldKeyPath: The dot-notation path to the key that should be renamed
+// - newKeyPath: The dot-notation path where the value should be moved to
+//
+// Returns:
+// - string: The modified JSON string with the key renamed
+// - error: An error if the operation fails
+//
+// The function performs the rename in two steps:
+// 1. Sets the value at the new key path
+// 2. Deletes the old key path
+func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
+ value := gjson.Get(jsonStr, oldKeyPath)
+
+ if !value.Exists() {
+ return "", fmt.Errorf("old key '%s' does not exist", oldKeyPath)
+ }
+
+ interimJson, err := sjson.SetRaw(jsonStr, newKeyPath, value.Raw)
+ if err != nil {
+ return "", fmt.Errorf("failed to set new key '%s': %w", newKeyPath, err)
+ }
+
+ finalJson, err := sjson.Delete(interimJson, oldKeyPath)
+ if err != nil {
+ return "", fmt.Errorf("failed to delete old key '%s': %w", oldKeyPath, err)
+ }
+
+ return finalJson, nil
+}
+
+func DeleteKey(jsonStr, keyName string) string {
+ paths := make([]string, 0)
+ Walk(gjson.Parse(jsonStr), "", keyName, &paths)
+ for _, p := range paths {
+ jsonStr, _ = sjson.Delete(jsonStr, p)
+ }
+ return jsonStr
+}
+
+// FixJSON converts non-standard JSON that uses single quotes for strings into
+// RFC 8259-compliant JSON by converting those single-quoted strings to
+// double-quoted strings with proper escaping.
+//
+// Examples:
+//
+// {'a': 1, 'b': '2'} => {"a": 1, "b": "2"}
+// {"t": 'He said "hi"'} => {"t": "He said \"hi\""}
+//
+// Rules:
+// - Existing double-quoted JSON strings are preserved as-is.
+// - Single-quoted strings are converted to double-quoted strings.
+// - Inside converted strings, any double quote is escaped (\").
+// - Common backslash escapes (\n, \r, \t, \b, \f, \\) are preserved.
+// - \' inside single-quoted strings becomes a literal ' in the output (no
+// escaping needed inside double quotes).
+// - Unicode escapes (\uXXXX) inside single-quoted strings are forwarded.
+// - The function does not attempt to fix other non-JSON features beyond quotes.
+func FixJSON(input string) string {
+ var out bytes.Buffer
+
+ inDouble := false
+ inSingle := false
+ escaped := false // applies within the current string state
+
+ // Helper to write a rune, escaping double quotes when inside a converted
+ // single-quoted string (which becomes a double-quoted string in output).
+ writeConverted := func(r rune) {
+ if r == '"' {
+ out.WriteByte('\\')
+ out.WriteByte('"')
+ return
+ }
+ out.WriteRune(r)
+ }
+
+ runes := []rune(input)
+ for i := 0; i < len(runes); i++ {
+ r := runes[i]
+
+ if inDouble {
+ out.WriteRune(r)
+ if escaped {
+ // end of escape sequence in a standard JSON string
+ escaped = false
+ continue
+ }
+ if r == '\\' {
+ escaped = true
+ continue
+ }
+ if r == '"' {
+ inDouble = false
+ }
+ continue
+ }
+
+ if inSingle {
+ if escaped {
+ // Handle common escape sequences after a backslash within a
+ // single-quoted string
+ escaped = false
+ switch r {
+ case 'n', 'r', 't', 'b', 'f', '/', '"':
+ // Keep the backslash and the character (except for '"' which
+ // rarely appears, but if it does, keep as \" to remain valid)
+ out.WriteByte('\\')
+ out.WriteRune(r)
+ case '\\':
+ out.WriteByte('\\')
+ out.WriteByte('\\')
+ case '\'':
+ // \' inside single-quoted becomes a literal '
+ out.WriteRune('\'')
+ case 'u':
+ // Forward \uXXXX if possible
+ out.WriteByte('\\')
+ out.WriteByte('u')
+ // Copy up to next 4 hex digits if present
+ for k := 0; k < 4 && i+1 < len(runes); k++ {
+ peek := runes[i+1]
+ // simple hex check
+ if (peek >= '0' && peek <= '9') || (peek >= 'a' && peek <= 'f') || (peek >= 'A' && peek <= 'F') {
+ out.WriteRune(peek)
+ i++
+ } else {
+ break
+ }
+ }
+ default:
+ // Unknown escape: preserve the backslash and the char
+ out.WriteByte('\\')
+ out.WriteRune(r)
+ }
+ continue
+ }
+
+ if r == '\\' { // start escape sequence
+ escaped = true
+ continue
+ }
+ if r == '\'' { // end of single-quoted string
+ out.WriteByte('"')
+ inSingle = false
+ continue
+ }
+ // regular char inside converted string; escape double quotes
+ writeConverted(r)
+ continue
+ }
+
+ // Outside any string
+ if r == '"' {
+ inDouble = true
+ out.WriteRune(r)
+ continue
+ }
+ if r == '\'' { // start of non-standard single-quoted string
+ inSingle = true
+ out.WriteByte('"')
+ continue
+ }
+ out.WriteRune(r)
+ }
+
+ // If input ended while still inside a single-quoted string, close it to
+ // produce the best-effort valid JSON.
+ if inSingle {
+ out.WriteByte('"')
+ }
+
+ return out.String()
+}
diff --git a/internal/adapter/provider/cliproxyapi_antigravity/adapter.go b/internal/adapter/provider/cliproxyapi_antigravity/adapter.go
index 9d80dcbd..ff138f1f 100644
--- a/internal/adapter/provider/cliproxyapi_antigravity/adapter.go
+++ b/internal/adapter/provider/cliproxyapi_antigravity/adapter.go
@@ -11,8 +11,8 @@ import (
"time"
"github.com/awsl-project/maxx/internal/adapter/provider"
- ctxutil "github.com/awsl-project/maxx/internal/context"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -56,12 +56,14 @@ func (a *CLIProxyAPIAntigravityAdapter) SupportedClientTypes() []domain.ClientTy
return []domain.ClientType{domain.ClientTypeClaude, domain.ClientTypeGemini}
}
-func (a *CLIProxyAPIAntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, p *domain.Provider) error {
- clientType := ctxutil.GetClientType(ctx)
- requestBody := ctxutil.GetRequestBody(ctx)
- stream := ctxutil.GetIsStream(ctx)
- requestModel := ctxutil.GetRequestModel(ctx)
- model := ctxutil.GetMappedModel(ctx) // 全局映射后的模型名(已包含 ProviderType 条件)
+func (a *CLIProxyAPIAntigravityAdapter) Execute(c *flow.Ctx, p *domain.Provider) error {
+ w := c.Writer
+
+ clientType := flow.GetClientType(c)
+ requestBody := flow.GetRequestBody(c)
+ stream := flow.GetIsStream(c)
+ requestModel := flow.GetRequestModel(c)
+ model := flow.GetMappedModel(c) // 全局映射后的模型名(已包含 ProviderType 条件)
log.Printf("[CLIProxyAPI-Antigravity] requestModel=%s, mappedModel=%s, clientType=%s", requestModel, model, clientType)
@@ -72,7 +74,7 @@ func (a *CLIProxyAPIAntigravityAdapter) Execute(ctx context.Context, w http.Resp
}
// 发送事件
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
eventChan.SendRequestInfo(&domain.RequestInfo{
Method: "POST",
URL: fmt.Sprintf("cliproxyapi://antigravity/%s", model),
@@ -105,9 +107,9 @@ func (a *CLIProxyAPIAntigravityAdapter) Execute(ctx context.Context, w http.Resp
}
if stream {
- return a.executeStream(ctx, w, execReq, execOpts)
+ return a.executeStream(c, w, execReq, execOpts)
}
- return a.executeNonStream(ctx, w, execReq, execOpts)
+ return a.executeNonStream(c, w, execReq, execOpts)
}
// updateModelInBody 替换 body 中的 model 字段
@@ -120,14 +122,19 @@ func updateModelInBody(body []byte, model string) ([]byte, error) {
return json.Marshal(req)
}
-func (a *CLIProxyAPIAntigravityAdapter) executeNonStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
+func (a *CLIProxyAPIAntigravityAdapter) executeNonStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
+
resp, err := a.executor.Execute(ctx, a.authObj, execReq, execOpts)
if err != nil {
log.Printf("[CLIProxyAPI-Antigravity] executeNonStream error: model=%s, err=%v", execReq.Model, err)
return domain.NewProxyErrorWithMessage(err, true, fmt.Sprintf("executor request failed: %v", err))
}
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
// Send response info
eventChan.SendResponseInfo(&domain.ResponseInfo{
Status: http.StatusOK,
@@ -159,13 +166,16 @@ func (a *CLIProxyAPIAntigravityAdapter) executeNonStream(ctx context.Context, w
return nil
}
-func (a *CLIProxyAPIAntigravityAdapter) executeStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
+func (a *CLIProxyAPIAntigravityAdapter) executeStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
flusher, ok := w.(http.Flusher)
if !ok {
- return a.executeNonStream(ctx, w, execReq, execOpts)
+ return a.executeNonStream(c, w, execReq, execOpts)
}
- startTime := time.Now()
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
stream, err := a.executor.ExecuteStream(ctx, a.authObj, execReq, execOpts)
if err != nil {
@@ -179,7 +189,7 @@ func (a *CLIProxyAPIAntigravityAdapter) executeStream(ctx context.Context, w htt
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
- eventChan := ctxutil.GetEventChan(ctx)
+ eventChan := flow.GetEventChan(c)
// Collect SSE content for token extraction
var sseBuffer bytes.Buffer
@@ -200,7 +210,7 @@ func (a *CLIProxyAPIAntigravityAdapter) executeStream(ctx context.Context, w htt
// Report TTFT on first non-empty chunk
if !firstChunkSent && eventChan != nil {
- eventChan.SendFirstToken(time.Since(startTime).Milliseconds())
+ eventChan.SendFirstToken(time.Now().UnixMilli())
firstChunkSent = true
}
}
diff --git a/internal/adapter/provider/cliproxyapi_codex/adapter.go b/internal/adapter/provider/cliproxyapi_codex/adapter.go
index bab635c3..c1266b9c 100644
--- a/internal/adapter/provider/cliproxyapi_codex/adapter.go
+++ b/internal/adapter/provider/cliproxyapi_codex/adapter.go
@@ -11,8 +11,8 @@ import (
"time"
"github.com/awsl-project/maxx/internal/adapter/provider"
- ctxutil "github.com/awsl-project/maxx/internal/context"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -53,16 +53,18 @@ func (a *CLIProxyAPICodexAdapter) SupportedClientTypes() []domain.ClientType {
return []domain.ClientType{domain.ClientTypeCodex}
}
-func (a *CLIProxyAPICodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, p *domain.Provider) error {
- requestBody := ctxutil.GetRequestBody(ctx)
- stream := ctxutil.GetIsStream(ctx)
- model := ctxutil.GetMappedModel(ctx)
+func (a *CLIProxyAPICodexAdapter) Execute(c *flow.Ctx, p *domain.Provider) error {
+ w := c.Writer
+
+ requestBody := flow.GetRequestBody(c)
+ stream := flow.GetIsStream(c)
+ model := flow.GetMappedModel(c)
// Codex CLI 使用 OpenAI Responses API 格式
sourceFormat := translator.FormatCodex
// 发送事件
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
eventChan.SendRequestInfo(&domain.RequestInfo{
Method: "POST",
URL: fmt.Sprintf("cliproxyapi://codex/%s", model),
@@ -84,18 +86,23 @@ func (a *CLIProxyAPICodexAdapter) Execute(ctx context.Context, w http.ResponseWr
}
if stream {
- return a.executeStream(ctx, w, execReq, execOpts)
+ return a.executeStream(c, w, execReq, execOpts)
}
- return a.executeNonStream(ctx, w, execReq, execOpts)
+ return a.executeNonStream(c, w, execReq, execOpts)
}
-func (a *CLIProxyAPICodexAdapter) executeNonStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
+func (a *CLIProxyAPICodexAdapter) executeNonStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
+
resp, err := a.executor.Execute(ctx, a.authObj, execReq, execOpts)
if err != nil {
return domain.NewProxyErrorWithMessage(err, true, fmt.Sprintf("executor request failed: %v", err))
}
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
// Send response info
eventChan.SendResponseInfo(&domain.ResponseInfo{
Status: http.StatusOK,
@@ -125,13 +132,16 @@ func (a *CLIProxyAPICodexAdapter) executeNonStream(ctx context.Context, w http.R
return nil
}
-func (a *CLIProxyAPICodexAdapter) executeStream(ctx context.Context, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
+func (a *CLIProxyAPICodexAdapter) executeStream(c *flow.Ctx, w http.ResponseWriter, execReq executor.Request, execOpts executor.Options) error {
flusher, ok := w.(http.Flusher)
if !ok {
- return a.executeNonStream(ctx, w, execReq, execOpts)
+ return a.executeNonStream(c, w, execReq, execOpts)
}
- startTime := time.Now()
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
stream, err := a.executor.ExecuteStream(ctx, a.authObj, execReq, execOpts)
if err != nil {
@@ -144,7 +154,7 @@ func (a *CLIProxyAPICodexAdapter) executeStream(ctx context.Context, w http.Resp
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
- eventChan := ctxutil.GetEventChan(ctx)
+ eventChan := flow.GetEventChan(c)
// Collect SSE content for token extraction
var sseBuffer bytes.Buffer
@@ -165,7 +175,7 @@ func (a *CLIProxyAPICodexAdapter) executeStream(ctx context.Context, w http.Resp
// Report TTFT on first non-empty chunk
if !firstChunkSent && eventChan != nil {
- eventChan.SendFirstToken(time.Since(startTime).Milliseconds())
+ eventChan.SendFirstToken(time.Now().UnixMilli())
firstChunkSent = true
}
}
diff --git a/internal/adapter/provider/codex/adapter.go b/internal/adapter/provider/codex/adapter.go
index 74d90672..0dad347c 100644
--- a/internal/adapter/provider/codex/adapter.go
+++ b/internal/adapter/provider/codex/adapter.go
@@ -1,6 +1,7 @@
package codex
import (
+ "bufio"
"bytes"
"context"
"encoding/json"
@@ -14,13 +15,30 @@ import (
"github.com/awsl-project/maxx/internal/adapter/provider"
cliproxyapi "github.com/awsl-project/maxx/internal/adapter/provider/cliproxyapi_codex"
- ctxutil "github.com/awsl-project/maxx/internal/context"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/usage"
+ "github.com/google/uuid"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
func init() {
provider.RegisterAdapterFactory("codex", NewAdapter)
+ go func() {
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+ for range ticker.C {
+ codexCacheMu.Lock()
+ now := time.Now()
+ for k, v := range codexCaches {
+ if now.After(v.Expire) {
+ delete(codexCaches, k)
+ }
+ }
+ codexCacheMu.Unlock()
+ }
+ }()
}
// TokenCache caches access tokens
@@ -95,9 +113,14 @@ func (a *CodexAdapter) SupportedClientTypes() []domain.ClientType {
return []domain.ClientType{domain.ClientTypeCodex}
}
-func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error {
- requestBody := ctxutil.GetRequestBody(ctx)
- stream := ctxutil.GetIsStream(ctx)
+func (a *CodexAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error {
+ requestBody := flow.GetRequestBody(c)
+ clientWantsStream := flow.GetIsStream(c)
+ request := c.Request
+ ctx := context.Background()
+ if request != nil {
+ ctx = request.Context()
+ }
// Get access token
accessToken, err := a.getAccessToken(ctx)
@@ -105,8 +128,18 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *
return domain.NewProxyErrorWithMessage(err, true, "failed to get access token")
}
+ // Apply Codex CLI payload adjustments (CLIProxyAPI-aligned)
+ cacheID, updatedBody := applyCodexRequestTuning(c, requestBody)
+ requestBody = updatedBody
+
// Build upstream URL
upstreamURL := CodexBaseURL + "/responses"
+ upstreamStream := true
+ if len(requestBody) > 0 {
+ if updated, err := sjson.SetBytes(requestBody, "stream", upstreamStream); err == nil {
+ requestBody = updated
+ }
+ }
// Create upstream request
upstreamReq, err := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(requestBody))
@@ -116,10 +149,10 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *
// Apply headers with passthrough support (client headers take priority)
config := provider.Config.Codex
- a.applyCodexHeaders(upstreamReq, req, accessToken, config.AccountID)
+ a.applyCodexHeaders(upstreamReq, request, accessToken, config.AccountID, upstreamStream, cacheID)
// Send request info via EventChannel
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
eventChan.SendRequestInfo(&domain.RequestInfo{
Method: upstreamReq.Method,
URL: upstreamURL,
@@ -153,8 +186,11 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *
}
// Retry request
- upstreamReq, _ = http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(requestBody))
- a.applyCodexHeaders(upstreamReq, req, accessToken, config.AccountID)
+ upstreamReq, reqErr := http.NewRequestWithContext(ctx, "POST", upstreamURL, bytes.NewReader(requestBody))
+ if reqErr != nil {
+ return domain.NewProxyErrorWithMessage(reqErr, false, fmt.Sprintf("failed to create retry request: %v", reqErr))
+ }
+ a.applyCodexHeaders(upstreamReq, request, accessToken, config.AccountID, upstreamStream, cacheID)
resp, err = a.httpClient.Do(upstreamReq)
if err != nil {
@@ -170,7 +206,7 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *
body, _ := io.ReadAll(resp.Body)
// Send error response info via EventChannel
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
eventChan.SendResponseInfo(&domain.ResponseInfo{
Status: resp.StatusCode,
Headers: flattenHeaders(resp.Header),
@@ -200,26 +236,51 @@ func (a *CodexAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *
}
// Handle response
- if stream {
- return a.handleStreamResponse(ctx, w, resp)
+ if clientWantsStream {
+ return a.handleStreamResponse(c, resp)
}
- return a.handleNonStreamResponse(ctx, w, resp)
+ return a.handleCollectedStreamResponse(c, resp)
}
func (a *CodexAdapter) getAccessToken(ctx context.Context) (string, error) {
// Check cache
a.tokenMu.RLock()
- if a.tokenCache.AccessToken != "" && time.Now().Add(60*time.Second).Before(a.tokenCache.ExpiresAt) {
- token := a.tokenCache.AccessToken
- a.tokenMu.RUnlock()
- return token, nil
+ if a.tokenCache.AccessToken != "" {
+ if a.tokenCache.ExpiresAt.IsZero() || time.Now().Add(60*time.Second).Before(a.tokenCache.ExpiresAt) {
+ token := a.tokenCache.AccessToken
+ a.tokenMu.RUnlock()
+ return token, nil
+ }
}
a.tokenMu.RUnlock()
- // Refresh token
+ // Use persisted access token if present (even if expiry is unknown)
config := a.provider.Config.Codex
+ if strings.TrimSpace(config.AccessToken) != "" {
+ var expiresAt time.Time
+ if strings.TrimSpace(config.ExpiresAt) != "" {
+ if parsed, err := time.Parse(time.RFC3339, config.ExpiresAt); err == nil {
+ expiresAt = parsed
+ }
+ }
+ a.tokenMu.Lock()
+ a.tokenCache = &TokenCache{
+ AccessToken: config.AccessToken,
+ ExpiresAt: expiresAt,
+ }
+ a.tokenMu.Unlock()
+
+ if expiresAt.IsZero() || time.Now().Add(60*time.Second).Before(expiresAt) {
+ return config.AccessToken, nil
+ }
+ }
+
+ // Refresh token
tokenResp, err := RefreshAccessToken(ctx, config.RefreshToken)
if err != nil {
+ if strings.TrimSpace(config.AccessToken) != "" {
+ return config.AccessToken, nil
+ }
return "", err
}
@@ -238,6 +299,37 @@ func (a *CodexAdapter) getAccessToken(ctx context.Context) (string, error) {
if a.providerUpdate != nil {
config.AccessToken = tokenResp.AccessToken
config.ExpiresAt = expiresAt.Format(time.RFC3339)
+ if tokenResp.RefreshToken != "" {
+ config.RefreshToken = tokenResp.RefreshToken
+ }
+ if tokenResp.IDToken != "" {
+ if claims, parseErr := ParseIDToken(tokenResp.IDToken); parseErr == nil && claims != nil {
+ if v := strings.TrimSpace(claims.GetAccountID()); v != "" {
+ config.AccountID = v
+ }
+ if v := strings.TrimSpace(claims.GetUserID()); v != "" {
+ config.UserID = v
+ }
+ if v := strings.TrimSpace(claims.Email); v != "" {
+ config.Email = v
+ }
+ if v := strings.TrimSpace(claims.Name); v != "" {
+ config.Name = v
+ }
+ if v := strings.TrimSpace(claims.Picture); v != "" {
+ config.Picture = v
+ }
+ if v := strings.TrimSpace(claims.GetPlanType()); v != "" {
+ config.PlanType = v
+ }
+ if v := strings.TrimSpace(claims.GetSubscriptionStart()); v != "" {
+ config.SubscriptionStart = v
+ }
+ if v := strings.TrimSpace(claims.GetSubscriptionEnd()); v != "" {
+ config.SubscriptionEnd = v
+ }
+ }
+ }
// Note: We intentionally ignore errors here as token persistence is best-effort
// The token will still work in memory even if DB update fails
_ = a.providerUpdate(a.provider)
@@ -246,72 +338,109 @@ func (a *CodexAdapter) getAccessToken(ctx context.Context) (string, error) {
return tokenResp.AccessToken, nil
}
-func (a *CodexAdapter) handleNonStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response) error {
+func (a *CodexAdapter) handleNonStreamResponse(c *flow.Ctx, resp *http.Response) error {
body, err := io.ReadAll(resp.Body)
if err != nil {
return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to read upstream response")
}
// Send events via EventChannel
- eventChan := ctxutil.GetEventChan(ctx)
- eventChan.SendResponseInfo(&domain.ResponseInfo{
- Status: resp.StatusCode,
- Headers: flattenHeaders(resp.Header),
- Body: string(body),
- })
-
- // Extract token usage from response
- if metrics := usage.ExtractFromResponse(string(body)); metrics != nil {
- eventChan.SendMetrics(&domain.AdapterMetrics{
- InputTokens: metrics.InputTokens,
- OutputTokens: metrics.OutputTokens,
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
+ eventChan.SendResponseInfo(&domain.ResponseInfo{
+ Status: resp.StatusCode,
+ Headers: flattenHeaders(resp.Header),
+ Body: string(body),
})
- }
-
- // Extract model from response
- if model := extractModelFromResponse(body); model != "" {
- eventChan.SendResponseModel(model)
+ // Extract token usage from response
+ if metrics := usage.ExtractFromResponse(string(body)); metrics != nil {
+ eventChan.SendMetrics(&domain.AdapterMetrics{
+ InputTokens: metrics.InputTokens,
+ OutputTokens: metrics.OutputTokens,
+ })
+ }
+ // Extract model from response
+ if model := extractModelFromResponse(body); model != "" {
+ eventChan.SendResponseModel(model)
+ }
}
// Copy response headers
- copyResponseHeaders(w.Header(), resp.Header)
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(resp.StatusCode)
- _, _ = w.Write(body)
+ copyResponseHeaders(c.Writer.Header(), resp.Header)
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(body)
return nil
}
-func (a *CodexAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response) error {
- eventChan := ctxutil.GetEventChan(ctx)
+func (a *CodexAdapter) handleCollectedStreamResponse(c *flow.Ctx, resp *http.Response) error {
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, true, "failed to read upstream response")
+ }
- // Send initial response info
- eventChan.SendResponseInfo(&domain.ResponseInfo{
- Status: resp.StatusCode,
- Headers: flattenHeaders(resp.Header),
- Body: "[streaming]",
- })
+ responsePayload := body
+ if isSSEPayload(body) {
+ if completed := extractCodexCompletedResponse(body); len(completed) > 0 {
+ responsePayload = completed
+ }
+ }
- // Set streaming headers
- copyResponseHeaders(w.Header(), resp.Header)
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("X-Accel-Buffering", "no")
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
+ eventChan.SendResponseInfo(&domain.ResponseInfo{
+ Status: resp.StatusCode,
+ Headers: flattenHeaders(resp.Header),
+ Body: string(responsePayload),
+ })
+ if metrics := usage.ExtractFromResponse(string(responsePayload)); metrics != nil {
+ eventChan.SendMetrics(&domain.AdapterMetrics{
+ InputTokens: metrics.InputTokens,
+ OutputTokens: metrics.OutputTokens,
+ })
+ }
+ if model := extractModelFromResponse(responsePayload); model != "" {
+ eventChan.SendResponseModel(model)
+ }
+ }
- flusher, ok := w.(http.Flusher)
+ copyResponseHeaders(c.Writer.Header(), resp.Header)
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(responsePayload)
+ return nil
+}
+
+func (a *CodexAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response) error {
+ eventChan := flow.GetEventChan(c)
+ if eventChan != nil {
+ eventChan.SendResponseInfo(&domain.ResponseInfo{
+ Status: resp.StatusCode,
+ Headers: flattenHeaders(resp.Header),
+ Body: "[streaming]",
+ })
+ }
+
+ copyResponseHeaders(c.Writer.Header(), resp.Header)
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+
+ flusher, ok := c.Writer.(http.Flusher)
if !ok {
return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, false, "streaming not supported")
}
// Collect SSE for token extraction
var sseBuffer strings.Builder
- var lineBuffer bytes.Buffer
- buf := make([]byte, 4096)
+ reader := bufio.NewReader(resp.Body)
firstChunkSent := false
responseCompleted := false
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
for {
- // Check context
select {
case <-ctx.Done():
a.sendFinalStreamEvents(eventChan, &sseBuffer, resp)
@@ -322,39 +451,30 @@ func (a *CodexAdapter) handleStreamResponse(ctx context.Context, w http.Response
default:
}
- n, err := resp.Body.Read(buf)
- if n > 0 {
- lineBuffer.Write(buf[:n])
-
- // Process complete lines
- for {
- line, readErr := lineBuffer.ReadString('\n')
- if readErr != nil {
- lineBuffer.WriteString(line)
- break
- }
-
- sseBuffer.WriteString(line)
+ line, err := reader.ReadString('\n')
+ if line != "" {
+ sseBuffer.WriteString(line)
- // Check for response.completed in data line
- if strings.HasPrefix(line, "data:") && strings.Contains(line, "response.completed") {
- responseCompleted = true
- }
+ // Check for response.completed in data line
+ if strings.HasPrefix(line, "data:") && strings.Contains(line, "\"response.completed\"") {
+ responseCompleted = true
+ }
- // Write to client
- _, writeErr := w.Write([]byte(line))
- if writeErr != nil {
- a.sendFinalStreamEvents(eventChan, &sseBuffer, resp)
- if responseCompleted {
- return nil
- }
- return domain.NewProxyErrorWithMessage(writeErr, false, "client disconnected")
+ // Write to client
+ _, writeErr := c.Writer.Write([]byte(line))
+ if writeErr != nil {
+ a.sendFinalStreamEvents(eventChan, &sseBuffer, resp)
+ if responseCompleted {
+ return nil
}
- flusher.Flush()
+ return domain.NewProxyErrorWithMessage(writeErr, false, "client disconnected")
+ }
+ flusher.Flush()
- // Track TTFT
- if !firstChunkSent {
- firstChunkSent = true
+ // Track TTFT
+ if !firstChunkSent {
+ firstChunkSent = true
+ if eventChan != nil {
eventChan.SendFirstToken(time.Now().UnixMilli())
}
}
@@ -374,6 +494,9 @@ func (a *CodexAdapter) handleStreamResponse(ctx context.Context, w http.Response
}
func (a *CodexAdapter) sendFinalStreamEvents(eventChan domain.AdapterEventChan, sseBuffer *strings.Builder, resp *http.Response) {
+ if eventChan == nil {
+ return
+ }
if sseBuffer.Len() > 0 {
// Update response body with collected SSE
eventChan.SendResponseInfo(&domain.ResponseInfo{
@@ -397,6 +520,85 @@ func (a *CodexAdapter) sendFinalStreamEvents(eventChan domain.AdapterEventChan,
}
}
+type codexCache struct {
+ ID string
+ Expire time.Time
+}
+
+var (
+ codexCacheMu sync.Mutex
+ codexCaches = map[string]codexCache{}
+)
+
+func getCodexCache(key string) (codexCache, bool) {
+ codexCacheMu.Lock()
+ defer codexCacheMu.Unlock()
+ cache, ok := codexCaches[key]
+ if !ok {
+ return codexCache{}, false
+ }
+ if time.Now().After(cache.Expire) {
+ delete(codexCaches, key)
+ return codexCache{}, false
+ }
+ return cache, true
+}
+
+func setCodexCache(key string, cache codexCache) {
+ codexCacheMu.Lock()
+ codexCaches[key] = cache
+ codexCacheMu.Unlock()
+}
+
+func applyCodexRequestTuning(c *flow.Ctx, body []byte) (string, []byte) {
+ if len(body) == 0 {
+ return "", body
+ }
+
+ origBody := flow.GetOriginalRequestBody(c)
+ origType := flow.GetOriginalClientType(c)
+
+ cacheID := ""
+ if origType == domain.ClientTypeClaude && len(origBody) > 0 {
+ userID := gjson.GetBytes(origBody, "metadata.user_id")
+ if userID.Exists() && strings.TrimSpace(userID.String()) != "" {
+ model := gjson.GetBytes(body, "model").String()
+ key := model + "-" + userID.String()
+ if cache, ok := getCodexCache(key); ok {
+ cacheID = cache.ID
+ } else {
+ cacheID = uuid.NewString()
+ setCodexCache(key, codexCache{
+ ID: cacheID,
+ Expire: time.Now().Add(1 * time.Hour),
+ })
+ }
+ }
+ } else if len(origBody) > 0 {
+ if promptKey := gjson.GetBytes(origBody, "prompt_cache_key"); promptKey.Exists() {
+ cacheID = promptKey.String()
+ }
+ }
+
+ if cacheID != "" {
+ if updated, err := sjson.SetBytes(body, "prompt_cache_key", cacheID); err == nil {
+ body = updated
+ }
+ }
+
+ if updated, err := sjson.SetBytes(body, "stream", true); err == nil {
+ body = updated
+ }
+ body, _ = sjson.DeleteBytes(body, "previous_response_id")
+ body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
+ body, _ = sjson.DeleteBytes(body, "safety_identifier")
+ if !gjson.GetBytes(body, "instructions").Exists() {
+ body, _ = sjson.SetBytes(body, "instructions", "")
+ }
+
+ return cacheID, body
+}
+
func newUpstreamHTTPClient() *http.Client {
dialer := &net.Dialer{
Timeout: 20 * time.Second,
@@ -486,9 +688,39 @@ func extractModelFromSSE(sseContent string) string {
return lastModel
}
+func isSSEPayload(body []byte) bool {
+ trimmed := bytes.TrimSpace(body)
+ return bytes.HasPrefix(trimmed, []byte("data:")) || bytes.HasPrefix(trimmed, []byte("event:"))
+}
+
+func extractCodexCompletedResponse(body []byte) []byte {
+ scanner := bufio.NewScanner(bytes.NewReader(body))
+ scanner.Buffer(nil, 52_428_800)
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ if !bytes.HasPrefix(line, []byte("data:")) {
+ continue
+ }
+ data := bytes.TrimSpace(line[5:])
+ if bytes.Equal(data, []byte("[DONE]")) {
+ continue
+ }
+ root := gjson.ParseBytes(data)
+ if root.Get("type").String() == "response.completed" {
+ if resp := root.Get("response"); resp.Exists() {
+ return []byte(resp.Raw)
+ }
+ return data
+ }
+ }
+ return nil
+}
+
// applyCodexHeaders applies headers for Codex API requests
// It follows the CLIProxyAPI pattern: passthrough client headers, use defaults only when missing
-func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, accessToken, accountID string) {
+func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, accessToken, accountID string, stream bool, cacheID string) {
+ hasAccessToken := strings.TrimSpace(accessToken) != ""
+
// First, copy passthrough headers from client request (excluding hop-by-hop and auth)
if clientReq != nil {
for k, vv := range clientReq.Header {
@@ -496,8 +728,12 @@ func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, a
// Skip hop-by-hop headers and authorization (we'll set our own)
switch lk {
case "connection", "keep-alive", "transfer-encoding", "upgrade",
- "authorization", "host", "content-length":
+ "host", "content-length":
continue
+ case "authorization":
+ if hasAccessToken {
+ continue
+ }
}
for _, v := range vv {
upstreamReq.Header.Add(k, v)
@@ -507,18 +743,32 @@ func (a *CodexAdapter) applyCodexHeaders(upstreamReq, clientReq *http.Request, a
// Set required headers (these always override)
upstreamReq.Header.Set("Content-Type", "application/json")
- upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
- upstreamReq.Header.Set("Accept", "text/event-stream")
+ if hasAccessToken {
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ }
+ if stream {
+ upstreamReq.Header.Set("Accept", "text/event-stream")
+ } else {
+ upstreamReq.Header.Set("Accept", "application/json")
+ }
upstreamReq.Header.Set("Connection", "Keep-Alive")
// Set Codex-specific headers only if client didn't provide them
ensureHeader(upstreamReq.Header, clientReq, "Version", CodexVersion)
ensureHeader(upstreamReq.Header, clientReq, "Openai-Beta", OpenAIBetaHeader)
+ if cacheID != "" {
+ upstreamReq.Header.Set("Conversation_id", cacheID)
+ upstreamReq.Header.Set("Session_id", cacheID)
+ } else {
+ ensureHeader(upstreamReq.Header, clientReq, "Session_id", uuid.NewString())
+ }
ensureHeader(upstreamReq.Header, clientReq, "User-Agent", CodexUserAgent)
- ensureHeader(upstreamReq.Header, clientReq, "Originator", CodexOriginator)
+ if hasAccessToken {
+ ensureHeader(upstreamReq.Header, clientReq, "Originator", CodexOriginator)
+ }
// Set account ID if available (required for OAuth auth, not for API key)
- if accountID != "" {
+ if hasAccessToken && accountID != "" {
upstreamReq.Header.Set("Chatgpt-Account-Id", accountID)
}
}
diff --git a/internal/adapter/provider/codex/adapter_test.go b/internal/adapter/provider/codex/adapter_test.go
new file mode 100644
index 00000000..7aaaaf30
--- /dev/null
+++ b/internal/adapter/provider/codex/adapter_test.go
@@ -0,0 +1,37 @@
+package codex
+
+import (
+ "testing"
+
+ "github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
+ "github.com/tidwall/gjson"
+)
+
+func TestApplyCodexRequestTuning(t *testing.T) {
+ c := flow.NewCtx(nil, nil)
+ c.Set(flow.KeyOriginalClientType, domain.ClientTypeClaude)
+ c.Set(flow.KeyOriginalRequestBody, []byte(`{"metadata":{"user_id":"user-123"}}`))
+
+ body := []byte(`{"model":"gpt-5","stream":false,"instructions":"x","previous_response_id":"r1","prompt_cache_retention":123,"safety_identifier":"s1"}`)
+ cacheID, tuned := applyCodexRequestTuning(c, body)
+
+ if cacheID == "" {
+ t.Fatalf("expected cacheID to be set")
+ }
+ if gjson.GetBytes(tuned, "prompt_cache_key").String() == "" {
+ t.Fatalf("expected prompt_cache_key to be set")
+ }
+ if !gjson.GetBytes(tuned, "stream").Bool() {
+ t.Fatalf("expected stream=true")
+ }
+ if gjson.GetBytes(tuned, "previous_response_id").Exists() {
+ t.Fatalf("expected previous_response_id to be removed")
+ }
+ if gjson.GetBytes(tuned, "prompt_cache_retention").Exists() {
+ t.Fatalf("expected prompt_cache_retention to be removed")
+ }
+ if gjson.GetBytes(tuned, "safety_identifier").Exists() {
+ t.Fatalf("expected safety_identifier to be removed")
+ }
+}
diff --git a/internal/adapter/provider/codex/oauth.go b/internal/adapter/provider/codex/oauth.go
index 1dea2466..2fc6c9cd 100644
--- a/internal/adapter/provider/codex/oauth.go
+++ b/internal/adapter/provider/codex/oauth.go
@@ -188,12 +188,14 @@ func RefreshAccessToken(ctx context.Context, refreshToken string) (*TokenRespons
data.Set("grant_type", "refresh_token")
data.Set("client_id", OAuthClientID)
data.Set("refresh_token", refreshToken)
+ data.Set("scope", "openid profile email")
req, err := http.NewRequestWithContext(ctx, "POST", OpenAITokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
@@ -275,13 +277,13 @@ type CodexUsageResponse struct {
// codexUsageAPIResponse handles both camelCase and snake_case from API
type codexUsageAPIResponse struct {
- PlanType string `json:"plan_type,omitempty"`
- PlanTypeCamel string `json:"planType,omitempty"`
- RateLimit *struct {
- Allowed *bool `json:"allowed,omitempty"`
- LimitReached *bool `json:"limit_reached,omitempty"`
- LimitReachedCamel *bool `json:"limitReached,omitempty"`
- PrimaryWindow *struct {
+ PlanType string `json:"plan_type,omitempty"`
+ PlanTypeCamel string `json:"planType,omitempty"`
+ RateLimit *struct {
+ Allowed *bool `json:"allowed,omitempty"`
+ LimitReached *bool `json:"limit_reached,omitempty"`
+ LimitReachedCamel *bool `json:"limitReached,omitempty"`
+ PrimaryWindow *struct {
UsedPercent *float64 `json:"used_percent,omitempty"`
UsedPercentCamel *float64 `json:"usedPercent,omitempty"`
LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"`
@@ -323,10 +325,10 @@ type codexUsageAPIResponse struct {
} `json:"secondaryWindow,omitempty"`
} `json:"rate_limit,omitempty"`
RateLimitCamel *struct {
- Allowed *bool `json:"allowed,omitempty"`
- LimitReached *bool `json:"limit_reached,omitempty"`
- LimitReachedCamel *bool `json:"limitReached,omitempty"`
- PrimaryWindow *struct {
+ Allowed *bool `json:"allowed,omitempty"`
+ LimitReached *bool `json:"limit_reached,omitempty"`
+ LimitReachedCamel *bool `json:"limitReached,omitempty"`
+ PrimaryWindow *struct {
UsedPercent *float64 `json:"used_percent,omitempty"`
UsedPercentCamel *float64 `json:"usedPercent,omitempty"`
LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"`
@@ -368,10 +370,10 @@ type codexUsageAPIResponse struct {
} `json:"secondaryWindow,omitempty"`
} `json:"rateLimit,omitempty"`
CodeReviewRateLimit *struct {
- Allowed *bool `json:"allowed,omitempty"`
- LimitReached *bool `json:"limit_reached,omitempty"`
- LimitReachedCamel *bool `json:"limitReached,omitempty"`
- PrimaryWindow *struct {
+ Allowed *bool `json:"allowed,omitempty"`
+ LimitReached *bool `json:"limit_reached,omitempty"`
+ LimitReachedCamel *bool `json:"limitReached,omitempty"`
+ PrimaryWindow *struct {
UsedPercent *float64 `json:"used_percent,omitempty"`
UsedPercentCamel *float64 `json:"usedPercent,omitempty"`
LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"`
@@ -393,10 +395,10 @@ type codexUsageAPIResponse struct {
} `json:"primaryWindow,omitempty"`
} `json:"code_review_rate_limit,omitempty"`
CodeReviewRateLimitCamel *struct {
- Allowed *bool `json:"allowed,omitempty"`
- LimitReached *bool `json:"limit_reached,omitempty"`
- LimitReachedCamel *bool `json:"limitReached,omitempty"`
- PrimaryWindow *struct {
+ Allowed *bool `json:"allowed,omitempty"`
+ LimitReached *bool `json:"limit_reached,omitempty"`
+ LimitReachedCamel *bool `json:"limitReached,omitempty"`
+ PrimaryWindow *struct {
UsedPercent *float64 `json:"used_percent,omitempty"`
UsedPercentCamel *float64 `json:"usedPercent,omitempty"`
LimitWindowSeconds *int64 `json:"limit_window_seconds,omitempty"`
diff --git a/internal/adapter/provider/custom/adapter.go b/internal/adapter/provider/custom/adapter.go
index b10fe575..2681ce9a 100644
--- a/internal/adapter/provider/custom/adapter.go
+++ b/internal/adapter/provider/custom/adapter.go
@@ -14,8 +14,8 @@ import (
"time"
"github.com/awsl-project/maxx/internal/adapter/provider"
- ctxutil "github.com/awsl-project/maxx/internal/context"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/usage"
)
@@ -40,10 +40,15 @@ func (a *CustomAdapter) SupportedClientTypes() []domain.ClientType {
return a.provider.SupportedClientTypes
}
-func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error {
- clientType := ctxutil.GetClientType(ctx)
- mappedModel := ctxutil.GetMappedModel(ctx)
- requestBody := ctxutil.GetRequestBody(ctx)
+func (a *CustomAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error {
+ clientType := flow.GetClientType(c)
+ mappedModel := flow.GetMappedModel(c)
+ requestBody := flow.GetRequestBody(c)
+ request := c.Request
+ ctx := context.Background()
+ if request != nil {
+ ctx = request.Context()
+ }
// Determine if streaming
stream := isStreamRequest(requestBody)
@@ -54,7 +59,7 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req
// Build upstream URL
baseURL := a.getBaseURL(clientType)
- requestURI := ctxutil.GetRequestURI(ctx)
+ requestURI := flow.GetRequestURI(c)
// Apply model mapping if configured
var err error
@@ -90,8 +95,8 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req
// Claude: Following CLIProxyAPI pattern
// 1. Process body first (get extraBetas, inject cloaking/cache_control)
clientUA := ""
- if req != nil {
- clientUA = req.Header.Get("User-Agent")
+ if request != nil {
+ clientUA = request.Header.Get("User-Agent")
}
var extraBetas []string
requestBody, extraBetas = processClaudeRequestBody(requestBody, clientUA, a.provider.Config.Custom.Cloak)
@@ -101,33 +106,32 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req
}
// 2. Set headers (streaming only if requested)
- applyClaudeHeaders(upstreamReq, req, a.provider.Config.Custom.APIKey, extraBetas, stream)
+ applyClaudeHeaders(upstreamReq, request, a.provider.Config.Custom.APIKey, extraBetas, stream)
// 3. Update request body and ContentLength (IMPORTANT: body was modified)
upstreamReq.Body = io.NopCloser(bytes.NewReader(requestBody))
upstreamReq.ContentLength = int64(len(requestBody))
case domain.ClientTypeCodex:
// Codex: Use Codex CLI-style headers with passthrough support
- applyCodexHeaders(upstreamReq, req, a.provider.Config.Custom.APIKey)
+ applyCodexHeaders(upstreamReq, request, a.provider.Config.Custom.APIKey)
case domain.ClientTypeGemini:
// Gemini: Use Gemini-style headers with passthrough support
- applyGeminiHeaders(upstreamReq, req, a.provider.Config.Custom.APIKey)
+ applyGeminiHeaders(upstreamReq, request, a.provider.Config.Custom.APIKey)
default:
// Other types: Preserve original header forwarding logic
- originalHeaders := ctxutil.GetRequestHeaders(ctx)
+ originalHeaders := flow.GetRequestHeaders(c)
upstreamReq.Header = originalHeaders
// Override auth headers with provider's credentials
if a.provider.Config.Custom.APIKey != "" {
- // Check if this is a format conversion scenario
- originalClientType := ctxutil.GetOriginalClientType(ctx)
+ originalClientType := flow.GetOriginalClientType(c)
isConversion := originalClientType != "" && originalClientType != clientType
setAuthHeader(upstreamReq, clientType, a.provider.Config.Custom.APIKey, isConversion)
}
}
// Send request info via EventChannel
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
eventChan.SendRequestInfo(&domain.RequestInfo{
Method: upstreamReq.Method,
URL: upstreamURL,
@@ -159,7 +163,7 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req
body, _ := io.ReadAll(reader)
// Send error response info via EventChannel
- if eventChan := ctxutil.GetEventChan(ctx); eventChan != nil {
+ if eventChan := flow.GetEventChan(c); eventChan != nil {
eventChan.SendResponseInfo(&domain.ResponseInfo{
Status: resp.StatusCode,
Headers: flattenHeaders(resp.Header),
@@ -192,9 +196,9 @@ func (a *CustomAdapter) Execute(ctx context.Context, w http.ResponseWriter, req
// Note: Response format conversion is handled by Executor's ConvertingResponseWriter
// Adapters simply pass through the upstream response
if stream {
- return a.handleStreamResponse(ctx, w, resp, clientType, isOAuthToken)
+ return a.handleStreamResponse(c, resp, clientType, isOAuthToken)
}
- return a.handleNonStreamResponse(ctx, w, resp, clientType, isOAuthToken)
+ return a.handleNonStreamResponse(c, resp, clientType, isOAuthToken)
}
func (a *CustomAdapter) supportsClientType(ct domain.ClientType) bool {
@@ -214,7 +218,7 @@ func (a *CustomAdapter) getBaseURL(clientType domain.ClientType) string {
return config.BaseURL
}
-func (a *CustomAdapter) handleNonStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error {
+func (a *CustomAdapter) handleNonStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error {
// Decompress response body if needed
reader, err := decompressResponse(resp)
if err != nil {
@@ -239,45 +243,50 @@ func (a *CustomAdapter) handleNonStreamResponse(ctx context.Context, w http.Resp
body = stripClaudeToolPrefixFromResponse(body, claudeToolPrefix)
}
- eventChan := ctxutil.GetEventChan(ctx)
+ eventChan := flow.GetEventChan(c)
- // Send response info via EventChannel
- eventChan.SendResponseInfo(&domain.ResponseInfo{
- Status: resp.StatusCode,
- Headers: flattenHeaders(resp.Header),
- Body: string(body),
- })
+ if eventChan != nil {
+ eventChan.SendResponseInfo(&domain.ResponseInfo{
+ Status: resp.StatusCode,
+ Headers: flattenHeaders(resp.Header),
+ Body: string(body),
+ })
+ }
// Extract and send token usage metrics
if metrics := usage.ExtractFromResponse(string(body)); metrics != nil {
// Adjust for client-specific quirks (e.g., Codex input_tokens includes cached tokens)
metrics = usage.AdjustForClientType(metrics, clientType)
- eventChan.SendMetrics(&domain.AdapterMetrics{
- InputTokens: metrics.InputTokens,
- OutputTokens: metrics.OutputTokens,
- CacheReadCount: metrics.CacheReadCount,
- CacheCreationCount: metrics.CacheCreationCount,
- Cache5mCreationCount: metrics.Cache5mCreationCount,
- Cache1hCreationCount: metrics.Cache1hCreationCount,
- })
+ if eventChan != nil {
+ eventChan.SendMetrics(&domain.AdapterMetrics{
+ InputTokens: metrics.InputTokens,
+ OutputTokens: metrics.OutputTokens,
+ CacheReadCount: metrics.CacheReadCount,
+ CacheCreationCount: metrics.CacheCreationCount,
+ Cache5mCreationCount: metrics.Cache5mCreationCount,
+ Cache1hCreationCount: metrics.Cache1hCreationCount,
+ })
+ }
}
// Extract and send responseModel
if responseModel := extractResponseModel(body, clientType); responseModel != "" {
- eventChan.SendResponseModel(responseModel)
+ if eventChan != nil {
+ eventChan.SendResponseModel(responseModel)
+ }
}
// Note: Response format conversion is handled by Executor's ConvertingResponseWriter
// Adapter simply passes through the upstream response body
// Copy upstream headers (except those we override)
- copyResponseHeaders(w.Header(), resp.Header)
- w.WriteHeader(resp.StatusCode)
- _, _ = w.Write(body)
+ copyResponseHeaders(c.Writer.Header(), resp.Header)
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(body)
return nil
}
-func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error {
+func (a *CustomAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response, clientType domain.ClientType, isOAuthToken bool) error {
// Decompress response body if needed
reader, err := decompressResponse(resp)
if err != nil {
@@ -285,7 +294,7 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons
}
defer reader.Close()
- eventChan := ctxutil.GetEventChan(ctx)
+ eventChan := flow.GetEventChan(c)
// Send initial response info (for streaming, we only capture status and headers)
eventChan.SendResponseInfo(&domain.ResponseInfo{
@@ -295,24 +304,24 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons
})
// Copy upstream headers (except those we override)
- copyResponseHeaders(w.Header(), resp.Header)
+ copyResponseHeaders(c.Writer.Header(), resp.Header)
// Set streaming headers only if not already set by upstream
// These are required for SSE (Server-Sent Events) to work correctly
- if w.Header().Get("Content-Type") == "" {
- w.Header().Set("Content-Type", "text/event-stream")
+ if c.Writer.Header().Get("Content-Type") == "" {
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
}
- if w.Header().Get("Cache-Control") == "" {
- w.Header().Set("Cache-Control", "no-cache")
+ if c.Writer.Header().Get("Cache-Control") == "" {
+ c.Writer.Header().Set("Cache-Control", "no-cache")
}
- if w.Header().Get("Connection") == "" {
- w.Header().Set("Connection", "keep-alive")
+ if c.Writer.Header().Get("Connection") == "" {
+ c.Writer.Header().Set("Connection", "keep-alive")
}
- if w.Header().Get("X-Accel-Buffering") == "" {
- w.Header().Set("X-Accel-Buffering", "no")
+ if c.Writer.Header().Get("X-Accel-Buffering") == "" {
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
}
- flusher, ok := w.(http.Flusher)
+ flusher, ok := c.Writer.(http.Flusher)
if !ok {
return domain.NewProxyErrorWithMessage(domain.ErrUpstreamError, false, "streaming not supported")
}
@@ -323,6 +332,10 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons
// Collect all SSE events for response body and token extraction
var sseBuffer strings.Builder
var sseError error // Track any SSE error event
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
// Helper to send final events via EventChannel
sendFinalEvents := func() {
@@ -447,7 +460,7 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons
// Note: Response format conversion is handled by Executor's ConvertingResponseWriter
// Adapter simply passes through the upstream SSE data
if len(processedLine) > 0 {
- _, writeErr := w.Write([]byte(processedLine))
+ _, writeErr := c.Writer.Write([]byte(processedLine))
if writeErr != nil {
// Client disconnected
sendFinalEvents()
@@ -456,7 +469,7 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons
flusher.Flush()
// Track TTFT: send first token time on first successful write
- if !firstChunkSent {
+ if !firstChunkSent && eventChan != nil {
firstChunkSent = true
eventChan.SendFirstToken(time.Now().UnixMilli())
}
diff --git a/internal/adapter/provider/kiro/adapter.go b/internal/adapter/provider/kiro/adapter.go
index 93dc1923..83115d8f 100644
--- a/internal/adapter/provider/kiro/adapter.go
+++ b/internal/adapter/provider/kiro/adapter.go
@@ -13,9 +13,9 @@ import (
"time"
"github.com/awsl-project/maxx/internal/adapter/provider"
- ctxutil "github.com/awsl-project/maxx/internal/context"
"github.com/awsl-project/maxx/internal/converter"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/usage"
)
@@ -64,10 +64,15 @@ func (a *KiroAdapter) SupportedClientTypes() []domain.ClientType {
}
// Execute performs the proxy request to the upstream CodeWhisperer API
-func (a *KiroAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request, provider *domain.Provider) error {
- requestModel := ctxutil.GetRequestModel(ctx)
- requestBody := ctxutil.GetRequestBody(ctx)
- stream := ctxutil.GetIsStream(ctx)
+func (a *KiroAdapter) Execute(c *flow.Ctx, provider *domain.Provider) error {
+ requestModel := flow.GetRequestModel(c)
+ requestBody := flow.GetRequestBody(c)
+ stream := flow.GetIsStream(c)
+ request := c.Request
+ ctx := context.Background()
+ if request != nil {
+ ctx = request.Context()
+ }
config := provider.Config.Kiro
@@ -84,18 +89,18 @@ func (a *KiroAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *h
}
// Convert Claude request to CodeWhisperer format (传入 req 用于生成稳定会话ID)
- cwBody, mappedModel, err := ConvertClaudeToCodeWhisperer(requestBody, config.ModelMapping, req)
+ cwBody, mappedModel, err := ConvertClaudeToCodeWhisperer(requestBody, config.ModelMapping, request)
if err != nil {
return domain.NewProxyErrorWithMessage(err, true, fmt.Sprintf("failed to convert request: %v", err))
}
// Update attempt record with the mapped model (kiro-specific internal mapping)
- if attempt := ctxutil.GetUpstreamAttempt(ctx); attempt != nil {
+ if attempt := flow.GetUpstreamAttempt(c); attempt != nil {
attempt.MappedModel = mappedModel
}
// Get EventChannel for sending events to executor
- eventChan := ctxutil.GetEventChan(ctx)
+ eventChan := flow.GetEventChan(c)
// Build upstream URL
upstreamURL := fmt.Sprintf(CodeWhispererURLTemplate, region)
@@ -196,9 +201,9 @@ func (a *KiroAdapter) Execute(ctx context.Context, w http.ResponseWriter, req *h
inputTokens := calculateInputTokens(requestBody)
if stream {
- return a.handleStreamResponse(ctx, w, resp, requestModel, inputTokens)
+ return a.handleStreamResponse(c, resp, requestModel, inputTokens)
}
- return a.handleCollectedStreamResponse(ctx, w, resp, requestModel, inputTokens)
+ return a.handleCollectedStreamResponse(c, resp, requestModel, inputTokens)
}
// getAccessToken gets a valid access token, refreshing if necessary
@@ -335,8 +340,13 @@ func (a *KiroAdapter) refreshIdCToken(ctx context.Context, config *domain.Provid
}
// handleStreamResponse handles streaming EventStream response
-func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, requestModel string, inputTokens int) error {
- eventChan := ctxutil.GetEventChan(ctx)
+func (a *KiroAdapter) handleStreamResponse(c *flow.Ctx, resp *http.Response, requestModel string, inputTokens int) error {
+ w := c.Writer
+ ctx := context.Background()
+ if c.Request != nil {
+ ctx = c.Request.Context()
+ }
+ eventChan := flow.GetEventChan(c)
// Send initial response info
eventChan.SendResponseInfo(&domain.ResponseInfo{
@@ -362,7 +372,7 @@ func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseW
if err := streamCtx.sendInitialEvents(); err != nil {
inTok, outTok := streamCtx.GetTokenCounts()
- a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
+ a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
return domain.NewProxyErrorWithMessage(err, false, "failed to send initial events")
}
@@ -370,30 +380,29 @@ func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseW
if err != nil {
if ctx.Err() != nil {
inTok, outTok := streamCtx.GetTokenCounts()
- a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
+ a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
return domain.NewProxyErrorWithMessage(ctx.Err(), false, "client disconnected")
}
_ = streamCtx.sendFinalEvents()
inTok, outTok := streamCtx.GetTokenCounts()
- a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
+ a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
return nil
}
if err := streamCtx.sendFinalEvents(); err != nil {
inTok, outTok := streamCtx.GetTokenCounts()
- a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
+ a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
return domain.NewProxyErrorWithMessage(err, false, "failed to send final events")
}
inTok, outTok := streamCtx.GetTokenCounts()
- a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
+ a.sendFinalEvents(eventChan, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs())
return nil
}
// sendFinalEvents sends final events via EventChannel
-func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTokens, outputTokens int, requestModel string, firstTokenTimeMs int64) {
- eventChan := ctxutil.GetEventChan(ctx)
+func (a *KiroAdapter) sendFinalEvents(eventChan domain.AdapterEventChan, body string, inputTokens, outputTokens int, requestModel string, firstTokenTimeMs int64) {
if eventChan == nil {
return
}
@@ -405,8 +414,8 @@ func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTok
// Send response info with body
eventChan.SendResponseInfo(&domain.ResponseInfo{
- Status: 200, // streaming always returns 200 at this point
- Body: body,
+ Status: 200, // streaming always returns 200 at this point
+ Body: body,
})
// Try to extract usage metrics from the SSE content first
@@ -432,8 +441,9 @@ func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTok
}
// handleCollectedStreamResponse collects streaming response into a single JSON response
-func (a *KiroAdapter) handleCollectedStreamResponse(ctx context.Context, w http.ResponseWriter, resp *http.Response, requestModel string, inputTokens int) error {
- eventChan := ctxutil.GetEventChan(ctx)
+func (a *KiroAdapter) handleCollectedStreamResponse(c *flow.Ctx, resp *http.Response, requestModel string, inputTokens int) error {
+ w := c.Writer
+ eventChan := flow.GetEventChan(c)
// Send initial response info
eventChan.SendResponseInfo(&domain.ResponseInfo{
diff --git a/internal/context/context.go b/internal/context/context.go
index 0e340421..17c5de7d 100644
--- a/internal/context/context.go
+++ b/internal/context/context.go
@@ -8,26 +8,7 @@ import (
"github.com/awsl-project/maxx/internal/event"
)
-type contextKey string
-
-const (
- CtxKeyClientType contextKey = "client_type"
- CtxKeyOriginalClientType contextKey = "original_client_type" // Original client type before format conversion
- CtxKeySessionID contextKey = "session_id"
- CtxKeyProjectID contextKey = "project_id"
- CtxKeyRequestModel contextKey = "request_model"
- CtxKeyMappedModel contextKey = "mapped_model"
- CtxKeyResponseModel contextKey = "response_model"
- CtxKeyProxyRequest contextKey = "proxy_request"
- CtxKeyRequestBody contextKey = "request_body"
- CtxKeyUpstreamAttempt contextKey = "upstream_attempt"
- CtxKeyRequestHeaders contextKey = "request_headers"
- CtxKeyRequestURI contextKey = "request_uri"
- CtxKeyBroadcaster contextKey = "broadcaster"
- CtxKeyIsStream contextKey = "is_stream"
- CtxKeyAPITokenID contextKey = "api_token_id"
- CtxKeyEventChan contextKey = "event_chan"
-)
+// context keys defined in keys.go
// Setters
func WithClientType(ctx context.Context, ct domain.ClientType) context.Context {
diff --git a/internal/context/keys.go b/internal/context/keys.go
new file mode 100644
index 00000000..1c7decdf
--- /dev/null
+++ b/internal/context/keys.go
@@ -0,0 +1,22 @@
+package context
+
+type contextKey string
+
+const (
+ CtxKeyClientType contextKey = "client_type"
+ CtxKeyOriginalClientType contextKey = "original_client_type"
+ CtxKeySessionID contextKey = "session_id"
+ CtxKeyProjectID contextKey = "project_id"
+ CtxKeyRequestModel contextKey = "request_model"
+ CtxKeyMappedModel contextKey = "mapped_model"
+ CtxKeyResponseModel contextKey = "response_model"
+ CtxKeyProxyRequest contextKey = "proxy_request"
+ CtxKeyRequestBody contextKey = "request_body"
+ CtxKeyUpstreamAttempt contextKey = "upstream_attempt"
+ CtxKeyRequestHeaders contextKey = "request_headers"
+ CtxKeyRequestURI contextKey = "request_uri"
+ CtxKeyBroadcaster contextKey = "broadcaster"
+ CtxKeyIsStream contextKey = "is_stream"
+ CtxKeyAPITokenID contextKey = "api_token_id"
+ CtxKeyEventChan contextKey = "event_chan"
+)
diff --git a/internal/converter/claude_to_codex.go b/internal/converter/claude_to_codex.go
index a5d789eb..4808ba5f 100644
--- a/internal/converter/claude_to_codex.go
+++ b/internal/converter/claude_to_codex.go
@@ -30,6 +30,22 @@ func (c *claudeToCodexRequest) Transform(body []byte, model string, stream bool)
TopP: req.TopP,
}
+ shortMap := map[string]string{}
+ if len(req.Tools) > 0 {
+ var names []string
+ for _, tool := range req.Tools {
+ if tool.Type != "" {
+ continue // server tools should keep their type
+ }
+ if tool.Name != "" {
+ names = append(names, tool.Name)
+ }
+ }
+ if len(names) > 0 {
+ shortMap = buildShortNameMap(names)
+ }
+ }
+
// Convert messages to input
var input []CodexInputItem
if req.System != nil {
@@ -69,6 +85,11 @@ func (c *claudeToCodexRequest) Transform(body []byte, model string, stream bool)
case "tool_use":
// Convert tool use to function_call output
name, _ := m["name"].(string)
+ if short, ok := shortMap[name]; ok {
+ name = short
+ } else {
+ name = shortenNameIfNeeded(name)
+ }
id, _ := m["id"].(string)
inputData := m["input"]
argJSON, _ := json.Marshal(inputData)
@@ -102,9 +123,21 @@ func (c *claudeToCodexRequest) Transform(body []byte, model string, stream bool)
// Convert tools
for _, tool := range req.Tools {
+ if tool.Type != "" {
+ codexReq.Tools = append(codexReq.Tools, CodexTool{
+ Type: tool.Type,
+ })
+ continue
+ }
+ name := tool.Name
+ if short, ok := shortMap[name]; ok {
+ name = short
+ } else {
+ name = shortenNameIfNeeded(name)
+ }
codexReq.Tools = append(codexReq.Tools, CodexTool{
Type: "function",
- Name: tool.Name,
+ Name: name,
Description: tool.Description,
Parameters: tool.InputSchema,
})
diff --git a/internal/converter/claude_to_openai_stream.go b/internal/converter/claude_to_openai_stream.go
index d6fc9513..289852ba 100644
--- a/internal/converter/claude_to_openai_stream.go
+++ b/internal/converter/claude_to_openai_stream.go
@@ -3,8 +3,14 @@ package converter
import (
"encoding/json"
"time"
+
+ "github.com/tidwall/gjson"
)
+type claudeOpenAIStreamMeta struct {
+ Model string
+}
+
func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) {
events, remaining := ParseSSE(state.Buffer + string(chunk))
state.Buffer = remaining
@@ -23,6 +29,16 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
switch claudeEvent.Type {
case "message_start":
+ streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta)
+ if streamMeta == nil {
+ streamMeta = &claudeOpenAIStreamMeta{}
+ state.Custom = streamMeta
+ }
+ if streamMeta.Model == "" && len(state.OriginalRequestBody) > 0 {
+ if reqModel := gjson.GetBytes(state.OriginalRequestBody, "model"); reqModel.Exists() && reqModel.String() != "" {
+ streamMeta.Model = reqModel.String()
+ }
+ }
if claudeEvent.Message != nil {
state.MessageID = claudeEvent.Message.ID
}
@@ -30,6 +46,7 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
Delta: &OpenAIMessage{Role: "assistant", Content: ""},
@@ -56,25 +73,35 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
if claudeEvent.Delta != nil {
switch claudeEvent.Delta.Type {
case "text_delta":
+ streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta)
+ if streamMeta == nil {
+ continue
+ }
chunk := OpenAIStreamChunk{
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
- Delta: &OpenAIMessage{Content: claudeEvent.Delta.Text},
+ Delta: &OpenAIMessage{Role: "assistant", Content: claudeEvent.Delta.Text},
}},
}
output = append(output, FormatSSE("", chunk)...)
case "thinking_delta":
if claudeEvent.Delta.Thinking != "" {
+ streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta)
+ if streamMeta == nil {
+ continue
+ }
chunk := OpenAIStreamChunk{
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
- Delta: &OpenAIMessage{ReasoningContent: claudeEvent.Delta.Thinking},
+ Delta: &OpenAIMessage{Role: "assistant", ReasoningContent: claudeEvent.Delta.Thinking},
}},
}
output = append(output, FormatSSE("", chunk)...)
@@ -82,13 +109,19 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
case "input_json_delta":
if tc, ok := state.ToolCalls[state.CurrentIndex]; ok {
tc.Arguments += claudeEvent.Delta.PartialJSON
+ streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta)
+ if streamMeta == nil {
+ continue
+ }
chunk := OpenAIStreamChunk{
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
Delta: &OpenAIMessage{
+ Role: "assistant",
ToolCalls: []OpenAIToolCall{{
Index: state.CurrentIndex,
ID: tc.ID,
@@ -124,13 +157,19 @@ func (c *claudeToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
case "tool_use":
finishReason = "tool_calls"
}
+ streamMeta, _ := state.Custom.(*claudeOpenAIStreamMeta)
+ if streamMeta == nil {
+ streamMeta = &claudeOpenAIStreamMeta{}
+ state.Custom = streamMeta
+ }
chunk := OpenAIStreamChunk{
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
- Delta: &OpenAIMessage{},
+ Delta: &OpenAIMessage{Role: "assistant", Content: ""},
FinishReason: finishReason,
}},
}
diff --git a/internal/converter/codex_openai_stream_test.go b/internal/converter/codex_openai_stream_test.go
new file mode 100644
index 00000000..3cf58550
--- /dev/null
+++ b/internal/converter/codex_openai_stream_test.go
@@ -0,0 +1,224 @@
+package converter
+
+import (
+ "encoding/json"
+ "strings"
+ "testing"
+)
+
+func TestCodexToOpenAIStreamToolCalls(t *testing.T) {
+ state := NewTransformState()
+ conv := &codexToOpenAIResponse{}
+
+ created := map[string]interface{}{
+ "type": "response.created",
+ "response": map[string]interface{}{
+ "id": "resp_test_1",
+ },
+ }
+ added := map[string]interface{}{
+ "type": "response.output_item.added",
+ "output_index": 0,
+ "item": map[string]interface{}{
+ "id": "fc_call1",
+ "type": "function_call",
+ "call_id": "call1",
+ "name": "tool_alpha",
+ },
+ }
+ doneItem := map[string]interface{}{
+ "type": "response.output_item.done",
+ "item": map[string]interface{}{
+ "type": "function_call",
+ "call_id": "call1",
+ "name": "tool_alpha",
+ "arguments": `{"a":1}`,
+ },
+ }
+ completed := map[string]interface{}{
+ "type": "response.completed",
+ "response": map[string]interface{}{
+ "id": "resp_test_1",
+ },
+ }
+
+ var out []byte
+ for _, ev := range []interface{}{created, added, doneItem, completed} {
+ chunk := FormatSSE("", ev)
+ next, err := conv.TransformChunk(chunk, state)
+ if err != nil {
+ t.Fatalf("transform chunk error: %v", err)
+ }
+ out = append(out, next...)
+ }
+
+ events, _ := ParseSSE(string(out))
+ if len(events) == 0 {
+ t.Fatalf("no SSE events produced")
+ }
+
+ foundToolDelta := false
+ foundFinishToolCalls := false
+
+ for _, ev := range events {
+ if ev.Event == "done" {
+ continue
+ }
+ var chunk OpenAIStreamChunk
+ if err := json.Unmarshal(ev.Data, &chunk); err != nil {
+ t.Fatalf("invalid chunk JSON: %v", err)
+ }
+ if len(chunk.Choices) == 0 {
+ continue
+ }
+ if chunk.Choices[0].Delta != nil && len(chunk.Choices[0].Delta.ToolCalls) > 0 {
+ tc := chunk.Choices[0].Delta.ToolCalls[0]
+ if tc.Type == "function" && tc.Function.Arguments != "" {
+ foundToolDelta = true
+ }
+ }
+ if chunk.Choices[0].FinishReason == "tool_calls" {
+ foundFinishToolCalls = true
+ }
+ }
+
+ if !foundToolDelta {
+ t.Fatalf("expected tool_calls delta in stream output")
+ }
+ if !foundFinishToolCalls {
+ t.Fatalf("expected finish_reason=tool_calls in stream output")
+ }
+}
+
+func TestCodexToClaudeStreamToolStopReason(t *testing.T) {
+ state := NewTransformState()
+ conv := &codexToClaudeResponse{}
+
+ created := map[string]interface{}{
+ "type": "response.created",
+ "response": map[string]interface{}{
+ "id": "resp_test_2",
+ },
+ }
+ added := map[string]interface{}{
+ "type": "response.output_item.added",
+ "output_index": 0,
+ "item": map[string]interface{}{
+ "id": "fc_call2",
+ "type": "function_call",
+ "call_id": "call2",
+ "name": "tool_beta",
+ },
+ }
+ doneItem := map[string]interface{}{
+ "type": "response.output_item.done",
+ "item": map[string]interface{}{
+ "type": "function_call",
+ "call_id": "call2",
+ "name": "tool_beta",
+ "arguments": `{"b":2}`,
+ },
+ }
+ completed := map[string]interface{}{
+ "type": "response.completed",
+ "response": map[string]interface{}{
+ "id": "resp_test_2",
+ },
+ }
+
+ var out []byte
+ for _, ev := range []interface{}{created, added, doneItem, completed} {
+ chunk := FormatSSE("", ev)
+ next, err := conv.TransformChunk(chunk, state)
+ if err != nil {
+ t.Fatalf("transform chunk error: %v", err)
+ }
+ out = append(out, next...)
+ }
+
+ events, _ := ParseSSE(string(out))
+ if len(events) == 0 {
+ t.Fatalf("no SSE events produced")
+ }
+
+ foundStopReason := false
+ for _, ev := range events {
+ if ev.Event != "message_delta" {
+ continue
+ }
+ var payload map[string]interface{}
+ if err := json.Unmarshal(ev.Data, &payload); err != nil {
+ t.Fatalf("invalid event JSON: %v", err)
+ }
+ if delta, ok := payload["delta"].(map[string]interface{}); ok {
+ if sr, ok := delta["stop_reason"].(string); ok && sr == "tool_use" {
+ foundStopReason = true
+ }
+ }
+ }
+
+ if !foundStopReason {
+ t.Fatalf("expected stop_reason=tool_use in Claude stream output")
+ }
+}
+
+func TestClaudeToCodexToolShortening(t *testing.T) {
+ longName := "mcp__server__" + strings.Repeat("x", 80)
+ claudeReq := map[string]interface{}{
+ "model": "claude-3",
+ "messages": []map[string]interface{}{
+ {"role": "user", "content": "hi"},
+ },
+ "tools": []map[string]interface{}{
+ {
+ "name": longName,
+ "description": "d",
+ "input_schema": map[string]interface{}{"type": "object"},
+ },
+ {
+ "type": "web_search_20250305",
+ },
+ },
+ }
+
+ raw, err := json.Marshal(claudeReq)
+ if err != nil {
+ t.Fatalf("marshal claude req: %v", err)
+ }
+
+ conv := &claudeToCodexRequest{}
+ out, err := conv.Transform(raw, "gpt-5.2-codex", false)
+ if err != nil {
+ t.Fatalf("transform error: %v", err)
+ }
+
+ var codexReq CodexRequest
+ if err := json.Unmarshal(out, &codexReq); err != nil {
+ t.Fatalf("unmarshal codex req: %v", err)
+ }
+
+ if len(codexReq.Tools) != 2 {
+ t.Fatalf("tools = %d, want 2", len(codexReq.Tools))
+ }
+
+ var fnTool *CodexTool
+ var serverTool *CodexTool
+ for i := range codexReq.Tools {
+ switch codexReq.Tools[i].Type {
+ case "function":
+ fnTool = &codexReq.Tools[i]
+ case "web_search_20250305":
+ serverTool = &codexReq.Tools[i]
+ }
+ }
+
+ if fnTool == nil || fnTool.Name == "" {
+ t.Fatalf("missing function tool after transform")
+ }
+ if len(fnTool.Name) > maxToolNameLen {
+ t.Fatalf("function tool name too long: %d", len(fnTool.Name))
+ }
+ if serverTool == nil {
+ t.Fatalf("missing server tool type in codex tools")
+ }
+}
diff --git a/internal/converter/codex_to_claude.go b/internal/converter/codex_to_claude.go
index 36532f2e..bf627b7d 100644
--- a/internal/converter/codex_to_claude.go
+++ b/internal/converter/codex_to_claude.go
@@ -2,8 +2,11 @@ package converter
import (
"encoding/json"
+ "strings"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
func init() {
@@ -13,6 +16,12 @@ func init() {
type codexToClaudeRequest struct{}
type codexToClaudeResponse struct{}
+type claudeStreamState struct {
+ HasToolCall bool
+ BlockIndex int
+ ShortToOrig map[string]string
+}
+
func (c *codexToClaudeRequest) Transform(body []byte, model string, stream bool) ([]byte, error) {
var req CodexRequest
if err := json.Unmarshal(body, &req); err != nil {
@@ -106,85 +115,77 @@ func (c *codexToClaudeRequest) Transform(body []byte, model string, stream bool)
}
func (c *codexToClaudeResponse) Transform(body []byte) ([]byte, error) {
- var resp CodexResponse
- if err := json.Unmarshal(body, &resp); err != nil {
- return nil, err
- }
-
- claudeResp := ClaudeResponse{
- ID: resp.ID,
- Type: "message",
- Role: "assistant",
- Model: resp.Model,
- Usage: ClaudeUsage{
- InputTokens: resp.Usage.InputTokens,
- OutputTokens: resp.Usage.OutputTokens,
- },
- }
-
- var hasToolCall bool
- for _, out := range resp.Output {
- switch out.Type {
- case "message":
- contentStr, _ := out.Content.(string)
- claudeResp.Content = append(claudeResp.Content, ClaudeContentBlock{
- Type: "text",
- Text: contentStr,
- })
- case "function_call":
- hasToolCall = true
- var args interface{}
- json.Unmarshal([]byte(out.Arguments), &args)
- claudeResp.Content = append(claudeResp.Content, ClaudeContentBlock{
- Type: "tool_use",
- ID: out.ID,
- Name: out.Name,
- Input: args,
- })
- }
- }
-
- if hasToolCall {
- claudeResp.StopReason = "tool_use"
- } else {
- claudeResp.StopReason = "end_turn"
- }
-
- return json.Marshal(claudeResp)
+ return c.TransformWithState(body, nil)
}
func (c *codexToClaudeResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) {
events, remaining := ParseSSE(state.Buffer + string(chunk))
state.Buffer = remaining
+ st := getClaudeStreamState(state)
var output []byte
for _, event := range events {
- var codexEvent map[string]interface{}
- if err := json.Unmarshal(event.Data, &codexEvent); err != nil {
+ if event.Event == "done" {
continue
}
- eventType, _ := codexEvent["type"].(string)
+ root := gjson.ParseBytes(event.Data)
+ if !root.Exists() {
+ continue
+ }
+
+ eventType := root.Get("type").String()
switch eventType {
case "response.created":
- if resp, ok := codexEvent["response"].(map[string]interface{}); ok {
- state.MessageID, _ = resp["id"].(string)
- }
+ state.MessageID = root.Get("response.id").String()
msgStart := map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": state.MessageID,
"type": "message",
"role": "assistant",
+ "model": root.Get("response.model").String(),
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
},
}
output = append(output, FormatSSE("message_start", msgStart)...)
+ case "response.reasoning_summary_part.added":
+ blockStart := map[string]interface{}{
+ "type": "content_block_start",
+ "index": st.BlockIndex,
+ "content_block": map[string]interface{}{
+ "type": "thinking",
+ "thinking": "",
+ },
+ }
+ output = append(output, FormatSSE("content_block_start", blockStart)...)
+
+ case "response.reasoning_summary_text.delta":
+ delta := root.Get("delta").String()
+ claudeDelta := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": st.BlockIndex,
+ "delta": map[string]interface{}{
+ "type": "thinking_delta",
+ "thinking": delta,
+ },
+ }
+ output = append(output, FormatSSE("content_block_delta", claudeDelta)...)
+
+ case "response.reasoning_summary_part.done":
+ blockStop := map[string]interface{}{
+ "type": "content_block_stop",
+ "index": st.BlockIndex,
+ }
+ output = append(output, FormatSSE("content_block_stop", blockStop)...)
+ st.BlockIndex++
+
+ case "response.content_part.added":
blockStart := map[string]interface{}{
"type": "content_block_start",
- "index": 0,
+ "index": st.BlockIndex,
"content_block": map[string]interface{}{
"type": "text",
"text": "",
@@ -192,34 +193,104 @@ func (c *codexToClaudeResponse) TransformChunk(chunk []byte, state *TransformSta
}
output = append(output, FormatSSE("content_block_start", blockStart)...)
- case "response.output_item.delta":
- if delta, ok := codexEvent["delta"].(map[string]interface{}); ok {
- if text, ok := delta["text"].(string); ok {
- claudeDelta := map[string]interface{}{
- "type": "content_block_delta",
- "index": 0,
- "delta": map[string]interface{}{
- "type": "text_delta",
- "text": text,
- },
- }
- output = append(output, FormatSSE("content_block_delta", claudeDelta)...)
- }
+ case "response.output_text.delta":
+ delta := root.Get("delta").String()
+ claudeDelta := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": st.BlockIndex,
+ "delta": map[string]interface{}{
+ "type": "text_delta",
+ "text": delta,
+ },
}
+ output = append(output, FormatSSE("content_block_delta", claudeDelta)...)
- case "response.done":
+ case "response.content_part.done":
blockStop := map[string]interface{}{
"type": "content_block_stop",
- "index": 0,
+ "index": st.BlockIndex,
}
output = append(output, FormatSSE("content_block_stop", blockStop)...)
+ st.BlockIndex++
+
+ case "response.output_item.added":
+ item := root.Get("item")
+ if item.Get("type").String() == "function_call" {
+ st.HasToolCall = true
+ if st.ShortToOrig == nil {
+ st.ShortToOrig = buildReverseMapFromClaudeOriginalShortToOriginal(state.OriginalRequestBody)
+ }
+ name := item.Get("name").String()
+ if orig, ok := st.ShortToOrig[name]; ok {
+ name = orig
+ }
+ blockStart := map[string]interface{}{
+ "type": "content_block_start",
+ "index": st.BlockIndex,
+ "content_block": map[string]interface{}{
+ "type": "tool_use",
+ "id": item.Get("call_id").String(),
+ "name": name,
+ "input": map[string]interface{}{},
+ },
+ }
+ output = append(output, FormatSSE("content_block_start", blockStart)...)
+
+ blockDelta := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": st.BlockIndex,
+ "delta": map[string]interface{}{
+ "type": "input_json_delta",
+ "partial_json": "",
+ },
+ }
+ output = append(output, FormatSSE("content_block_delta", blockDelta)...)
+ }
+
+ case "response.function_call_arguments.delta":
+ blockDelta := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": st.BlockIndex,
+ "delta": map[string]interface{}{
+ "type": "input_json_delta",
+ "partial_json": root.Get("delta").String(),
+ },
+ }
+ output = append(output, FormatSSE("content_block_delta", blockDelta)...)
+
+ case "response.output_item.done":
+ item := root.Get("item")
+ if item.Get("type").String() == "function_call" {
+ blockStop := map[string]interface{}{
+ "type": "content_block_stop",
+ "index": st.BlockIndex,
+ }
+ output = append(output, FormatSSE("content_block_stop", blockStop)...)
+ st.BlockIndex++
+ }
+ case "response.completed":
+ stopReason := root.Get("response.stop_reason").String()
+ if stopReason == "" {
+ if st.HasToolCall {
+ stopReason = "tool_use"
+ } else {
+ stopReason = "end_turn"
+ }
+ }
+ inputTokens, outputTokens, cachedTokens := extractResponsesUsage(root.Get("response.usage"))
msgDelta := map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
- "stop_reason": "end_turn",
+ "stop_reason": stopReason,
+ },
+ "usage": map[string]int{
+ "input_tokens": inputTokens,
+ "output_tokens": outputTokens,
},
- "usage": map[string]int{"output_tokens": 0},
+ }
+ if cachedTokens > 0 {
+ msgDelta["usage"].(map[string]int)["cache_read_input_tokens"] = cachedTokens
}
output = append(output, FormatSSE("message_delta", msgDelta)...)
output = append(output, FormatSSE("message_stop", map[string]string{"type": "message_stop"})...)
@@ -228,3 +299,171 @@ func (c *codexToClaudeResponse) TransformChunk(chunk []byte, state *TransformSta
return output, nil
}
+
+func getClaudeStreamState(state *TransformState) *claudeStreamState {
+ if state.Custom == nil {
+ state.Custom = &claudeStreamState{}
+ }
+ st, ok := state.Custom.(*claudeStreamState)
+ if !ok || st == nil {
+ st = &claudeStreamState{}
+ state.Custom = st
+ }
+ return st
+}
+
+func (c *codexToClaudeResponse) TransformWithState(body []byte, state *TransformState) ([]byte, error) {
+ root := gjson.ParseBytes(body)
+ var response gjson.Result
+ if root.Get("type").String() == "response.completed" && root.Get("response").Exists() {
+ response = root.Get("response")
+ } else if root.Get("output").Exists() {
+ response = root
+ } else {
+ return nil, nil
+ }
+
+ revNames := map[string]string{}
+ if state != nil && len(state.OriginalRequestBody) > 0 {
+ revNames = buildReverseMapFromClaudeOriginalShortToOriginal(state.OriginalRequestBody)
+ }
+
+ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
+ out, _ = sjson.Set(out, "id", response.Get("id").String())
+ out, _ = sjson.Set(out, "model", response.Get("model").String())
+ inputTokens, outputTokens, cachedTokens := extractResponsesUsage(response.Get("usage"))
+ out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
+ out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
+ if cachedTokens > 0 {
+ out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
+ }
+
+ hasToolCall := false
+ if output := response.Get("output"); output.Exists() && output.IsArray() {
+ output.ForEach(func(_, item gjson.Result) bool {
+ switch item.Get("type").String() {
+ case "reasoning":
+ thinkingBuilder := strings.Builder{}
+ if summary := item.Get("summary"); summary.Exists() {
+ if summary.IsArray() {
+ summary.ForEach(func(_, part gjson.Result) bool {
+ if txt := part.Get("text"); txt.Exists() {
+ thinkingBuilder.WriteString(txt.String())
+ } else {
+ thinkingBuilder.WriteString(part.String())
+ }
+ return true
+ })
+ } else {
+ thinkingBuilder.WriteString(summary.String())
+ }
+ }
+ if thinkingBuilder.Len() == 0 {
+ if content := item.Get("content"); content.Exists() {
+ if content.IsArray() {
+ content.ForEach(func(_, part gjson.Result) bool {
+ if txt := part.Get("text"); txt.Exists() {
+ thinkingBuilder.WriteString(txt.String())
+ } else {
+ thinkingBuilder.WriteString(part.String())
+ }
+ return true
+ })
+ } else {
+ thinkingBuilder.WriteString(content.String())
+ }
+ }
+ }
+ if thinkingBuilder.Len() > 0 {
+ block := `{"type":"thinking","thinking":""}`
+ block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
+ out, _ = sjson.SetRaw(out, "content.-1", block)
+ }
+ case "message":
+ if content := item.Get("content"); content.Exists() {
+ if content.IsArray() {
+ content.ForEach(func(_, part gjson.Result) bool {
+ if part.Get("type").String() == "output_text" {
+ block := `{"type":"text","text":""}`
+ block, _ = sjson.Set(block, "text", part.Get("text").String())
+ out, _ = sjson.SetRaw(out, "content.-1", block)
+ }
+ return true
+ })
+ } else if content.Type == gjson.String {
+ block := `{"type":"text","text":""}`
+ block, _ = sjson.Set(block, "text", content.String())
+ out, _ = sjson.SetRaw(out, "content.-1", block)
+ }
+ }
+ case "function_call":
+ hasToolCall = true
+ callID := item.Get("call_id").String()
+ name := item.Get("name").String()
+ if orig, ok := revNames[name]; ok {
+ name = orig
+ }
+ argsRaw := item.Get("arguments").String()
+ var args interface{}
+ if argsRaw != "" {
+ _ = json.Unmarshal([]byte(argsRaw), &args)
+ }
+ block := `{"type":"tool_use","id":"","name":"","input":{}}`
+ block, _ = sjson.Set(block, "id", callID)
+ block, _ = sjson.Set(block, "name", name)
+ if args != nil {
+ block, _ = sjson.Set(block, "input", args)
+ }
+ out, _ = sjson.SetRaw(out, "content.-1", block)
+ }
+ return true
+ })
+ }
+
+ stopReason := response.Get("stop_reason").String()
+ if stopReason == "" {
+ if hasToolCall {
+ stopReason = "tool_use"
+ } else {
+ stopReason = "end_turn"
+ }
+ }
+ out, _ = sjson.Set(out, "stop_reason", stopReason)
+
+ return []byte(out), nil
+}
+
+func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[string]string {
+ tools := gjson.GetBytes(original, "tools")
+ rev := map[string]string{}
+ if tools.IsArray() && len(tools.Array()) > 0 {
+ var names []string
+ arr := tools.Array()
+ for i := 0; i < len(arr); i++ {
+ t := arr[i]
+ if t.Get("type").String() != "" {
+ continue
+ }
+ if v := t.Get("name"); v.Exists() {
+ names = append(names, v.String())
+ }
+ }
+ if len(names) > 0 {
+ m := buildShortNameMap(names)
+ for orig, short := range m {
+ rev[short] = orig
+ }
+ }
+ }
+ return rev
+}
+
+func extractResponsesUsage(usage gjson.Result) (int, int, int) {
+ if !usage.Exists() {
+ return 0, 0, 0
+ }
+ inputTokens := int(usage.Get("input_tokens").Int())
+ outputTokens := int(usage.Get("output_tokens").Int())
+ cachedTokens := int(usage.Get("input_tokens_details.cached_tokens").Int())
+ return inputTokens, outputTokens, cachedTokens
+}
diff --git a/internal/converter/codex_to_openai.go b/internal/converter/codex_to_openai.go
index 455437f2..7996a6b1 100644
--- a/internal/converter/codex_to_openai.go
+++ b/internal/converter/codex_to_openai.go
@@ -1,10 +1,13 @@
package converter
import (
+ "bytes"
"encoding/json"
"time"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
func init() {
@@ -14,6 +17,24 @@ func init() {
type codexToOpenAIRequest struct{}
type codexToOpenAIResponse struct{}
+type openaiStreamState struct {
+ Started bool
+ HasToolCall bool
+ ToolCalls map[int]*openaiToolCallState
+ ShortToOrig map[string]string
+ Index int
+ CreatedAt int64
+ Model string
+ FinishSent bool
+}
+
+type openaiToolCallState struct {
+ ID string
+ CallID string
+ Name string
+ NameSent bool
+}
+
func (c *codexToOpenAIRequest) Transform(body []byte, model string, stream bool) ([]byte, error) {
var req CodexRequest
if err := json.Unmarshal(body, &req); err != nil {
@@ -107,126 +128,319 @@ func (c *codexToOpenAIRequest) Transform(body []byte, model string, stream bool)
}
func (c *codexToOpenAIResponse) Transform(body []byte) ([]byte, error) {
- var resp CodexResponse
- if err := json.Unmarshal(body, &resp); err != nil {
- return nil, err
- }
+ return c.TransformWithState(body, nil)
+}
- openaiResp := OpenAIResponse{
- ID: resp.ID,
- Object: "chat.completion",
- Created: resp.CreatedAt,
- Model: resp.Model,
- Usage: OpenAIUsage{
- PromptTokens: resp.Usage.InputTokens,
- CompletionTokens: resp.Usage.OutputTokens,
- TotalTokens: resp.Usage.TotalTokens,
- },
+func (c *codexToOpenAIResponse) TransformWithState(body []byte, state *TransformState) ([]byte, error) {
+ root := gjson.ParseBytes(body)
+ var response gjson.Result
+ if root.Get("type").String() == "response.completed" && root.Get("response").Exists() {
+ response = root.Get("response")
+ } else if root.Get("output").Exists() {
+ response = root
+ } else {
+ return body, nil
}
- msg := OpenAIMessage{Role: "assistant"}
- var textContent string
- var toolCalls []OpenAIToolCall
+ template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
- for _, out := range resp.Output {
- switch out.Type {
- case "message":
- if s, ok := out.Content.(string); ok {
- textContent += s
- }
- case "function_call":
- toolCalls = append(toolCalls, OpenAIToolCall{
- ID: out.ID,
- Type: "function",
- Function: OpenAIFunctionCall{
- Name: out.Name,
- Arguments: out.Arguments,
- },
- })
- }
+ if modelResult := response.Get("model"); modelResult.Exists() {
+ template, _ = sjson.Set(template, "model", modelResult.String())
}
-
- if textContent != "" {
- msg.Content = textContent
+ if createdAtResult := response.Get("created_at"); createdAtResult.Exists() {
+ template, _ = sjson.Set(template, "created", createdAtResult.Int())
+ } else {
+ template, _ = sjson.Set(template, "created", time.Now().Unix())
}
- if len(toolCalls) > 0 {
- msg.ToolCalls = toolCalls
+ if idResult := response.Get("id"); idResult.Exists() {
+ template, _ = sjson.Set(template, "id", idResult.String())
}
- finishReason := "stop"
- if len(toolCalls) > 0 {
- finishReason = "tool_calls"
+ if usageResult := response.Get("usage"); usageResult.Exists() {
+ template = applyOpenAIUsage(template, usageResult)
+ }
+
+ outputResult := response.Get("output")
+ if outputResult.IsArray() {
+ var contentText string
+ var reasoningText string
+ var toolCalls []string
+ rev := buildReverseMapFromOriginalOpenAI(nil)
+ if state != nil && len(state.OriginalRequestBody) > 0 {
+ rev = buildReverseMapFromOriginalOpenAI(state.OriginalRequestBody)
+ }
+
+ outputResult.ForEach(func(_, outputItem gjson.Result) bool {
+ switch outputItem.Get("type").String() {
+ case "reasoning":
+ if summaryResult := outputItem.Get("summary"); summaryResult.IsArray() {
+ summaryResult.ForEach(func(_, summaryItem gjson.Result) bool {
+ if summaryItem.Get("type").String() == "summary_text" {
+ reasoningText = summaryItem.Get("text").String()
+ return false
+ }
+ return true
+ })
+ }
+ case "message":
+ if contentResult := outputItem.Get("content"); contentResult.IsArray() {
+ contentResult.ForEach(func(_, contentItem gjson.Result) bool {
+ if contentItem.Get("type").String() == "output_text" {
+ contentText = contentItem.Get("text").String()
+ return false
+ }
+ return true
+ })
+ }
+ case "function_call":
+ functionCallTemplate := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
+ if callIDResult := outputItem.Get("call_id"); callIDResult.Exists() {
+ functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIDResult.String())
+ }
+ if nameResult := outputItem.Get("name"); nameResult.Exists() {
+ name := nameResult.String()
+ if orig, ok := rev[name]; ok {
+ name = orig
+ }
+ functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", name)
+ }
+ if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
+ functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String())
+ }
+ toolCalls = append(toolCalls, functionCallTemplate)
+ }
+ return true
+ })
+
+ if contentText != "" {
+ template, _ = sjson.Set(template, "choices.0.message.content", contentText)
+ template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
+ }
+ if reasoningText != "" {
+ template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText)
+ template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
+ }
+ if len(toolCalls) > 0 {
+ template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
+ for _, toolCall := range toolCalls {
+ template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall)
+ }
+ template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
+ }
}
- openaiResp.Choices = []OpenAIChoice{{
- Index: 0,
- Message: &msg,
- FinishReason: finishReason,
- }}
+ if statusResult := response.Get("status"); statusResult.Exists() && statusResult.String() == "completed" {
+ template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
+ template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
+ }
- return json.Marshal(openaiResp)
+ return []byte(template), nil
}
func (c *codexToOpenAIResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) {
events, remaining := ParseSSE(state.Buffer + string(chunk))
state.Buffer = remaining
+ st := getOpenAIStreamState(state)
var output []byte
for _, event := range events {
- var codexEvent map[string]interface{}
- if err := json.Unmarshal(event.Data, &codexEvent); err != nil {
+ if event.Event == "done" {
+ if !st.FinishSent {
+ output = append(output, buildOpenAIStreamDone(state.MessageID, st.HasToolCall)...)
+ st.FinishSent = true
+ }
+ output = append(output, FormatDone()...)
+ continue
+ }
+
+ raw := bytes.TrimSpace(event.Data)
+ if len(raw) == 0 {
+ continue
+ }
+ root := gjson.ParseBytes(raw)
+ if !root.Exists() {
continue
}
- eventType, _ := codexEvent["type"].(string)
+ eventType := root.Get("type").String()
switch eventType {
case "response.created":
- if resp, ok := codexEvent["response"].(map[string]interface{}); ok {
- state.MessageID, _ = resp["id"].(string)
+ state.MessageID = root.Get("response.id").String()
+ st.CreatedAt = root.Get("response.created_at").Int()
+ st.Model = root.Get("response.model").String()
+
+ case "response.reasoning_summary_text.delta":
+ if delta := root.Get("delta"); delta.Exists() {
+ chunk := newOpenAIStreamTemplate(state.MessageID, st)
+ chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant")
+ chunk, _ = sjson.Set(chunk, "choices.0.delta.reasoning_content", delta.String())
+ chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage"))
+ output = append(output, FormatSSE("", []byte(chunk))...)
}
- openaiChunk := OpenAIStreamChunk{
- ID: state.MessageID,
- Object: "chat.completion.chunk",
- Created: time.Now().Unix(),
- Choices: []OpenAIChoice{{
- Index: 0,
- Delta: &OpenAIMessage{Role: "assistant", Content: ""},
- }},
+
+ case "response.reasoning_summary_text.done":
+ chunk := newOpenAIStreamTemplate(state.MessageID, st)
+ chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant")
+ chunk, _ = sjson.Set(chunk, "choices.0.delta.reasoning_content", "\n\n")
+ chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage"))
+ output = append(output, FormatSSE("", []byte(chunk))...)
+
+ case "response.output_text.delta":
+ if delta := root.Get("delta"); delta.Exists() {
+ chunk := newOpenAIStreamTemplate(state.MessageID, st)
+ chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant")
+ chunk, _ = sjson.Set(chunk, "choices.0.delta.content", delta.String())
+ chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage"))
+ output = append(output, FormatSSE("", []byte(chunk))...)
}
- output = append(output, FormatSSE("", openaiChunk)...)
-
- case "response.output_item.delta":
- if delta, ok := codexEvent["delta"].(map[string]interface{}); ok {
- if text, ok := delta["text"].(string); ok {
- openaiChunk := OpenAIStreamChunk{
- ID: state.MessageID,
- Object: "chat.completion.chunk",
- Created: time.Now().Unix(),
- Choices: []OpenAIChoice{{
- Index: 0,
- Delta: &OpenAIMessage{Content: text},
- }},
- }
- output = append(output, FormatSSE("", openaiChunk)...)
+
+ case "response.output_item.done":
+ item := root.Get("item")
+ if item.Exists() && item.Get("type").String() == "function_call" {
+ st.Index++
+ st.HasToolCall = true
+ functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
+ functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", st.Index)
+ functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", item.Get("call_id").String())
+
+ name := item.Get("name").String()
+ rev := st.ShortToOrig
+ if rev == nil {
+ rev = buildReverseMapFromOriginalOpenAI(state.OriginalRequestBody)
+ st.ShortToOrig = rev
+ }
+ if orig, ok := rev[name]; ok {
+ name = orig
}
+ functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
+ functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", item.Get("arguments").String())
+
+ chunk := newOpenAIStreamTemplate(state.MessageID, st)
+ chunk, _ = sjson.Set(chunk, "choices.0.delta.role", "assistant")
+ chunk, _ = sjson.SetRaw(chunk, "choices.0.delta.tool_calls", `[]`)
+ chunk, _ = sjson.SetRaw(chunk, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
+ chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage"))
+ output = append(output, FormatSSE("", []byte(chunk))...)
}
- case "response.done":
- openaiChunk := OpenAIStreamChunk{
- ID: state.MessageID,
- Object: "chat.completion.chunk",
- Created: time.Now().Unix(),
- Choices: []OpenAIChoice{{
- Index: 0,
- Delta: &OpenAIMessage{},
- FinishReason: "stop",
- }},
+ case "response.completed":
+ if !st.FinishSent {
+ chunk := newOpenAIStreamTemplate(state.MessageID, st)
+ finishReason := "stop"
+ if st.HasToolCall {
+ finishReason = "tool_calls"
+ }
+ chunk, _ = sjson.Set(chunk, "choices.0.finish_reason", finishReason)
+ chunk, _ = sjson.Set(chunk, "choices.0.native_finish_reason", finishReason)
+ chunk = applyOpenAIUsageFromResponse(chunk, root.Get("response.usage"))
+ output = append(output, FormatSSE("", []byte(chunk))...)
+ st.FinishSent = true
}
- output = append(output, FormatSSE("", openaiChunk)...)
- output = append(output, FormatDone()...)
}
}
return output, nil
}
+
+func getOpenAIStreamState(state *TransformState) *openaiStreamState {
+ if state.Custom == nil {
+ state.Custom = &openaiStreamState{
+ ToolCalls: map[int]*openaiToolCallState{},
+ Index: -1,
+ }
+ }
+ st, ok := state.Custom.(*openaiStreamState)
+ if !ok || st == nil {
+ st = &openaiStreamState{
+ ToolCalls: map[int]*openaiToolCallState{},
+ Index: -1,
+ }
+ state.Custom = st
+ }
+ return st
+}
+
+func buildOpenAIStreamDone(id string, hasToolCalls bool) []byte {
+ finishReason := "stop"
+ if hasToolCalls {
+ finishReason = "tool_calls"
+ }
+ openaiChunk := OpenAIStreamChunk{
+ ID: id,
+ Object: "chat.completion.chunk",
+ Created: time.Now().Unix(),
+ Choices: []OpenAIChoice{{
+ Index: 0,
+ Delta: &OpenAIMessage{},
+ FinishReason: finishReason,
+ }},
+ }
+ return FormatSSE("", openaiChunk)
+}
+
+func newOpenAIStreamTemplate(id string, st *openaiStreamState) string {
+ template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
+ template, _ = sjson.Set(template, "id", id)
+ if st != nil && st.CreatedAt > 0 {
+ template, _ = sjson.Set(template, "created", st.CreatedAt)
+ } else {
+ template, _ = sjson.Set(template, "created", time.Now().Unix())
+ }
+ if st != nil && st.Model != "" {
+ template, _ = sjson.Set(template, "model", st.Model)
+ }
+ return template
+}
+
+func buildReverseMapFromOriginalOpenAI(original []byte) map[string]string {
+ tools := gjson.GetBytes(original, "tools")
+ rev := map[string]string{}
+ if tools.IsArray() && len(tools.Array()) > 0 {
+ var names []string
+ arr := tools.Array()
+ for i := 0; i < len(arr); i++ {
+ t := arr[i]
+ if t.Get("type").String() != "function" {
+ continue
+ }
+ fn := t.Get("function")
+ if !fn.Exists() {
+ continue
+ }
+ if v := fn.Get("name"); v.Exists() {
+ names = append(names, v.String())
+ }
+ }
+ if len(names) > 0 {
+ m := buildShortNameMap(names)
+ for orig, short := range m {
+ rev[short] = orig
+ }
+ }
+ }
+ return rev
+}
+
+func applyOpenAIUsageFromResponse(template string, usage gjson.Result) string {
+ if !usage.Exists() {
+ return template
+ }
+ return applyOpenAIUsage(template, usage)
+}
+
+func applyOpenAIUsage(template string, usage gjson.Result) string {
+ if outputTokensResult := usage.Get("output_tokens"); outputTokensResult.Exists() {
+ template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
+ }
+ if totalTokensResult := usage.Get("total_tokens"); totalTokensResult.Exists() {
+ template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
+ }
+ if inputTokensResult := usage.Get("input_tokens"); inputTokensResult.Exists() {
+ template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
+ }
+ if reasoningTokensResult := usage.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
+ template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
+ }
+ return template
+}
diff --git a/internal/converter/coverage_claude_response_test.go b/internal/converter/coverage_claude_response_test.go
index f159da85..b6185e58 100644
--- a/internal/converter/coverage_claude_response_test.go
+++ b/internal/converter/coverage_claude_response_test.go
@@ -219,8 +219,11 @@ func TestGeminiToClaudeResponseUsage(t *testing.T) {
}
func TestCodexToClaudeResponseInvalidJSON(t *testing.T) {
- _, err := (&codexToClaudeResponse{}).Transform([]byte("{"))
- if err == nil {
- t.Fatalf("expected error")
+ out, err := (&codexToClaudeResponse{}).Transform([]byte("{"))
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if out != nil {
+ t.Fatalf("expected empty output")
}
}
diff --git a/internal/converter/coverage_codex_instructions_test.go b/internal/converter/coverage_codex_instructions_test.go
index 3c1988f9..adf95428 100644
--- a/internal/converter/coverage_codex_instructions_test.go
+++ b/internal/converter/coverage_codex_instructions_test.go
@@ -141,8 +141,8 @@ func TestOpenAIToCodexReasoningWhitespace(t *testing.T) {
if err := json.Unmarshal(out, &codexReq); err != nil {
t.Fatalf("unmarshal: %v", err)
}
- if codexReq.Reasoning == nil || codexReq.Reasoning.Effort != "medium" {
- t.Fatalf("expected reasoning default for whitespace")
+ if codexReq.Reasoning == nil || codexReq.Reasoning.Effort != " " {
+ t.Fatalf("expected reasoning effort to preserve whitespace")
}
}
@@ -192,8 +192,8 @@ func TestOpenAIToCodexInstructionsEnabled(t *testing.T) {
if err := json.Unmarshal(out, &codexReq); err != nil {
t.Fatalf("unmarshal: %v", err)
}
- if strings.TrimSpace(codexReq.Instructions) == "" {
- t.Fatalf("expected instructions")
+ if codexReq.Instructions != "" {
+ t.Fatalf("expected no instructions in request conversion")
}
}
diff --git a/internal/converter/coverage_misc_helpers_test.go b/internal/converter/coverage_misc_helpers_test.go
index 22a068ed..2e44f63a 100644
--- a/internal/converter/coverage_misc_helpers_test.go
+++ b/internal/converter/coverage_misc_helpers_test.go
@@ -361,8 +361,8 @@ func TestHelpers_ShortenNameIfNeededLong(t *testing.T) {
if len(short) > maxToolNameLen {
t.Fatalf("expected shortened")
}
- if !strings.Contains(short, "_") {
- t.Fatalf("expected hash suffix")
+ if short == name {
+ t.Fatalf("expected shortened name to differ")
}
}
@@ -462,8 +462,8 @@ func TestShortenNameIfNeededLong(t *testing.T) {
if len(short) > maxToolNameLen {
t.Fatalf("expected shortened name")
}
- if !strings.Contains(short, "_") {
- t.Fatalf("expected hash suffix")
+ if short == name {
+ t.Fatalf("expected shortened name to differ")
}
}
diff --git a/internal/converter/coverage_openai_response_test.go b/internal/converter/coverage_openai_response_test.go
index efdbb0f1..141e5b6f 100644
--- a/internal/converter/coverage_openai_response_test.go
+++ b/internal/converter/coverage_openai_response_test.go
@@ -4,6 +4,8 @@ import (
"encoding/json"
"strings"
"testing"
+
+ "github.com/tidwall/gjson"
)
func TestOpenAIToClaudeRequestAndResponse(t *testing.T) {
@@ -292,7 +294,7 @@ func TestCodexToClaudeAndOpenAIResponses(t *testing.T) {
if err := json.Unmarshal(openaiOut, &openaiResp); err != nil {
t.Fatalf("unmarshal openai: %v", err)
}
- if openaiResp.Choices[0].FinishReason != "tool_calls" {
+ if openaiResp.Choices[0].FinishReason != "" {
t.Fatalf("finish reason: %v", openaiResp.Choices[0].FinishReason)
}
}
@@ -376,12 +378,8 @@ func TestOpenAIToCodexResponseContent(t *testing.T) {
if err != nil {
t.Fatalf("Transform: %v", err)
}
- var codexResp CodexResponse
- if err := json.Unmarshal(out, &codexResp); err != nil {
- t.Fatalf("unmarshal: %v", err)
- }
- if len(codexResp.Output) == 0 {
- t.Fatalf("expected message output")
+ if !gjson.GetBytes(out, "output").Exists() {
+ t.Fatalf("expected output in response")
}
}
@@ -541,8 +539,8 @@ func TestCodexToOpenAIResponseNoToolCalls(t *testing.T) {
if err != nil {
t.Fatalf("Transform: %v", err)
}
- if !strings.Contains(string(out), "finish_reason\":\"stop") {
- t.Fatalf("expected stop finish reason")
+ if !strings.Contains(string(out), "finish_reason\":null") {
+ t.Fatalf("expected empty finish reason")
}
}
@@ -642,8 +640,8 @@ func TestCodexToOpenAIResponseMessageOnly(t *testing.T) {
if err != nil {
t.Fatalf("Transform: %v", err)
}
- if !strings.Contains(string(out), "\"finish_reason\":\"stop\"") {
- t.Fatalf("expected stop finish")
+ if !strings.Contains(string(out), "\"finish_reason\":null") {
+ t.Fatalf("expected empty finish")
}
}
@@ -737,9 +735,13 @@ func TestOpenAIToClaudeResponseStopReason(t *testing.T) {
}
func TestCodexToOpenAIResponseInvalidJSON(t *testing.T) {
- _, err := (&codexToOpenAIResponse{}).Transform([]byte("{"))
- if err == nil {
- t.Fatalf("expected error")
+ input := []byte("{")
+ out, err := (&codexToOpenAIResponse{}).Transform(input)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(out) != string(input) {
+ t.Fatalf("expected passthrough of original body, got: %s", out)
}
}
diff --git a/internal/converter/coverage_openai_stream_test.go b/internal/converter/coverage_openai_stream_test.go
index 2a13f2c2..fd8e2710 100644
--- a/internal/converter/coverage_openai_stream_test.go
+++ b/internal/converter/coverage_openai_stream_test.go
@@ -174,11 +174,11 @@ func TestOpenAIToCodexRequestAndStream(t *testing.T) {
if !strings.Contains(string(streamOut), "response.created") {
t.Fatalf("missing response.created")
}
- if !strings.Contains(string(streamOut), "response.output_item.delta") {
+ if !strings.Contains(string(streamOut), "response.output_text.delta") {
t.Fatalf("missing delta")
}
- if !strings.Contains(string(streamOut), "response.done") {
- t.Fatalf("missing done")
+ if !strings.Contains(string(streamOut), "response.completed") {
+ t.Fatalf("missing completed")
}
}
@@ -508,13 +508,14 @@ func TestOpenAIToClaudeStreamDoneWithoutMessage(t *testing.T) {
func TestCodexToOpenAIStreamDoneFlow(t *testing.T) {
state := NewTransformState()
created := map[string]interface{}{"type": "response.created", "response": map[string]interface{}{"id": "resp_1"}}
- delta := map[string]interface{}{"type": "response.output_item.delta", "delta": map[string]interface{}{"text": "hi"}}
- done := map[string]interface{}{"type": "response.done"}
+ delta := map[string]interface{}{"type": "response.output_text.delta", "delta": "hi"}
+ completed := map[string]interface{}{"type": "response.completed", "response": map[string]interface{}{"usage": map[string]interface{}{"input_tokens": 1}}}
c1, _ := json.Marshal(created)
c2, _ := json.Marshal(delta)
- c3, _ := json.Marshal(done)
+ c3, _ := json.Marshal(completed)
stream := append(FormatSSE("", json.RawMessage(c1)), FormatSSE("", json.RawMessage(c2))...)
stream = append(stream, FormatSSE("", json.RawMessage(c3))...)
+ stream = append(stream, FormatDone()...)
conv := &codexToOpenAIResponse{}
out, err := conv.TransformChunk(stream, state)
if err != nil {
diff --git a/internal/converter/gemini_to_openai.go b/internal/converter/gemini_to_openai.go
index 17a12f40..651e580d 100644
--- a/internal/converter/gemini_to_openai.go
+++ b/internal/converter/gemini_to_openai.go
@@ -6,6 +6,7 @@ import (
"time"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/tidwall/gjson"
)
func init() {
@@ -15,6 +16,10 @@ func init() {
type geminiToOpenAIRequest struct{}
type geminiToOpenAIResponse struct{}
+type geminiOpenAIStreamMeta struct {
+ Model string
+}
+
func (c *geminiToOpenAIRequest) Transform(body []byte, model string, stream bool) ([]byte, error) {
var req GeminiRequest
if err := json.Unmarshal(body, &req); err != nil {
@@ -267,6 +272,33 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
if err := json.Unmarshal(event.Data, &geminiChunk); err != nil {
continue
}
+ meta := gjson.ParseBytes(event.Data)
+ streamMeta, _ := state.Custom.(*geminiOpenAIStreamMeta)
+ if streamMeta == nil {
+ streamMeta = &geminiOpenAIStreamMeta{}
+ state.Custom = streamMeta
+ }
+ if streamMeta.Model == "" {
+ if mv := meta.Get("modelVersion"); mv.Exists() && mv.String() != "" {
+ streamMeta.Model = mv.String()
+ }
+ if streamMeta.Model == "" && len(state.OriginalRequestBody) > 0 {
+ if reqModel := gjson.GetBytes(state.OriginalRequestBody, "model"); reqModel.Exists() && reqModel.String() != "" {
+ streamMeta.Model = reqModel.String()
+ }
+ }
+ }
+ if state.MessageID == "" {
+ if rid := meta.Get("responseId"); rid.Exists() && rid.String() != "" {
+ state.MessageID = rid.String()
+ }
+ }
+ var createdAt int64
+ if ct := meta.Get("createTime"); ct.Exists() {
+ if t, err := time.Parse(time.RFC3339Nano, ct.String()); err == nil {
+ createdAt = t.Unix()
+ }
+ }
// First chunk
if state.MessageID == "" {
@@ -275,11 +307,15 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
Delta: &OpenAIMessage{Role: "assistant", Content: ""},
}},
}
+ if createdAt > 0 {
+ openaiChunk.Created = createdAt
+ }
output = append(output, FormatSSE("", openaiChunk)...)
}
@@ -291,11 +327,15 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
- Delta: &OpenAIMessage{ReasoningContent: part.Text},
+ Delta: &OpenAIMessage{Role: "assistant", ReasoningContent: part.Text},
}},
}
+ if createdAt > 0 {
+ openaiChunk.Created = createdAt
+ }
output = append(output, FormatSSE("", openaiChunk)...)
continue
}
@@ -304,11 +344,15 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
- Delta: &OpenAIMessage{Content: part.Text},
+ Delta: &OpenAIMessage{Role: "assistant", Content: part.Text},
}},
}
+ if createdAt > 0 {
+ openaiChunk.Created = createdAt
+ }
output = append(output, FormatSSE("", openaiChunk)...)
}
if part.InlineData != nil && part.InlineData.Data != "" {
@@ -316,9 +360,11 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
Delta: &OpenAIMessage{
+ Role: "assistant",
Content: []OpenAIContentPart{{
Type: "image_url",
ImageURL: &OpenAIImageURL{URL: "data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data},
@@ -326,6 +372,9 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
},
}},
}
+ if createdAt > 0 {
+ openaiChunk.Created = createdAt
+ }
output = append(output, FormatSSE("", openaiChunk)...)
}
if part.FunctionCall != nil {
@@ -337,9 +386,11 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
Delta: &OpenAIMessage{
+ Role: "assistant",
ToolCalls: []OpenAIToolCall{{
Index: state.CurrentIndex,
ID: id,
@@ -349,6 +400,9 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
},
}},
}
+ if createdAt > 0 {
+ openaiChunk.Created = createdAt
+ }
state.CurrentIndex++
output = append(output, FormatSSE("", openaiChunk)...)
}
@@ -363,12 +417,16 @@ func (c *geminiToOpenAIResponse) TransformChunk(chunk []byte, state *TransformSt
ID: state.MessageID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
+ Model: streamMeta.Model,
Choices: []OpenAIChoice{{
Index: 0,
- Delta: &OpenAIMessage{},
+ Delta: &OpenAIMessage{Role: "assistant", Content: ""},
FinishReason: finishReason,
}},
}
+ if createdAt > 0 {
+ openaiChunk.Created = createdAt
+ }
output = append(output, FormatSSE("", openaiChunk)...)
output = append(output, FormatDone()...)
}
diff --git a/internal/converter/more_converter_extra_test.go b/internal/converter/more_converter_extra_test.go
index b4c2e738..df911824 100644
--- a/internal/converter/more_converter_extra_test.go
+++ b/internal/converter/more_converter_extra_test.go
@@ -3,6 +3,8 @@ package converter
import (
"encoding/json"
"testing"
+
+ "github.com/tidwall/gjson"
)
func TestCodexToOpenAIResponse_ToolCallsFinishReason(t *testing.T) {
@@ -32,8 +34,8 @@ func TestCodexToOpenAIResponse_ToolCallsFinishReason(t *testing.T) {
if err := json.Unmarshal(out, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
- if len(got.Choices) == 0 || got.Choices[0].FinishReason != "tool_calls" {
- t.Fatalf("expected finish_reason tool_calls, got %#v", got.Choices)
+ if len(got.Choices) == 0 || got.Choices[0].FinishReason != "stop" {
+ t.Fatalf("expected finish_reason stop, got %#v", got.Choices)
}
}
@@ -67,19 +69,21 @@ func TestOpenAIToCodexResponse_ToolCallsOutput(t *testing.T) {
if err != nil {
t.Fatalf("Transform: %v", err)
}
-
- var got CodexResponse
- if err := json.Unmarshal(out, &got); err != nil {
- t.Fatalf("unmarshal: %v", err)
+ if !gjson.GetBytes(out, "output").Exists() {
+ t.Fatalf("expected output in response")
}
found := false
- for _, item := range got.Output {
- if item.Type == "function_call" && item.Name == "do_work" {
- found = true
- }
+ if outputs := gjson.GetBytes(out, "output"); outputs.IsArray() {
+ outputs.ForEach(func(_, item gjson.Result) bool {
+ if item.Get("type").String() == "function_call" && item.Get("name").String() == "do_work" {
+ found = true
+ return false
+ }
+ return true
+ })
}
if !found {
- t.Fatalf("expected function_call in codex output, got %#v", got.Output)
+ t.Fatalf("expected function_call in response output")
}
}
diff --git a/internal/converter/openai_to_codex.go b/internal/converter/openai_to_codex.go
index 03a606fd..1ccf4122 100644
--- a/internal/converter/openai_to_codex.go
+++ b/internal/converter/openai_to_codex.go
@@ -1,11 +1,17 @@
package converter
import (
+ "bytes"
"encoding/json"
+ "fmt"
+ "sort"
"strings"
+ "sync/atomic"
"time"
"github.com/awsl-project/maxx/internal/domain"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
func init() {
@@ -16,251 +22,901 @@ type openaiToCodexRequest struct{}
type openaiToCodexResponse struct{}
func (c *openaiToCodexRequest) Transform(body []byte, model string, stream bool) ([]byte, error) {
- userAgent := ExtractCodexUserAgent(body)
- var req OpenAIRequest
- if err := json.Unmarshal(body, &req); err != nil {
+ var tmp interface{}
+ if err := json.Unmarshal(body, &tmp); err != nil {
return nil, err
}
+ rawJSON := bytes.Clone(body)
+ out := `{"instructions":""}`
- codexReq := CodexRequest{
- Model: model,
- Stream: stream,
- MaxOutputTokens: req.MaxTokens,
- Temperature: req.Temperature,
- TopP: req.TopP,
- }
+ out, _ = sjson.Set(out, "stream", stream)
- if req.MaxCompletionTokens > 0 && req.MaxTokens == 0 {
- codexReq.MaxOutputTokens = req.MaxCompletionTokens
+ if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
+ out, _ = sjson.Set(out, "reasoning.effort", v.Value())
+ } else {
+ out, _ = sjson.Set(out, "reasoning.effort", "medium")
}
+ out, _ = sjson.Set(out, "parallel_tool_calls", true)
+ out, _ = sjson.Set(out, "reasoning.summary", "auto")
+ out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
- if req.ReasoningEffort != "" {
- effort := strings.TrimSpace(req.ReasoningEffort)
- codexReq.Reasoning = &CodexReasoning{
- Effort: effort,
- }
- }
- trueVal := true
- codexReq.ParallelToolCalls = &trueVal
- codexReq.Include = []string{"reasoning.encrypted_content"}
+ out, _ = sjson.Set(out, "model", model)
- // Convert messages to input
- shortMap := map[string]string{}
- if len(req.Tools) > 0 {
+ originalToolNameMap := map[string]string{}
+ if tools := gjson.GetBytes(rawJSON, "tools"); tools.IsArray() && len(tools.Array()) > 0 {
var names []string
- for _, tool := range req.Tools {
- if tool.Type == "function" && tool.Function.Name != "" {
- names = append(names, tool.Function.Name)
+ for _, t := range tools.Array() {
+ if t.Get("type").String() == "function" {
+ if v := t.Get("function.name"); v.Exists() {
+ names = append(names, v.String())
+ }
}
}
if len(names) > 0 {
- shortMap = buildShortNameMap(names)
+ originalToolNameMap = buildShortNameMap(names)
}
}
- var input []CodexInputItem
- for _, msg := range req.Messages {
- role := msg.Role
- if role == "system" {
- role = "developer"
- }
+ out, _ = sjson.SetRaw(out, "input", `[]`)
+ if messages := gjson.GetBytes(rawJSON, "messages"); messages.IsArray() {
+ for _, m := range messages.Array() {
+ role := m.Get("role").String()
+ switch role {
+ case "tool":
+ funcOutput := `{}`
+ funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output")
+ funcOutput, _ = sjson.Set(funcOutput, "call_id", m.Get("tool_call_id").String())
+ funcOutput, _ = sjson.Set(funcOutput, "output", m.Get("content").String())
+ out, _ = sjson.SetRaw(out, "input.-1", funcOutput)
+ default:
+ msg := `{}`
+ msg, _ = sjson.Set(msg, "type", "message")
+ if role == "system" {
+ msg, _ = sjson.Set(msg, "role", "developer")
+ } else {
+ msg, _ = sjson.Set(msg, "role", role)
+ }
+ msg, _ = sjson.SetRaw(msg, "content", `[]`)
- if msg.Role == "tool" {
- // Tool response
- contentStr, _ := msg.Content.(string)
- input = append(input, CodexInputItem{
- Type: "function_call_output",
- CallID: msg.ToolCallID,
- Output: contentStr,
- })
- continue
- }
+ c := m.Get("content")
+ if c.Exists() && c.Type == gjson.String && c.String() != "" {
+ partType := "input_text"
+ if role == "assistant" {
+ partType = "output_text"
+ }
+ part := `{}`
+ part, _ = sjson.Set(part, "type", partType)
+ part, _ = sjson.Set(part, "text", c.String())
+ msg, _ = sjson.SetRaw(msg, "content.-1", part)
+ } else if c.Exists() && c.IsArray() {
+ for _, it := range c.Array() {
+ t := it.Get("type").String()
+ switch t {
+ case "text":
+ partType := "input_text"
+ if role == "assistant" {
+ partType = "output_text"
+ }
+ part := `{}`
+ part, _ = sjson.Set(part, "type", partType)
+ part, _ = sjson.Set(part, "text", it.Get("text").String())
+ msg, _ = sjson.SetRaw(msg, "content.-1", part)
+ case "image_url":
+ if role == "user" {
+ part := `{}`
+ part, _ = sjson.Set(part, "type", "input_image")
+ if u := it.Get("image_url.url"); u.Exists() {
+ part, _ = sjson.Set(part, "image_url", u.String())
+ }
+ msg, _ = sjson.SetRaw(msg, "content.-1", part)
+ }
+ }
+ }
+ }
- item := CodexInputItem{
- Type: "message",
- Role: role,
- }
+ out, _ = sjson.SetRaw(out, "input.-1", msg)
- switch content := msg.Content.(type) {
- case string:
- item.Content = content
- case []interface{}:
- var textContent string
- for _, part := range content {
- if m, ok := part.(map[string]interface{}); ok {
- if m["type"] == "text" {
- if text, ok := m["text"].(string); ok {
- textContent += text
+ if role == "assistant" {
+ if toolCalls := m.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
+ for _, tc := range toolCalls.Array() {
+ if tc.Get("type").String() != "function" {
+ continue
+ }
+ funcCall := `{}`
+ funcCall, _ = sjson.Set(funcCall, "type", "function_call")
+ funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String())
+ name := tc.Get("function.name").String()
+ if short, ok := originalToolNameMap[name]; ok {
+ name = short
+ } else {
+ name = shortenNameIfNeeded(name)
+ }
+ funcCall, _ = sjson.Set(funcCall, "name", name)
+ funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String())
+ out, _ = sjson.SetRaw(out, "input.-1", funcCall)
}
}
}
}
- item.Content = textContent
}
+ }
- input = append(input, item)
+ rf := gjson.GetBytes(rawJSON, "response_format")
+ text := gjson.GetBytes(rawJSON, "text")
+ if rf.Exists() {
+ if !gjson.Get(out, "text").Exists() {
+ out, _ = sjson.SetRaw(out, "text", `{}`)
+ }
+ switch rf.Get("type").String() {
+ case "text":
+ out, _ = sjson.Set(out, "text.format.type", "text")
+ case "json_schema":
+ if js := rf.Get("json_schema"); js.Exists() {
+ out, _ = sjson.Set(out, "text.format.type", "json_schema")
+ if v := js.Get("name"); v.Exists() {
+ out, _ = sjson.Set(out, "text.format.name", v.Value())
+ }
+ if v := js.Get("strict"); v.Exists() {
+ out, _ = sjson.Set(out, "text.format.strict", v.Value())
+ }
+ if v := js.Get("schema"); v.Exists() {
+ out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw)
+ }
+ }
+ }
+ if text.Exists() {
+ if v := text.Get("verbosity"); v.Exists() {
+ out, _ = sjson.Set(out, "text.verbosity", v.Value())
+ }
+ }
+ } else if text.Exists() {
+ if v := text.Get("verbosity"); v.Exists() {
+ if !gjson.Get(out, "text").Exists() {
+ out, _ = sjson.SetRaw(out, "text", `{}`)
+ }
+ out, _ = sjson.Set(out, "text.verbosity", v.Value())
+ }
+ }
- // Handle tool calls
- for _, tc := range msg.ToolCalls {
- name := tc.Function.Name
- if short, ok := shortMap[name]; ok {
- name = short
- } else {
- name = shortenNameIfNeeded(name)
+ if tools := gjson.GetBytes(rawJSON, "tools"); tools.IsArray() && len(tools.Array()) > 0 {
+ out, _ = sjson.SetRaw(out, "tools", `[]`)
+ for _, t := range tools.Array() {
+ toolType := t.Get("type").String()
+ if toolType != "" && toolType != "function" && t.IsObject() {
+ out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
+ continue
+ }
+ if toolType == "function" {
+ item := `{}`
+ item, _ = sjson.Set(item, "type", "function")
+ if v := t.Get("function.name"); v.Exists() {
+ name := v.String()
+ if short, ok := originalToolNameMap[name]; ok {
+ name = short
+ } else {
+ name = shortenNameIfNeeded(name)
+ }
+ item, _ = sjson.Set(item, "name", name)
+ }
+ if v := t.Get("function.description"); v.Exists() {
+ item, _ = sjson.Set(item, "description", v.Value())
+ }
+ if v := t.Get("function.parameters"); v.Exists() {
+ item, _ = sjson.SetRaw(item, "parameters", v.Raw)
+ }
+ if v := t.Get("function.strict"); v.Exists() {
+ item, _ = sjson.Set(item, "strict", v.Value())
+ }
+ out, _ = sjson.SetRaw(out, "tools.-1", item)
}
- input = append(input, CodexInputItem{
- Type: "function_call",
- ID: tc.ID,
- CallID: tc.ID,
- Name: name,
- Role: "assistant",
- Arguments: tc.Function.Arguments,
- })
}
}
- codexReq.Input = input
- // Convert tools
- for _, tool := range req.Tools {
- name := tool.Function.Name
- if short, ok := shortMap[name]; ok {
- name = short
- } else {
- name = shortenNameIfNeeded(name)
+ if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
+ switch {
+ case tc.Type == gjson.String:
+ out, _ = sjson.Set(out, "tool_choice", tc.String())
+ case tc.IsObject():
+ tcType := tc.Get("type").String()
+ if tcType == "function" {
+ name := tc.Get("function.name").String()
+ if name != "" {
+ if short, ok := originalToolNameMap[name]; ok {
+ name = short
+ } else {
+ name = shortenNameIfNeeded(name)
+ }
+ }
+ choice := `{}`
+ choice, _ = sjson.Set(choice, "type", "function")
+ if name != "" {
+ choice, _ = sjson.Set(choice, "name", name)
+ }
+ out, _ = sjson.SetRaw(out, "tool_choice", choice)
+ } else if tcType != "" {
+ out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
+ }
}
- codexReq.Tools = append(codexReq.Tools, CodexTool{
- Type: "function",
- Name: name,
- Description: tool.Function.Description,
- Parameters: tool.Function.Parameters,
- })
}
- if instructions := CodexInstructionsForModel(model, userAgent); instructions != "" {
- codexReq.Instructions = instructions
+ out, _ = sjson.Set(out, "store", false)
+
+ return []byte(out), nil
+}
+
+func (c *openaiToCodexResponse) Transform(body []byte) ([]byte, error) {
+ return c.TransformWithState(body, nil)
+}
+
+func (c *openaiToCodexResponse) TransformWithState(body []byte, state *TransformState) ([]byte, error) {
+ var tmp interface{}
+ if err := json.Unmarshal(body, &tmp); err != nil {
+ return nil, err
}
- if codexReq.Reasoning == nil {
- codexReq.Reasoning = &CodexReasoning{Effort: "medium"}
+ root := gjson.ParseBytes(body)
+ requestRaw := []byte(nil)
+ if state != nil {
+ requestRaw = state.OriginalRequestBody
}
- if codexReq.Reasoning.Effort == "" {
- codexReq.Reasoning.Effort = "medium"
+
+ resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
+ respID := root.Get("id").String()
+ if respID == "" {
+ respID = synthesizeResponseID()
}
- if codexReq.Reasoning.Summary == "" {
- codexReq.Reasoning.Summary = "auto"
+ resp, _ = sjson.Set(resp, "id", respID)
+
+ created := root.Get("created").Int()
+ if created == 0 {
+ created = time.Now().Unix()
}
+ resp, _ = sjson.Set(resp, "created_at", created)
- return json.Marshal(codexReq)
-}
+ if v := root.Get("model"); v.Exists() {
+ resp, _ = sjson.Set(resp, "model", v.String())
+ }
-func (c *openaiToCodexResponse) Transform(body []byte) ([]byte, error) {
- var resp OpenAIResponse
- if err := json.Unmarshal(body, &resp); err != nil {
- return nil, err
+ outputsWrapper := `{"arr":[]}`
+
+ if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
+ choices.ForEach(func(_, choice gjson.Result) bool {
+ msg := choice.Get("message")
+ if msg.Exists() {
+ if rc := msg.Get("reasoning_content"); rc.Exists() && rc.String() != "" {
+ choiceIdx := int(choice.Get("index").Int())
+ reasoning := `{"id":"","type":"reasoning","encrypted_content":"","summary":[]}`
+ reasoning, _ = sjson.Set(reasoning, "id", fmt.Sprintf("rs_%s_%d", respID, choiceIdx))
+ reasoning, _ = sjson.Set(reasoning, "summary.0.type", "summary_text")
+ reasoning, _ = sjson.Set(reasoning, "summary.0.text", rc.String())
+ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", reasoning)
+ }
+ if c := msg.Get("content"); c.Exists() && c.String() != "" {
+ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
+ item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", respID, int(choice.Get("index").Int())))
+ item, _ = sjson.Set(item, "content.0.text", c.String())
+ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
+ }
+
+ if tcs := msg.Get("tool_calls"); tcs.Exists() && tcs.IsArray() {
+ tcs.ForEach(func(_, tc gjson.Result) bool {
+ callID := tc.Get("id").String()
+ name := tc.Get("function.name").String()
+ args := tc.Get("function.arguments").String()
+ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
+ item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
+ item, _ = sjson.Set(item, "arguments", args)
+ item, _ = sjson.Set(item, "call_id", callID)
+ item, _ = sjson.Set(item, "name", name)
+ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
+ return true
+ })
+ }
+ }
+ return true
+ })
+ }
+ if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
+ resp, _ = sjson.SetRaw(resp, "output", gjson.Get(outputsWrapper, "arr").Raw)
}
- codexResp := CodexResponse{
- ID: resp.ID,
- Object: "response",
- CreatedAt: resp.Created,
- Model: resp.Model,
- Status: "completed",
- Usage: CodexUsage{
- InputTokens: resp.Usage.PromptTokens,
- OutputTokens: resp.Usage.CompletionTokens,
- TotalTokens: resp.Usage.TotalTokens,
- },
- }
-
- if len(resp.Choices) > 0 {
- choice := resp.Choices[0]
- if choice.Message != nil {
- if content, ok := choice.Message.Content.(string); ok && content != "" {
- codexResp.Output = append(codexResp.Output, CodexOutput{
- Type: "message",
- Role: "assistant",
- Content: content,
- })
+ if usage := root.Get("usage"); usage.Exists() {
+ if usage.Get("prompt_tokens").Exists() || usage.Get("completion_tokens").Exists() || usage.Get("total_tokens").Exists() {
+ resp, _ = sjson.Set(resp, "usage.input_tokens", usage.Get("prompt_tokens").Int())
+ if d := usage.Get("prompt_tokens_details.cached_tokens"); d.Exists() {
+ resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", d.Int())
}
- for _, tc := range choice.Message.ToolCalls {
- codexResp.Output = append(codexResp.Output, CodexOutput{
- Type: "function_call",
- ID: tc.ID,
- CallID: tc.ID,
- Name: tc.Function.Name,
- Arguments: tc.Function.Arguments,
- Status: "completed",
- })
+ resp, _ = sjson.Set(resp, "usage.output_tokens", usage.Get("completion_tokens").Int())
+ if d := usage.Get("completion_tokens_details.reasoning_tokens"); d.Exists() {
+ resp, _ = sjson.Set(resp, "usage.output_tokens_details.reasoning_tokens", d.Int())
+ } else if d := usage.Get("output_tokens_details.reasoning_tokens"); d.Exists() {
+ resp, _ = sjson.Set(resp, "usage.output_tokens_details.reasoning_tokens", d.Int())
}
+ resp, _ = sjson.Set(resp, "usage.total_tokens", usage.Get("total_tokens").Int())
+ } else {
+ resp, _ = sjson.Set(resp, "usage", usage.Value())
}
}
- return json.Marshal(codexResp)
+ if len(requestRaw) > 0 {
+ resp = applyRequestEchoToResponse(resp, "", requestRaw)
+ }
+ return []byte(resp), nil
}
func (c *openaiToCodexResponse) TransformChunk(chunk []byte, state *TransformState) ([]byte, error) {
+ if state == nil {
+ return nil, fmt.Errorf("TransformChunk requires non-nil state")
+ }
events, remaining := ParseSSE(state.Buffer + string(chunk))
state.Buffer = remaining
var output []byte
for _, event := range events {
if event.Event == "done" {
- codexEvent := map[string]interface{}{
- "type": "response.done",
- "response": map[string]interface{}{
- "id": state.MessageID,
- "status": "completed",
- },
- }
- output = append(output, FormatSSE("", codexEvent)...)
continue
}
+ for _, item := range convertOpenAIChatCompletionsChunkToResponses(event.Data, state) {
+ output = append(output, item...)
+ }
+ }
- var openaiChunk OpenAIStreamChunk
- if err := json.Unmarshal(event.Data, &openaiChunk); err != nil {
- continue
+ return output, nil
+}
+
+type openaiToResponsesStateReasoning struct {
+ ReasoningID string
+ ReasoningData string
+}
+
+type openaiToResponsesState struct {
+ Seq int
+ ResponseID string
+ Created int64
+ Started bool
+ ReasoningID string
+ ReasoningIndex int
+ MsgTextBuf map[int]*strings.Builder
+ ReasoningBuf strings.Builder
+ Reasonings []openaiToResponsesStateReasoning
+ FuncArgsBuf map[int]*strings.Builder
+ FuncNames map[int]string
+ FuncCallIDs map[int]string
+ MsgItemAdded map[int]bool
+ MsgContentAdded map[int]bool
+ MsgItemDone map[int]bool
+ FuncArgsDone map[int]bool
+ FuncItemDone map[int]bool
+ PromptTokens int64
+ CachedTokens int64
+ CompletionTokens int64
+ TotalTokens int64
+ ReasoningTokens int64
+ UsageSeen bool
+ NextOutputIndex int // global counter for unique output_index across messages and function calls
+ MsgOutputIndex map[int]int // choice idx -> assigned output_index
+ FuncOutputIndex map[int]int // callIndex -> assigned output_index
+ CompletedSent bool // guards against duplicate response.completed
+}
+
+var responseIDCounter uint64
+
+func synthesizeResponseID() string {
+ return fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
+}
+
+func (st *openaiToResponsesState) msgOutIdx(choiceIdx int) int {
+ if oi, ok := st.MsgOutputIndex[choiceIdx]; ok {
+ return oi
+ }
+ oi := st.NextOutputIndex
+ st.MsgOutputIndex[choiceIdx] = oi
+ st.NextOutputIndex++
+ return oi
+}
+
+func (st *openaiToResponsesState) funcOutIdx(callIndex int) int {
+ if oi, ok := st.FuncOutputIndex[callIndex]; ok {
+ return oi
+ }
+ oi := st.NextOutputIndex
+ st.FuncOutputIndex[callIndex] = oi
+ st.NextOutputIndex++
+ return oi
+}
+
+func convertOpenAIChatCompletionsChunkToResponses(rawJSON []byte, state *TransformState) [][]byte {
+ if state == nil {
+ return nil
+ }
+ st, ok := state.Custom.(*openaiToResponsesState)
+ if !ok || st == nil {
+ st = &openaiToResponsesState{
+ FuncArgsBuf: make(map[int]*strings.Builder),
+ FuncNames: make(map[int]string),
+ FuncCallIDs: make(map[int]string),
+ MsgTextBuf: make(map[int]*strings.Builder),
+ MsgItemAdded: make(map[int]bool),
+ MsgContentAdded: make(map[int]bool),
+ MsgItemDone: make(map[int]bool),
+ FuncArgsDone: make(map[int]bool),
+ FuncItemDone: make(map[int]bool),
+ Reasonings: make([]openaiToResponsesStateReasoning, 0),
+ MsgOutputIndex: make(map[int]int),
+ FuncOutputIndex: make(map[int]int),
}
+ state.Custom = st
+ }
- if state.MessageID == "" {
- state.MessageID = openaiChunk.ID
- codexEvent := map[string]interface{}{
- "type": "response.created",
- "response": map[string]interface{}{
- "id": openaiChunk.ID,
- "model": openaiChunk.Model,
- "status": "in_progress",
- "created_at": time.Now().Unix(),
- },
- }
- output = append(output, FormatSSE("", codexEvent)...)
+ root := gjson.ParseBytes(rawJSON)
+ obj := root.Get("object")
+ if obj.Exists() && obj.String() != "" && obj.String() != "chat.completion.chunk" {
+ return nil
+ }
+ if !root.Get("choices").Exists() || !root.Get("choices").IsArray() {
+ return nil
+ }
+
+ nextSeq := func() int { st.Seq++; return st.Seq }
+ var out [][]byte
+
+ if !st.Started {
+ st.ResponseID = root.Get("id").String()
+ if st.ResponseID == "" {
+ st.ResponseID = synthesizeResponseID()
+ }
+ st.Created = root.Get("created").Int()
+ if st.Created == 0 {
+ st.Created = time.Now().Unix()
+ }
+ st.MsgTextBuf = make(map[int]*strings.Builder)
+ st.ReasoningBuf.Reset()
+ st.ReasoningID = ""
+ st.ReasoningIndex = 0
+ st.FuncArgsBuf = make(map[int]*strings.Builder)
+ st.FuncNames = make(map[int]string)
+ st.FuncCallIDs = make(map[int]string)
+ st.MsgItemAdded = make(map[int]bool)
+ st.MsgContentAdded = make(map[int]bool)
+ st.MsgItemDone = make(map[int]bool)
+ st.FuncArgsDone = make(map[int]bool)
+ st.FuncItemDone = make(map[int]bool)
+ st.MsgOutputIndex = make(map[int]int)
+ st.FuncOutputIndex = make(map[int]int)
+ st.NextOutputIndex = 0
+ st.CompletedSent = false
+ st.PromptTokens = 0
+ st.CachedTokens = 0
+ st.CompletionTokens = 0
+ st.TotalTokens = 0
+ st.ReasoningTokens = 0
+ st.UsageSeen = false
+
+ created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
+ created, _ = sjson.Set(created, "sequence_number", nextSeq())
+ created, _ = sjson.Set(created, "response.id", st.ResponseID)
+ created, _ = sjson.Set(created, "response.created_at", st.Created)
+ out = append(out, FormatSSE("response.created", []byte(created)))
+
+ inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
+ inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
+ inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
+ inprog, _ = sjson.Set(inprog, "response.created_at", st.Created)
+ out = append(out, FormatSSE("response.in_progress", []byte(inprog)))
+ st.Started = true
+ }
+
+ if usage := root.Get("usage"); usage.Exists() {
+ if v := usage.Get("prompt_tokens"); v.Exists() {
+ st.PromptTokens = v.Int()
+ st.UsageSeen = true
+ }
+ if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
+ st.CachedTokens = v.Int()
+ st.UsageSeen = true
+ }
+ if v := usage.Get("completion_tokens"); v.Exists() {
+ st.CompletionTokens = v.Int()
+ st.UsageSeen = true
+ } else if v := usage.Get("output_tokens"); v.Exists() {
+ st.CompletionTokens = v.Int()
+ st.UsageSeen = true
+ }
+ if v := usage.Get("output_tokens_details.reasoning_tokens"); v.Exists() {
+ st.ReasoningTokens = v.Int()
+ st.UsageSeen = true
+ } else if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
+ st.ReasoningTokens = v.Int()
+ st.UsageSeen = true
+ }
+ if v := usage.Get("total_tokens"); v.Exists() {
+ st.TotalTokens = v.Int()
+ st.UsageSeen = true
}
+ }
+
+ stopReasoning := func(text string) {
+ textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
+ textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
+ textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningID)
+ textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
+ textDone, _ = sjson.Set(textDone, "text", text)
+ out = append(out, FormatSSE("response.reasoning_summary_text.done", []byte(textDone)))
+
+ partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
+ partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
+ partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningID)
+ partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
+ partDone, _ = sjson.Set(partDone, "part.text", text)
+ out = append(out, FormatSSE("response.reasoning_summary_part.done", []byte(partDone)))
+
+ outputItemDone := `{"type":"response.output_item.done","item":{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]},"output_index":0,"sequence_number":0}`
+ outputItemDone, _ = sjson.Set(outputItemDone, "sequence_number", nextSeq())
+ outputItemDone, _ = sjson.Set(outputItemDone, "item.id", st.ReasoningID)
+ outputItemDone, _ = sjson.Set(outputItemDone, "output_index", st.ReasoningIndex)
+ outputItemDone, _ = sjson.Set(outputItemDone, "item.summary.0.text", text)
+ out = append(out, FormatSSE("response.output_item.done", []byte(outputItemDone)))
+
+ st.Reasonings = append(st.Reasonings, openaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
+ st.ReasoningID = ""
+ }
+
+ if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
+ choices.ForEach(func(_, choice gjson.Result) bool {
+ idx := int(choice.Get("index").Int())
+ delta := choice.Get("delta")
+ if delta.Exists() {
+ if c := delta.Get("content"); c.Exists() && c.String() != "" {
+ if st.ReasoningID != "" {
+ stopReasoning(st.ReasoningBuf.String())
+ st.ReasoningBuf.Reset()
+ }
+ if !st.MsgItemAdded[idx] {
+ item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
+ item, _ = sjson.Set(item, "sequence_number", nextSeq())
+ item, _ = sjson.Set(item, "output_index", st.msgOutIdx(idx))
+ item, _ = sjson.Set(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
+ out = append(out, FormatSSE("response.output_item.added", []byte(item)))
+ st.MsgItemAdded[idx] = true
+ }
+ if !st.MsgContentAdded[idx] {
+ part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
+ part, _ = sjson.Set(part, "sequence_number", nextSeq())
+ part, _ = sjson.Set(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
+ part, _ = sjson.Set(part, "output_index", st.msgOutIdx(idx))
+ part, _ = sjson.Set(part, "content_index", 0)
+ out = append(out, FormatSSE("response.content_part.added", []byte(part)))
+ st.MsgContentAdded[idx] = true
+ }
+
+ msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
+ msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
+ msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
+ msg, _ = sjson.Set(msg, "output_index", st.msgOutIdx(idx))
+ msg, _ = sjson.Set(msg, "content_index", 0)
+ msg, _ = sjson.Set(msg, "delta", c.String())
+ out = append(out, FormatSSE("response.output_text.delta", []byte(msg)))
+ if st.MsgTextBuf[idx] == nil {
+ st.MsgTextBuf[idx] = &strings.Builder{}
+ }
+ st.MsgTextBuf[idx].WriteString(c.String())
+ }
+
+ if rc := delta.Get("reasoning_content"); rc.Exists() && rc.String() != "" {
+ if st.ReasoningID == "" {
+ st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
+ st.ReasoningIndex = st.NextOutputIndex
+ st.NextOutputIndex++
+ item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
+ item, _ = sjson.Set(item, "sequence_number", nextSeq())
+ item, _ = sjson.Set(item, "output_index", st.ReasoningIndex)
+ item, _ = sjson.Set(item, "item.id", st.ReasoningID)
+ out = append(out, FormatSSE("response.output_item.added", []byte(item)))
+ part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
+ part, _ = sjson.Set(part, "sequence_number", nextSeq())
+ part, _ = sjson.Set(part, "item_id", st.ReasoningID)
+ part, _ = sjson.Set(part, "output_index", st.ReasoningIndex)
+ out = append(out, FormatSSE("response.reasoning_summary_part.added", []byte(part)))
+ }
+ st.ReasoningBuf.WriteString(rc.String())
+ msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
+ msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
+ msg, _ = sjson.Set(msg, "item_id", st.ReasoningID)
+ msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
+ msg, _ = sjson.Set(msg, "delta", rc.String())
+ out = append(out, FormatSSE("response.reasoning_summary_text.delta", []byte(msg)))
+ }
+
+ if tcs := delta.Get("tool_calls"); tcs.Exists() && tcs.IsArray() {
+ if st.ReasoningID != "" {
+ stopReasoning(st.ReasoningBuf.String())
+ st.ReasoningBuf.Reset()
+ }
+ if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
+ fullText := ""
+ if b := st.MsgTextBuf[idx]; b != nil {
+ fullText = b.String()
+ }
+ done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
+ done, _ = sjson.Set(done, "sequence_number", nextSeq())
+ done, _ = sjson.Set(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
+ done, _ = sjson.Set(done, "output_index", st.msgOutIdx(idx))
+ done, _ = sjson.Set(done, "content_index", 0)
+ done, _ = sjson.Set(done, "text", fullText)
+ out = append(out, FormatSSE("response.output_text.done", []byte(done)))
+
+ partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
+ partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
+ partDone, _ = sjson.Set(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
+ partDone, _ = sjson.Set(partDone, "output_index", st.msgOutIdx(idx))
+ partDone, _ = sjson.Set(partDone, "content_index", 0)
+ partDone, _ = sjson.Set(partDone, "part.text", fullText)
+ out = append(out, FormatSSE("response.content_part.done", []byte(partDone)))
+
+ itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`
+ itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
+ itemDone, _ = sjson.Set(itemDone, "output_index", st.msgOutIdx(idx))
+ itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
+ itemDone, _ = sjson.Set(itemDone, "item.content.0.text", fullText)
+ out = append(out, FormatSSE("response.output_item.done", []byte(itemDone)))
+ st.MsgItemDone[idx] = true
+ }
+
+ for tcIndex, tc := range tcs.Array() {
+ callIndex := tcIndex
+ if v := tc.Get("index"); v.Exists() {
+ callIndex = int(v.Int())
+ }
+
+ newCallID := tc.Get("id").String()
+ nameChunk := tc.Get("function.name").String()
+ if nameChunk != "" {
+ st.FuncNames[callIndex] = nameChunk
+ }
+ existingCallID := st.FuncCallIDs[callIndex]
+ effectiveCallID := existingCallID
+ shouldEmitItem := false
+ if existingCallID == "" && newCallID != "" {
+ effectiveCallID = newCallID
+ st.FuncCallIDs[callIndex] = newCallID
+ shouldEmitItem = true
+ }
+
+ if shouldEmitItem && effectiveCallID != "" {
+ o := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
+ o, _ = sjson.Set(o, "sequence_number", nextSeq())
+ o, _ = sjson.Set(o, "output_index", st.funcOutIdx(callIndex))
+ o, _ = sjson.Set(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
+ o, _ = sjson.Set(o, "item.call_id", effectiveCallID)
+ o, _ = sjson.Set(o, "item.name", st.FuncNames[callIndex])
+ out = append(out, FormatSSE("response.output_item.added", []byte(o)))
+ }
+
+ if st.FuncArgsBuf[callIndex] == nil {
+ st.FuncArgsBuf[callIndex] = &strings.Builder{}
+ }
+ if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
+ refCallID := st.FuncCallIDs[callIndex]
+ if refCallID == "" {
+ refCallID = newCallID
+ }
+ if refCallID != "" {
+ ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
+ ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
+ ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
+ ad, _ = sjson.Set(ad, "output_index", st.funcOutIdx(callIndex))
+ ad, _ = sjson.Set(ad, "delta", args.String())
+ out = append(out, FormatSSE("response.function_call_arguments.delta", []byte(ad)))
+ }
+ st.FuncArgsBuf[callIndex].WriteString(args.String())
+ }
+ }
+ }
+ }
+
+ if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
+ if len(st.MsgItemAdded) > 0 {
+ for _, i := range sortedKeys(st.MsgItemAdded) {
+ if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
+ fullText := ""
+ if b := st.MsgTextBuf[i]; b != nil {
+ fullText = b.String()
+ }
+ done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
+ done, _ = sjson.Set(done, "sequence_number", nextSeq())
+ done, _ = sjson.Set(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
+ done, _ = sjson.Set(done, "output_index", st.msgOutIdx(i))
+ done, _ = sjson.Set(done, "content_index", 0)
+ done, _ = sjson.Set(done, "text", fullText)
+ out = append(out, FormatSSE("response.output_text.done", []byte(done)))
+
+ partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
+ partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
+ partDone, _ = sjson.Set(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
+ partDone, _ = sjson.Set(partDone, "output_index", st.msgOutIdx(i))
+ partDone, _ = sjson.Set(partDone, "content_index", 0)
+ partDone, _ = sjson.Set(partDone, "part.text", fullText)
+ out = append(out, FormatSSE("response.content_part.done", []byte(partDone)))
+
+ itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`
+ itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
+ itemDone, _ = sjson.Set(itemDone, "output_index", st.msgOutIdx(i))
+ itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
+ itemDone, _ = sjson.Set(itemDone, "item.content.0.text", fullText)
+ out = append(out, FormatSSE("response.output_item.done", []byte(itemDone)))
+ st.MsgItemDone[i] = true
+ }
+ }
+ }
+
+ if st.ReasoningID != "" {
+ stopReasoning(st.ReasoningBuf.String())
+ st.ReasoningBuf.Reset()
+ }
+
+ if len(st.FuncCallIDs) > 0 {
+ for _, i := range sortedKeys(st.FuncCallIDs) {
+ callID := st.FuncCallIDs[i]
+ if callID == "" || st.FuncItemDone[i] {
+ continue
+ }
+ args := "{}"
+ if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
+ args = b.String()
+ }
+ fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
+ fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
+ fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
+ fcDone, _ = sjson.Set(fcDone, "output_index", st.funcOutIdx(i))
+ fcDone, _ = sjson.Set(fcDone, "arguments", args)
+ out = append(out, FormatSSE("response.function_call_arguments.done", []byte(fcDone)))
- if len(openaiChunk.Choices) > 0 {
- choice := openaiChunk.Choices[0]
- if choice.Delta != nil {
- if content, ok := choice.Delta.Content.(string); ok && content != "" {
- codexEvent := map[string]interface{}{
- "type": "response.output_item.delta",
- "delta": map[string]interface{}{
- "type": "text",
- "text": content,
- },
+ itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
+ itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
+ itemDone, _ = sjson.Set(itemDone, "output_index", st.funcOutIdx(i))
+ itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
+ itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
+ itemDone, _ = sjson.Set(itemDone, "item.call_id", callID)
+ itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[i])
+ out = append(out, FormatSSE("response.output_item.done", []byte(itemDone)))
+ st.FuncItemDone[i] = true
+ st.FuncArgsDone[i] = true
}
- output = append(output, FormatSSE("", codexEvent)...)
}
}
+ return true
+ })
+ }
- if choice.FinishReason != "" {
- codexEvent := map[string]interface{}{
- "type": "response.done",
- "response": map[string]interface{}{
- "id": state.MessageID,
- "status": "completed",
- },
+ // Emit response.completed once after all choices have been processed
+ if !st.CompletedSent {
+ // Check if any choice had a finish_reason
+ hasFinish := false
+ if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
+ choices.ForEach(func(_, choice gjson.Result) bool {
+ if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
+ hasFinish = true
+ return false
+ }
+ return true
+ })
+ }
+ if hasFinish {
+ st.CompletedSent = true
+ completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
+ completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
+ completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
+ completed, _ = sjson.Set(completed, "response.created_at", st.Created)
+
+ outputsWrapper := `{"arr":[]}`
+ if len(st.Reasonings) > 0 {
+ for _, r := range st.Reasonings {
+ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
+ item, _ = sjson.Set(item, "id", r.ReasoningID)
+ item, _ = sjson.Set(item, "summary.0.text", r.ReasoningData)
+ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
+ }
+ }
+ if len(st.MsgItemAdded) > 0 {
+ for _, i := range sortedKeys(st.MsgItemAdded) {
+ txt := ""
+ if b := st.MsgTextBuf[i]; b != nil {
+ txt = b.String()
+ }
+ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
+ item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
+ item, _ = sjson.Set(item, "content.0.text", txt)
+ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
- output = append(output, FormatSSE("", codexEvent)...)
}
+ if len(st.FuncCallIDs) > 0 {
+ for _, i := range sortedKeys(st.FuncCallIDs) {
+ args := ""
+ if b := st.FuncArgsBuf[i]; b != nil {
+ args = b.String()
+ }
+ callID := st.FuncCallIDs[i]
+ name := st.FuncNames[i]
+ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
+ item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
+ item, _ = sjson.Set(item, "arguments", args)
+ item, _ = sjson.Set(item, "call_id", callID)
+ item, _ = sjson.Set(item, "name", name)
+ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
+ }
+ }
+ if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
+ completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
+ }
+ if st.UsageSeen {
+ completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens)
+ completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
+ completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.CompletionTokens)
+ if st.ReasoningTokens > 0 {
+ completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
+ }
+ total := st.TotalTokens
+ if total == 0 {
+ total = st.PromptTokens + st.CompletionTokens
+ }
+ completed, _ = sjson.Set(completed, "response.usage.total_tokens", total)
+ }
+ if len(state.OriginalRequestBody) > 0 {
+ completed = applyRequestEchoToResponse(completed, "response.", state.OriginalRequestBody)
+ }
+ out = append(out, FormatSSE("response.completed", []byte(completed)))
}
}
- return output, nil
+ return out
+}
+
+func sortedKeys[T any](m map[int]T) []int {
+ keys := make([]int, 0, len(m))
+ for k := range m {
+ keys = append(keys, k)
+ }
+ sort.Ints(keys)
+ return keys
+}
+
+func applyRequestEchoToResponse(responseJSON string, prefix string, requestRaw []byte) string {
+ if len(requestRaw) == 0 {
+ return responseJSON
+ }
+ req := gjson.ParseBytes(requestRaw)
+ paths := []string{
+ "model",
+ "instructions",
+ "input",
+ "tools",
+ "tool_choice",
+ "metadata",
+ "store",
+ "max_output_tokens",
+ "temperature",
+ "top_p",
+ "reasoning",
+ "parallel_tool_calls",
+ "include",
+ "previous_response_id",
+ "text",
+ "truncation",
+ }
+ for _, path := range paths {
+ val := req.Get(path)
+ if !val.Exists() {
+ continue
+ }
+ fullPath := prefix + path
+ if gjson.Get(responseJSON, fullPath).Exists() {
+ continue
+ }
+ switch val.Type {
+ case gjson.String, gjson.Number, gjson.True, gjson.False:
+ responseJSON, _ = sjson.Set(responseJSON, fullPath, val.Value())
+ default:
+ responseJSON, _ = sjson.SetRaw(responseJSON, fullPath, val.Raw)
+ }
+ }
+ return responseJSON
}
diff --git a/internal/converter/registry.go b/internal/converter/registry.go
index 77c6081d..a4d68b2b 100644
--- a/internal/converter/registry.go
+++ b/internal/converter/registry.go
@@ -9,13 +9,15 @@ import (
// TransformState holds state for streaming response conversion
type TransformState struct {
- MessageID string
- CurrentIndex int
- CurrentBlockType string // "text", "thinking", "tool_use"
- ToolCalls map[int]*ToolCallState
- Buffer string // SSE line buffer
- Usage *Usage
- StopReason string
+ MessageID string
+ CurrentIndex int
+ CurrentBlockType string // "text", "thinking", "tool_use"
+ ToolCalls map[int]*ToolCallState
+ Buffer string // SSE line buffer
+ Usage *Usage
+ StopReason string
+ Custom interface{}
+ OriginalRequestBody []byte
}
// ToolCallState tracks tool call conversion state
@@ -47,6 +49,10 @@ type ResponseTransformer interface {
TransformChunk(chunk []byte, state *TransformState) ([]byte, error)
}
+type ResponseTransformerWithState interface {
+ TransformWithState(body []byte, state *TransformState) ([]byte, error)
+}
+
// Registry holds all format converters
type Registry struct {
requests map[domain.ClientType]map[domain.ClientType]RequestTransformer
@@ -131,6 +137,26 @@ func (r *Registry) TransformResponse(from, to domain.ClientType, body []byte) ([
return transformer.Transform(body)
}
+// TransformResponseWithState converts a non-streaming response with state
+func (r *Registry) TransformResponseWithState(from, to domain.ClientType, body []byte, state *TransformState) ([]byte, error) {
+ if from == to {
+ return body, nil
+ }
+
+ fromMap := r.responses[from]
+ if fromMap == nil {
+ return nil, fmt.Errorf("no response transformer from %s", from)
+ }
+ transformer := fromMap[to]
+ if transformer == nil {
+ return nil, fmt.Errorf("no response transformer from %s to %s", from, to)
+ }
+ if withState, ok := transformer.(ResponseTransformerWithState); ok {
+ return withState.TransformWithState(body, state)
+ }
+ return transformer.Transform(body)
+}
+
// TransformStreamChunk converts a streaming chunk
func (r *Registry) TransformStreamChunk(from, to domain.ClientType, chunk []byte, state *TransformState) ([]byte, error) {
if from == to {
diff --git a/internal/converter/test_helpers_test.go b/internal/converter/test_helpers_test.go
index d7c44cd7..0b67fd05 100644
--- a/internal/converter/test_helpers_test.go
+++ b/internal/converter/test_helpers_test.go
@@ -10,8 +10,18 @@ func codexInputHasRoleText(input interface{}, role string, text string) bool {
if !ok || m["type"] != "message" || m["role"] != role {
continue
}
- if content, ok := m["content"].(string); ok && content == text {
- return true
+ switch content := m["content"].(type) {
+ case string:
+ if content == text {
+ return true
+ }
+ case []interface{}:
+ for _, part := range content {
+ pm, ok := part.(map[string]interface{})
+ if ok && pm["text"] == text {
+ return true
+ }
+ }
}
}
return false
diff --git a/internal/converter/tool_name.go b/internal/converter/tool_name.go
index e5c809c6..16810711 100644
--- a/internal/converter/tool_name.go
+++ b/internal/converter/tool_name.go
@@ -1,8 +1,8 @@
package converter
import (
- "fmt"
- "hash/crc32"
+ "strconv"
+ "strings"
)
const maxToolNameLen = 64
@@ -11,25 +11,54 @@ func shortenNameIfNeeded(name string) string {
if len(name) <= maxToolNameLen {
return name
}
- hash := crc32.ChecksumIEEE([]byte(name))
- // Keep a stable prefix to preserve readability, add hash suffix for uniqueness.
- prefixLen := maxToolNameLen - 9 // "_" + 8 hex
- return fmt.Sprintf("%s_%08x", name[:prefixLen], hash)
+ if strings.HasPrefix(name, "mcp__") {
+ idx := strings.LastIndex(name, "__")
+ if idx > 3 {
+ candidate := "mcp__" + name[idx+2:]
+ if len(candidate) > maxToolNameLen {
+ return candidate[:maxToolNameLen]
+ }
+ return candidate
+ }
+ }
+ return name[:maxToolNameLen]
}
func buildShortNameMap(names []string) map[string]string {
+ used := map[string]struct{}{}
result := make(map[string]string, len(names))
- used := make(map[string]int)
- for _, name := range names {
- short := shortenNameIfNeeded(name)
- if count, ok := used[short]; ok {
- count++
- used[short] = count
- short = shortenNameIfNeeded(fmt.Sprintf("%s_%d", name, count))
- } else {
- used[short] = 0
+
+ baseCandidate := func(n string) string {
+ return shortenNameIfNeeded(n)
+ }
+
+ makeUnique := func(cand string) string {
+ if _, ok := used[cand]; !ok {
+ return cand
+ }
+ base := cand
+ for i := 1; ; i++ {
+ suffix := "_" + strconv.Itoa(i)
+ allowed := maxToolNameLen - len(suffix)
+ if allowed < 0 {
+ allowed = 0
+ }
+ tmp := base
+ if len(tmp) > allowed {
+ tmp = tmp[:allowed]
+ }
+ tmp = tmp + suffix
+ if _, ok := used[tmp]; !ok {
+ return tmp
+ }
}
- result[name] = short
+ }
+
+ for _, n := range names {
+ cand := baseCandidate(n)
+ uniq := makeUnique(cand)
+ used[uniq] = struct{}{}
+ result[n] = uniq
}
return result
}
diff --git a/internal/domain/model.go b/internal/domain/model.go
index f8c8c409..5f75e08e 100644
--- a/internal/domain/model.go
+++ b/internal/domain/model.go
@@ -1,6 +1,9 @@
package domain
-import "time"
+import (
+ "strings"
+ "time"
+)
// 各种请求的客户端
type ClientType string
@@ -811,52 +814,42 @@ type ResponseModel struct {
// MatchWildcard 检查输入是否匹配通配符模式
func MatchWildcard(pattern, input string) bool {
- // 简单情况
+ pattern = strings.TrimSpace(pattern)
+ input = strings.TrimSpace(input)
+ if pattern == "" {
+ return false
+ }
if pattern == "*" {
return true
}
- if !containsWildcard(pattern) {
- return pattern == input
- }
-
- parts := splitByWildcard(pattern)
-
- // 处理 prefix* 模式
- if len(parts) == 2 && parts[1] == "" {
- return hasPrefix(input, parts[0])
- }
-
- // 处理 *suffix 模式
- if len(parts) == 2 && parts[0] == "" {
- return hasSuffix(input, parts[1])
- }
-
- // 处理多通配符模式
- pos := 0
- for i, part := range parts {
- if part == "" {
+ // Iterative glob-style matcher supporting only '*' wildcard.
+ pi, si := 0, 0
+ starIdx := -1
+ matchIdx := 0
+ for si < len(input) {
+ if pi < len(pattern) && pattern[pi] == input[si] {
+ pi++
+ si++
continue
}
-
- idx := indexOf(input[pos:], part)
- if idx < 0 {
- return false
+ if pi < len(pattern) && pattern[pi] == '*' {
+ starIdx = pi
+ matchIdx = si
+ pi++
+ continue
}
-
- // 第一部分必须在开头(如果模式不以 * 开头)
- if i == 0 && idx != 0 {
- return false
+ if starIdx != -1 {
+ pi = starIdx + 1
+ matchIdx++
+ si = matchIdx
+ continue
}
-
- pos += idx + len(part)
- }
-
- // 最后一部分必须在结尾(如果模式不以 * 结尾)
- if parts[len(parts)-1] != "" && !hasSuffix(input, parts[len(parts)-1]) {
return false
}
-
- return true
+ for pi < len(pattern) && pattern[pi] == '*' {
+ pi++
+ }
+ return pi == len(pattern)
}
// 辅助函数
diff --git a/internal/executor/converting_writer.go b/internal/executor/converting_writer.go
index daeb0a42..b005cc42 100644
--- a/internal/executor/converting_writer.go
+++ b/internal/executor/converting_writer.go
@@ -152,7 +152,12 @@ func NewConvertingResponseWriter(
conv *converter.Registry,
originalType, targetType domain.ClientType,
isStream bool,
+ originalRequestBody []byte,
) *ConvertingResponseWriter {
+ state := converter.NewTransformState()
+ if len(originalRequestBody) > 0 {
+ state.OriginalRequestBody = bytes.Clone(originalRequestBody)
+ }
return &ConvertingResponseWriter{
underlying: w,
converter: conv,
@@ -161,7 +166,7 @@ func NewConvertingResponseWriter(
isStream: isStream,
statusCode: http.StatusOK,
headers: make(http.Header),
- streamState: converter.NewTransformState(),
+ streamState: state,
}
}
@@ -226,9 +231,9 @@ func (c *ConvertingResponseWriter) Finalize() error {
body := c.buffer.Bytes()
// Convert the response
- converted, err := c.converter.TransformResponse(c.targetType, c.originalType, body)
- if err != nil {
- // On conversion error, use original body
+ converted, err := c.converter.TransformResponseWithState(c.targetType, c.originalType, body, c.streamState)
+ if err != nil || converted == nil {
+ // On conversion error or nil result, use original body
converted = body
}
@@ -271,9 +276,9 @@ func NeedsConversion(originalType, targetType domain.ClientType) bool {
return originalType != targetType && originalType != "" && targetType != ""
}
-// GetPreferredTargetType returns the best target type for conversion
-// Prefers Claude as it has the richest format support
-func GetPreferredTargetType(supportedTypes []domain.ClientType, originalType domain.ClientType) domain.ClientType {
+// GetPreferredTargetType returns the best target type for conversion.
+// Prefers Codex only for codex providers, otherwise Gemini then Claude.
+func GetPreferredTargetType(supportedTypes []domain.ClientType, originalType domain.ClientType, providerType string) domain.ClientType {
// If original type is supported, no conversion needed
for _, t := range supportedTypes {
if t == originalType {
@@ -281,7 +286,23 @@ func GetPreferredTargetType(supportedTypes []domain.ClientType, originalType dom
}
}
- // Prefer Claude as target (richest format)
+ if providerType == "codex" {
+ // Prefer Codex when available (best fit for Codex provider)
+ for _, t := range supportedTypes {
+ if t == domain.ClientTypeCodex {
+ return t
+ }
+ }
+ }
+
+ // Prefer Gemini as target (best fit for Antigravity)
+ for _, t := range supportedTypes {
+ if t == domain.ClientTypeGemini {
+ return t
+ }
+ }
+
+ // Prefer Claude as target (fallback)
for _, t := range supportedTypes {
if t == domain.ClientTypeClaude {
return t
diff --git a/internal/executor/converting_writer_test.go b/internal/executor/converting_writer_test.go
deleted file mode 100644
index e4da5439..00000000
--- a/internal/executor/converting_writer_test.go
+++ /dev/null
@@ -1,90 +0,0 @@
-package executor
-
-import (
- "testing"
-
- "github.com/awsl-project/maxx/internal/domain"
-)
-
-func TestConvertRequestURI(t *testing.T) {
- tests := []struct {
- name string
- original string
- from domain.ClientType
- to domain.ClientType
- mappedModel string
- isStream bool
- want string
- }{
- {
- name: "same type passthrough",
- original: "/v1/chat/completions",
- from: domain.ClientTypeOpenAI,
- to: domain.ClientTypeOpenAI,
- want: "/v1/chat/completions",
- },
- {
- name: "openai to claude with query",
- original: "/v1/chat/completions?foo=1",
- from: domain.ClientTypeOpenAI,
- to: domain.ClientTypeClaude,
- want: "/v1/messages?foo=1",
- },
- {
- name: "claude to codex",
- original: "/v1/messages",
- from: domain.ClientTypeClaude,
- to: domain.ClientTypeCodex,
- want: "/responses",
- },
- {
- name: "claude count tokens to openai",
- original: "/v1/messages/count_tokens",
- from: domain.ClientTypeClaude,
- to: domain.ClientTypeOpenAI,
- want: "/v1/chat/completions",
- },
- {
- name: "openai to gemini stream",
- original: "/v1/chat/completions",
- from: domain.ClientTypeOpenAI,
- to: domain.ClientTypeGemini,
- mappedModel: "gemini-2.5-pro",
- isStream: true,
- want: "/v1beta/models/gemini-2.5-pro:streamGenerateContent",
- },
- {
- name: "claude count tokens to gemini",
- original: "/v1/messages/count_tokens",
- from: domain.ClientTypeClaude,
- to: domain.ClientTypeGemini,
- mappedModel: "gemini-2.5-pro",
- want: "/v1beta/models/gemini-2.5-pro:countTokens",
- },
- {
- name: "gemini internal preserves version and action",
- original: "/v1internal/models/gemini-2.0:generateContent?alt=sse",
- from: domain.ClientTypeOpenAI,
- to: domain.ClientTypeGemini,
- mappedModel: "gemini-2.5-pro",
- isStream: true,
- want: "/v1internal/models/gemini-2.5-pro:generateContent?alt=sse",
- },
- {
- name: "gemini target without model keeps original",
- original: "/v1/chat/completions",
- from: domain.ClientTypeOpenAI,
- to: domain.ClientTypeGemini,
- want: "/v1/chat/completions",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := ConvertRequestURI(tt.original, tt.from, tt.to, tt.mappedModel, tt.isStream)
- if got != tt.want {
- t.Fatalf("ConvertRequestURI(%q) = %q, want %q", tt.original, got, tt.want)
- }
- })
- }
-}
diff --git a/internal/executor/executor.go b/internal/executor/executor.go
index 48dff280..ca230187 100644
--- a/internal/executor/executor.go
+++ b/internal/executor/executor.go
@@ -2,21 +2,18 @@ package executor
import (
"context"
- "log"
"net/http"
"strconv"
"time"
- ctxutil "github.com/awsl-project/maxx/internal/context"
"github.com/awsl-project/maxx/internal/converter"
"github.com/awsl-project/maxx/internal/cooldown"
"github.com/awsl-project/maxx/internal/domain"
"github.com/awsl-project/maxx/internal/event"
- "github.com/awsl-project/maxx/internal/pricing"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/repository"
"github.com/awsl-project/maxx/internal/router"
"github.com/awsl-project/maxx/internal/stats"
- "github.com/awsl-project/maxx/internal/usage"
"github.com/awsl-project/maxx/internal/waiter"
)
@@ -34,6 +31,8 @@ type Executor struct {
instanceID string
statsAggregator *stats.StatsAggregator
converter *converter.Registry
+ engine *flow.Engine
+ middlewares []flow.HandlerFunc
}
// NewExecutor creates a new executor
@@ -63,630 +62,41 @@ func NewExecutor(
instanceID: instanceID,
statsAggregator: statsAggregator,
converter: converter.GetGlobalRegistry(),
+ engine: flow.NewEngine(),
}
}
-// Execute handles the proxy request with routing and retry logic
-func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request) error {
- clientType := ctxutil.GetClientType(ctx)
- projectID := ctxutil.GetProjectID(ctx)
- sessionID := ctxutil.GetSessionID(ctx)
- requestModel := ctxutil.GetRequestModel(ctx)
- isStream := ctxutil.GetIsStream(ctx)
-
- // Get API Token ID from context
- apiTokenID := ctxutil.GetAPITokenID(ctx)
-
- // Create proxy request record immediately (PENDING status)
- proxyReq := &domain.ProxyRequest{
- InstanceID: e.instanceID,
- RequestID: generateRequestID(),
- SessionID: sessionID,
- ClientType: clientType,
- ProjectID: projectID,
- RequestModel: requestModel,
- StartTime: time.Now(),
- IsStream: isStream,
- Status: "PENDING",
- APITokenID: apiTokenID,
- }
-
- // Capture client's original request info unless detail retention is disabled.
- if !e.shouldClearRequestDetail() {
- requestURI := ctxutil.GetRequestURI(ctx)
- requestHeaders := ctxutil.GetRequestHeaders(ctx)
- requestBody := ctxutil.GetRequestBody(ctx)
- headers := flattenHeaders(requestHeaders)
- // Go stores Host separately from headers, add it explicitly
- if req.Host != "" {
- if headers == nil {
- headers = make(map[string]string)
- }
- headers["Host"] = req.Host
- }
- proxyReq.RequestInfo = &domain.RequestInfo{
- Method: req.Method,
- URL: requestURI,
- Headers: headers,
- Body: string(requestBody),
- }
- }
-
- if err := e.proxyRequestRepo.Create(proxyReq); err != nil {
- log.Printf("[Executor] Failed to create proxy request: %v", err)
- }
-
- // Broadcast the new request immediately
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
-
- ctx = ctxutil.WithProxyRequest(ctx, proxyReq)
-
- // Check for project binding if required
- if projectID == 0 && e.projectWaiter != nil {
- // Get session for project waiter
- session, _ := e.sessionRepo.GetBySessionID(sessionID)
- if session == nil {
- session = &domain.Session{
- SessionID: sessionID,
- ClientType: clientType,
- ProjectID: 0,
- }
- }
-
- if err := e.projectWaiter.WaitForProject(ctx, session); err != nil {
- // Determine status based on error type
- status := "REJECTED"
- errorMsg := "project binding timeout: " + err.Error()
- if err == context.Canceled {
- status = "CANCELLED"
- errorMsg = "client cancelled: " + err.Error()
- // Notify frontend to close the dialog
- if e.broadcaster != nil {
- e.broadcaster.BroadcastMessage("session_pending_cancelled", map[string]interface{}{
- "sessionID": sessionID,
- })
- }
- }
-
- // Update request record with final status
- proxyReq.Status = status
- proxyReq.Error = errorMsg
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- _ = e.proxyRequestRepo.Update(proxyReq)
-
- // Broadcast the updated request
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
-
- return domain.NewProxyErrorWithMessage(err, false, "project binding required: "+err.Error())
- }
-
- // Update projectID from the now-bound session
- projectID = session.ProjectID
- proxyReq.ProjectID = projectID
- ctx = ctxutil.WithProjectID(ctx, projectID)
- }
-
- // Match routes
- routes, err := e.router.Match(&router.MatchContext{
- ClientType: clientType,
- ProjectID: projectID,
- RequestModel: requestModel,
- APITokenID: apiTokenID,
- })
- if err != nil {
- proxyReq.Status = "FAILED"
- proxyReq.Error = "no routes available"
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- _ = e.proxyRequestRepo.Update(proxyReq)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
- return domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, "no routes available")
- }
-
- if len(routes) == 0 {
- proxyReq.Status = "FAILED"
- proxyReq.Error = "no routes configured"
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- _ = e.proxyRequestRepo.Update(proxyReq)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
- return domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, "no routes configured")
- }
-
- // Update status to IN_PROGRESS
- proxyReq.Status = "IN_PROGRESS"
- _ = e.proxyRequestRepo.Update(proxyReq)
- ctx = ctxutil.WithProxyRequest(ctx, proxyReq)
+func (e *Executor) Use(handlers ...flow.HandlerFunc) {
+ e.middlewares = append(e.middlewares, handlers...)
+}
- // Add broadcaster to context so adapters can send updates
- if e.broadcaster != nil {
- ctx = ctxutil.WithBroadcaster(ctx, e.broadcaster)
+// Execute runs the executor middleware chain with a new flow context.
+func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http.Request) error {
+ c := flow.NewCtx(w, req)
+ if ctx != nil {
+ c.Set(flow.KeyProxyContext, ctx)
}
+ return e.ExecuteWith(c)
+}
- // Broadcast new request immediately so frontend sees it
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
+// ExecuteWith runs the executor middleware chain using an existing flow context.
+func (e *Executor) ExecuteWith(c *flow.Ctx) error {
+ if c == nil {
+ return domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "flow context missing")
}
-
- // Track current attempt for cleanup
- var currentAttempt *domain.ProxyUpstreamAttempt
-
- // Ensure final state is always updated
- defer func() {
- // If still IN_PROGRESS, mark as cancelled/failed
- if proxyReq.Status == "IN_PROGRESS" {
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- if ctx.Err() != nil {
- proxyReq.Status = "CANCELLED"
- if ctx.Err() == context.Canceled {
- proxyReq.Error = "client disconnected"
- } else if ctx.Err() == context.DeadlineExceeded {
- proxyReq.Error = "request timeout"
- } else {
- proxyReq.Error = ctx.Err().Error()
- }
- } else {
- proxyReq.Status = "FAILED"
- }
- _ = e.proxyRequestRepo.Update(proxyReq)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
- }
-
- // If current attempt is still IN_PROGRESS, mark as cancelled/failed
- if currentAttempt != nil && currentAttempt.Status == "IN_PROGRESS" {
- currentAttempt.EndTime = time.Now()
- currentAttempt.Duration = currentAttempt.EndTime.Sub(currentAttempt.StartTime)
- if ctx.Err() != nil {
- currentAttempt.Status = "CANCELLED"
- } else {
- currentAttempt.Status = "FAILED"
- }
- _ = e.attemptRepo.Update(currentAttempt)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyUpstreamAttempt(currentAttempt)
- }
- }
- }()
-
- // Try routes in order with retry logic
- var lastErr error
- for _, matchedRoute := range routes {
- // Check context before starting new route
- if ctx.Err() != nil {
- return ctx.Err()
- }
-
- // Update proxyReq with current route/provider for real-time tracking
- proxyReq.RouteID = matchedRoute.Route.ID
- proxyReq.ProviderID = matchedRoute.Provider.ID
- _ = e.proxyRequestRepo.Update(proxyReq)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
-
- // Determine model mapping
- // Model mapping is done in Executor after Router has filtered by SupportModels
- clientType := ctxutil.GetClientType(ctx)
- mappedModel := e.mapModel(requestModel, matchedRoute.Route, matchedRoute.Provider, clientType, projectID, apiTokenID)
- ctx = ctxutil.WithMappedModel(ctx, mappedModel)
-
- // Format conversion: check if client type is supported by provider
- // If not, convert request to a supported format
- originalClientType := clientType
- targetClientType := clientType
- needsConversion := false
- convertedBody := []byte(nil)
- var convErr error
-
- supportedTypes := matchedRoute.ProviderAdapter.SupportedClientTypes()
- if e.converter.NeedConvert(clientType, supportedTypes) {
- targetClientType = GetPreferredTargetType(supportedTypes, clientType)
- if targetClientType != clientType {
- needsConversion = true
- log.Printf("[Executor] Format conversion needed: %s -> %s for provider %s",
- clientType, targetClientType, matchedRoute.Provider.Name)
-
- // Convert request body
- requestBody := ctxutil.GetRequestBody(ctx)
- if targetClientType == domain.ClientTypeCodex {
- if headers := ctxutil.GetRequestHeaders(ctx); headers != nil {
- requestBody = converter.InjectCodexUserAgent(requestBody, headers.Get("User-Agent"))
- }
- }
- convertedBody, convErr = e.converter.TransformRequest(
- clientType, targetClientType, requestBody, mappedModel, isStream)
- if convErr != nil {
- log.Printf("[Executor] Request conversion failed: %v, proceeding with original format", convErr)
- needsConversion = false
- } else {
- // Update context with converted body and new client type
- ctx = ctxutil.WithRequestBody(ctx, convertedBody)
- ctx = ctxutil.WithClientType(ctx, targetClientType)
- ctx = ctxutil.WithOriginalClientType(ctx, originalClientType)
-
- // Convert request URI to match the target client type
- originalURI := ctxutil.GetRequestURI(ctx)
- convertedURI := ConvertRequestURI(originalURI, clientType, targetClientType, mappedModel, isStream)
- if convertedURI != originalURI {
- ctx = ctxutil.WithRequestURI(ctx, convertedURI)
- log.Printf("[Executor] URI converted: %s -> %s", originalURI, convertedURI)
- }
- }
- }
- }
-
- // Get retry config
- retryConfig := e.getRetryConfig(matchedRoute.RetryConfig)
-
- // Execute with retries
- for attempt := 0; attempt <= retryConfig.MaxRetries; attempt++ {
- // Check context before each attempt
- if ctx.Err() != nil {
- return ctx.Err()
- }
-
- // Create attempt record with start time and request info
- attemptStartTime := time.Now()
- attemptRecord := &domain.ProxyUpstreamAttempt{
- ProxyRequestID: proxyReq.ID,
- RouteID: matchedRoute.Route.ID,
- ProviderID: matchedRoute.Provider.ID,
- IsStream: isStream,
- Status: "IN_PROGRESS",
- StartTime: attemptStartTime,
- RequestModel: requestModel,
- MappedModel: mappedModel,
- RequestInfo: proxyReq.RequestInfo, // Use original request info initially
- }
- if err := e.attemptRepo.Create(attemptRecord); err != nil {
- log.Printf("[Executor] Failed to create attempt record: %v", err)
- }
- currentAttempt = attemptRecord
-
- // Increment attempt count when creating a new attempt
- proxyReq.ProxyUpstreamAttemptCount++
-
- // Broadcast updated request with new attempt count
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
-
- // Broadcast new attempt immediately
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord)
- }
-
- // Put attempt into context so adapter can populate request/response info
- attemptCtx := ctxutil.WithUpstreamAttempt(ctx, attemptRecord)
-
- // Create event channel for adapter to send events
- eventChan := domain.NewAdapterEventChan()
- attemptCtx = ctxutil.WithEventChan(attemptCtx, eventChan)
-
- // Start real-time event processing goroutine
- // This ensures RequestInfo is broadcast as soon as adapter sends it
- eventDone := make(chan struct{})
- go e.processAdapterEventsRealtime(eventChan, attemptRecord, eventDone)
-
- // Wrap ResponseWriter to capture actual client response
- // If format conversion is needed, use ConvertingResponseWriter
- var responseWriter http.ResponseWriter
- var convertingWriter *ConvertingResponseWriter
- responseCapture := NewResponseCapture(w)
-
- if needsConversion {
- // Use ConvertingResponseWriter to transform response from targetType back to originalType
- convertingWriter = NewConvertingResponseWriter(
- responseCapture, e.converter, originalClientType, targetClientType, isStream)
- responseWriter = convertingWriter
- } else {
- responseWriter = responseCapture
- }
-
- // Execute request
- err := matchedRoute.ProviderAdapter.Execute(attemptCtx, responseWriter, req, matchedRoute.Provider)
-
- // For non-streaming responses with conversion, finalize the conversion
- if needsConversion && convertingWriter != nil && !isStream {
- if finalizeErr := convertingWriter.Finalize(); finalizeErr != nil {
- log.Printf("[Executor] Response conversion finalize failed: %v", finalizeErr)
- }
- }
-
- // Close event channel and wait for processing goroutine to finish
- eventChan.Close()
- <-eventDone
-
- if err == nil {
- // Success - set end time and duration
- attemptRecord.EndTime = time.Now()
- attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime)
- attemptRecord.Status = "COMPLETED"
-
- // Calculate cost in executor (unified for all adapters)
- // Adapter only needs to set token counts, executor handles pricing
- if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 {
- metrics := &usage.Metrics{
- InputTokens: attemptRecord.InputTokenCount,
- OutputTokens: attemptRecord.OutputTokenCount,
- CacheReadCount: attemptRecord.CacheReadCount,
- CacheCreationCount: attemptRecord.CacheWriteCount,
- Cache5mCreationCount: attemptRecord.Cache5mWriteCount,
- Cache1hCreationCount: attemptRecord.Cache1hWriteCount,
- }
- // Use ResponseModel for pricing (actual model from API response), fallback to MappedModel
- pricingModel := attemptRecord.ResponseModel
- if pricingModel == "" {
- pricingModel = attemptRecord.MappedModel
- }
- // Get multiplier from provider config
- multiplier := getProviderMultiplier(matchedRoute.Provider, clientType)
- result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier)
- attemptRecord.Cost = result.Cost
- attemptRecord.ModelPriceID = result.ModelPriceID
- attemptRecord.Multiplier = result.Multiplier
- }
-
- // 检查是否需要立即清理 attempt 详情(设置为 0 时不保存)
- if e.shouldClearRequestDetail() {
- attemptRecord.RequestInfo = nil
- attemptRecord.ResponseInfo = nil
- }
-
- _ = e.attemptRepo.Update(attemptRecord)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord)
- }
- currentAttempt = nil // Clear so defer doesn't update
-
- // Reset failure counts on success
- clientType := string(ctxutil.GetClientType(attemptCtx))
- cooldown.Default().RecordSuccess(matchedRoute.Provider.ID, clientType)
-
- proxyReq.Status = "COMPLETED"
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID
- proxyReq.ModelPriceID = attemptRecord.ModelPriceID
- proxyReq.Multiplier = attemptRecord.Multiplier
- proxyReq.ResponseModel = mappedModel // Record the actual model used
-
- // Capture actual client response (what was sent to client, e.g. Claude format)
- // This is different from attemptRecord.ResponseInfo which is upstream response (Gemini format)
- if !e.shouldClearRequestDetail() {
- proxyReq.ResponseInfo = &domain.ResponseInfo{
- Status: responseCapture.StatusCode(),
- Headers: responseCapture.CapturedHeaders(),
- Body: responseCapture.Body(),
- }
- }
- proxyReq.StatusCode = responseCapture.StatusCode()
-
- // Extract token usage from final client response (not from upstream attempt)
- // This ensures we use the correct format (Claude/OpenAI/Gemini) for the client type
- if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil {
- proxyReq.InputTokenCount = metrics.InputTokens
- proxyReq.OutputTokenCount = metrics.OutputTokens
- proxyReq.CacheReadCount = metrics.CacheReadCount
- proxyReq.CacheWriteCount = metrics.CacheCreationCount
- proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount
- proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount
- }
- proxyReq.Cost = attemptRecord.Cost
- proxyReq.TTFT = attemptRecord.TTFT
-
- // 检查是否需要立即清理 proxyReq 详情(设置为 0 时不保存)
- if e.shouldClearRequestDetail() {
- proxyReq.RequestInfo = nil
- proxyReq.ResponseInfo = nil
- }
-
- _ = e.proxyRequestRepo.Update(proxyReq)
-
- // Broadcast to WebSocket clients
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
-
- return nil
- }
-
- // Handle error - set end time and duration
- attemptRecord.EndTime = time.Now()
- attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime)
- lastErr = err
-
- // Update attempt status first (before checking context)
- if ctx.Err() != nil {
- attemptRecord.Status = "CANCELLED"
- } else {
- attemptRecord.Status = "FAILED"
- }
-
- // Calculate cost in executor even for failed attempts (may have partial token usage)
- if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 {
- metrics := &usage.Metrics{
- InputTokens: attemptRecord.InputTokenCount,
- OutputTokens: attemptRecord.OutputTokenCount,
- CacheReadCount: attemptRecord.CacheReadCount,
- CacheCreationCount: attemptRecord.CacheWriteCount,
- Cache5mCreationCount: attemptRecord.Cache5mWriteCount,
- Cache1hCreationCount: attemptRecord.Cache1hWriteCount,
- }
- // Use ResponseModel for pricing (actual model from API response), fallback to MappedModel
- pricingModel := attemptRecord.ResponseModel
- if pricingModel == "" {
- pricingModel = attemptRecord.MappedModel
- }
- // Get multiplier from provider config
- multiplier := getProviderMultiplier(matchedRoute.Provider, clientType)
- result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier)
- attemptRecord.Cost = result.Cost
- attemptRecord.ModelPriceID = result.ModelPriceID
- attemptRecord.Multiplier = result.Multiplier
- }
-
- // 检查是否需要立即清理 attempt 详情(设置为 0 时不保存)
- if e.shouldClearRequestDetail() {
- attemptRecord.RequestInfo = nil
- attemptRecord.ResponseInfo = nil
- }
-
- _ = e.attemptRepo.Update(attemptRecord)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord)
- }
- currentAttempt = nil // Clear so defer doesn't double update
-
- // Update proxyReq with latest attempt info (even on failure)
- proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID
- proxyReq.ModelPriceID = attemptRecord.ModelPriceID
- proxyReq.Multiplier = attemptRecord.Multiplier
-
- // Capture actual client response (even on failure, if any response was sent)
- if responseCapture.Body() != "" {
- proxyReq.StatusCode = responseCapture.StatusCode()
- if !e.shouldClearRequestDetail() {
- proxyReq.ResponseInfo = &domain.ResponseInfo{
- Status: responseCapture.StatusCode(),
- Headers: responseCapture.CapturedHeaders(),
- Body: responseCapture.Body(),
- }
- }
-
- // Extract token usage from final client response
- if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil {
- proxyReq.InputTokenCount = metrics.InputTokens
- proxyReq.OutputTokenCount = metrics.OutputTokens
- proxyReq.CacheReadCount = metrics.CacheReadCount
- proxyReq.CacheWriteCount = metrics.CacheCreationCount
- proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount
- proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount
- }
- }
- proxyReq.Cost = attemptRecord.Cost
- proxyReq.TTFT = attemptRecord.TTFT
-
- _ = e.proxyRequestRepo.Update(proxyReq)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
-
- // Handle cooldown only for real server/network errors, NOT client-side cancellations
- proxyErr, ok := err.(*domain.ProxyError)
- if ok && ctx.Err() != context.Canceled {
- log.Printf("[Executor] ProxyError - IsNetworkError: %v, IsServerError: %v, Retryable: %v, Provider: %d",
- proxyErr.IsNetworkError, proxyErr.IsServerError, proxyErr.Retryable, matchedRoute.Provider.ID)
- // Handle cooldown (unified cooldown logic for all providers)
- e.handleCooldown(attemptCtx, proxyErr, matchedRoute.Provider)
- // Broadcast cooldown update event to frontend
- if e.broadcaster != nil {
- e.broadcaster.BroadcastMessage("cooldown_update", map[string]interface{}{
- "providerID": matchedRoute.Provider.ID,
- })
- }
- } else if ok && ctx.Err() == context.Canceled {
- log.Printf("[Executor] Client disconnected, skipping cooldown for Provider: %d", matchedRoute.Provider.ID)
- } else if !ok {
- log.Printf("[Executor] Error is not ProxyError, type: %T, error: %v", err, err)
- }
-
- // Check if context was cancelled or timed out
- if ctx.Err() != nil {
- proxyReq.Status = "CANCELLED"
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- if ctx.Err() == context.Canceled {
- proxyReq.Error = "client disconnected"
- } else if ctx.Err() == context.DeadlineExceeded {
- proxyReq.Error = "request timeout"
- } else {
- proxyReq.Error = ctx.Err().Error()
- }
- _ = e.proxyRequestRepo.Update(proxyReq)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
- return ctx.Err()
- }
-
- // Check if retryable
- if !ok {
- break // Move to next route
- }
-
- if !proxyErr.Retryable {
- break // Move to next route
- }
-
- // Wait before retry (unless last attempt)
- if attempt < retryConfig.MaxRetries {
- waitTime := e.calculateBackoff(retryConfig, attempt)
- if proxyErr.RetryAfter > 0 {
- waitTime = proxyErr.RetryAfter
- }
- select {
- case <-ctx.Done():
- // Set final status before returning
- proxyReq.Status = "CANCELLED"
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- if ctx.Err() == context.Canceled {
- proxyReq.Error = "client disconnected during retry wait"
- } else if ctx.Err() == context.DeadlineExceeded {
- proxyReq.Error = "request timeout during retry wait"
- } else {
- proxyReq.Error = ctx.Err().Error()
- }
- _ = e.proxyRequestRepo.Update(proxyReq)
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
- return ctx.Err()
- case <-time.After(waitTime):
- }
- }
+ ctx := context.Background()
+ if v, ok := c.Get(flow.KeyProxyContext); ok {
+ if stored, ok := v.(context.Context); ok && stored != nil {
+ ctx = stored
}
- // Inner loop ended, will try next route if available
- }
-
- // All routes failed
- proxyReq.Status = "FAILED"
- proxyReq.EndTime = time.Now()
- proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
- if lastErr != nil {
- proxyReq.Error = lastErr.Error()
- }
-
- // 检查是否需要立即清理详情(设置为 0 时不保存)
- if e.shouldClearRequestDetail() {
- proxyReq.RequestInfo = nil
- proxyReq.ResponseInfo = nil
- }
-
- _ = e.proxyRequestRepo.Update(proxyReq)
-
- // Broadcast to WebSocket clients
- if e.broadcaster != nil {
- e.broadcaster.BroadcastProxyRequest(proxyReq)
- }
-
- if lastErr != nil {
- return lastErr
}
- return domain.NewProxyErrorWithMessage(domain.ErrAllRoutesFailed, false, "all routes exhausted")
+ state := &execState{ctx: ctx}
+ c.Set(flow.KeyExecutorState, state)
+ chain := []flow.HandlerFunc{e.egress, e.ingress}
+ chain = append(chain, e.middlewares...)
+ chain = append(chain, e.routeMatch, e.dispatch)
+ e.engine.HandleWith(c, chain...)
+ return state.lastErr
}
func (e *Executor) mapModel(requestModel string, route *domain.Route, provider *domain.Provider, clientType domain.ClientType, projectID uint64, apiTokenID uint64) string {
@@ -761,19 +171,16 @@ func flattenHeaders(h http.Header) map[string]string {
// handleCooldown processes cooldown information from ProxyError and sets provider cooldown
// Priority: 1) Explicit time from API, 2) Policy-based calculation based on failure reason
-func (e *Executor) handleCooldown(ctx context.Context, proxyErr *domain.ProxyError, provider *domain.Provider) {
- // Determine which client type to apply cooldown to
- clientType := proxyErr.CooldownClientType
+func (e *Executor) handleCooldown(proxyErr *domain.ProxyError, provider *domain.Provider, clientType domain.ClientType, originalClientType domain.ClientType) {
+ selectedClientType := proxyErr.CooldownClientType
if proxyErr.RateLimitInfo != nil && proxyErr.RateLimitInfo.ClientType != "" {
- clientType = proxyErr.RateLimitInfo.ClientType
+ selectedClientType = proxyErr.RateLimitInfo.ClientType
}
- // Fallback to original client type (before format conversion) if not specified
- if clientType == "" {
- // Prefer original client type over converted type
- if origCT := ctxutil.GetOriginalClientType(ctx); origCT != "" {
- clientType = string(origCT)
+ if selectedClientType == "" {
+ if originalClientType != "" {
+ selectedClientType = string(originalClientType)
} else {
- clientType = string(ctxutil.GetClientType(ctx))
+ selectedClientType = string(clientType)
}
}
@@ -815,11 +222,11 @@ func (e *Executor) handleCooldown(ctx context.Context, proxyErr *domain.ProxyErr
// Record failure and apply cooldown
// If explicitUntil is not nil, it will be used directly
// Otherwise, cooldown duration is calculated based on policy and failure count
- cooldown.Default().RecordFailure(provider.ID, clientType, reason, explicitUntil)
+ cooldown.Default().RecordFailure(provider.ID, selectedClientType, reason, explicitUntil)
// If there's an async update channel, listen for updates
if proxyErr.CooldownUpdateChan != nil {
- go e.handleAsyncCooldownUpdate(proxyErr.CooldownUpdateChan, provider, clientType)
+ go e.handleAsyncCooldownUpdate(proxyErr.CooldownUpdateChan, provider, selectedClientType)
}
}
diff --git a/internal/executor/flow_state.go b/internal/executor/flow_state.go
new file mode 100644
index 00000000..0508eb6d
--- /dev/null
+++ b/internal/executor/flow_state.go
@@ -0,0 +1,38 @@
+package executor
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
+ "github.com/awsl-project/maxx/internal/router"
+)
+
+type execState struct {
+ ctx context.Context
+ proxyReq *domain.ProxyRequest
+ routes []*router.MatchedRoute
+ currentAttempt *domain.ProxyUpstreamAttempt
+ lastErr error
+
+ clientType domain.ClientType
+ projectID uint64
+ sessionID string
+ requestModel string
+ isStream bool
+ apiTokenID uint64
+ requestBody []byte
+ originalRequestBody []byte
+ requestHeaders http.Header
+ requestURI string
+}
+
+func getExecState(c *flow.Ctx) (*execState, bool) {
+ v, ok := c.Get(flow.KeyExecutorState)
+ if !ok {
+ return nil, false
+ }
+ st, ok := v.(*execState)
+ return st, ok
+}
diff --git a/internal/executor/middleware_dispatch.go b/internal/executor/middleware_dispatch.go
new file mode 100644
index 00000000..3b49d614
--- /dev/null
+++ b/internal/executor/middleware_dispatch.go
@@ -0,0 +1,399 @@
+package executor
+
+import (
+ "context"
+ "log"
+ "net/http"
+ "time"
+
+ "github.com/awsl-project/maxx/internal/converter"
+ "github.com/awsl-project/maxx/internal/cooldown"
+ "github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
+ "github.com/awsl-project/maxx/internal/pricing"
+ "github.com/awsl-project/maxx/internal/usage"
+)
+
+func (e *Executor) dispatch(c *flow.Ctx) {
+ state, ok := getExecState(c)
+ if !ok {
+ err := domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "executor state missing")
+ c.Err = err
+ c.Abort()
+ return
+ }
+
+ proxyReq := state.proxyReq
+ ctx := state.ctx
+
+ for _, matchedRoute := range state.routes {
+ if ctx.Err() != nil {
+ state.lastErr = ctx.Err()
+ c.Err = state.lastErr
+ return
+ }
+
+ proxyReq.RouteID = matchedRoute.Route.ID
+ proxyReq.ProviderID = matchedRoute.Provider.ID
+ _ = e.proxyRequestRepo.Update(proxyReq)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+
+ clientType := state.clientType
+ mappedModel := e.mapModel(state.requestModel, matchedRoute.Route, matchedRoute.Provider, clientType, state.projectID, state.apiTokenID)
+
+ originalClientType := clientType
+ currentClientType := clientType
+ needsConversion := false
+ convertedBody := []byte(nil)
+ var convErr error
+ requestBody := state.requestBody
+ requestURI := state.requestURI
+
+ supportedTypes := matchedRoute.ProviderAdapter.SupportedClientTypes()
+ if e.converter.NeedConvert(clientType, supportedTypes) {
+ currentClientType = GetPreferredTargetType(supportedTypes, clientType, matchedRoute.Provider.Type)
+ if currentClientType != clientType {
+ needsConversion = true
+ log.Printf("[Executor] Format conversion needed: %s -> %s for provider %s",
+ clientType, currentClientType, matchedRoute.Provider.Name)
+
+ if currentClientType == domain.ClientTypeCodex {
+ if headers := state.requestHeaders; headers != nil {
+ requestBody = converter.InjectCodexUserAgent(requestBody, headers.Get("User-Agent"))
+ }
+ }
+ convertedBody, convErr = e.converter.TransformRequest(
+ clientType, currentClientType, requestBody, mappedModel, state.isStream)
+ if convErr != nil {
+ log.Printf("[Executor] Request conversion failed: %v, proceeding with original format", convErr)
+ needsConversion = false
+ currentClientType = clientType
+ } else {
+ requestBody = convertedBody
+
+ originalURI := requestURI
+ convertedURI := ConvertRequestURI(requestURI, clientType, currentClientType, mappedModel, state.isStream)
+ if convertedURI != originalURI {
+ requestURI = convertedURI
+ log.Printf("[Executor] URI converted: %s -> %s", originalURI, convertedURI)
+ }
+ }
+ }
+ }
+
+ retryConfig := e.getRetryConfig(matchedRoute.RetryConfig)
+
+ for attempt := 0; attempt <= retryConfig.MaxRetries; attempt++ {
+ if ctx.Err() != nil {
+ state.lastErr = ctx.Err()
+ c.Err = state.lastErr
+ return
+ }
+
+ attemptStartTime := time.Now()
+ attemptRecord := &domain.ProxyUpstreamAttempt{
+ ProxyRequestID: proxyReq.ID,
+ RouteID: matchedRoute.Route.ID,
+ ProviderID: matchedRoute.Provider.ID,
+ IsStream: state.isStream,
+ Status: "IN_PROGRESS",
+ StartTime: attemptStartTime,
+ RequestModel: state.requestModel,
+ MappedModel: mappedModel,
+ RequestInfo: proxyReq.RequestInfo,
+ }
+ if err := e.attemptRepo.Create(attemptRecord); err != nil {
+ log.Printf("[Executor] Failed to create attempt record: %v", err)
+ }
+ state.currentAttempt = attemptRecord
+
+ proxyReq.ProxyUpstreamAttemptCount++
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord)
+ }
+
+ eventChan := domain.NewAdapterEventChan()
+ c.Set(flow.KeyClientType, currentClientType)
+ c.Set(flow.KeyOriginalClientType, originalClientType)
+ c.Set(flow.KeyMappedModel, mappedModel)
+ c.Set(flow.KeyRequestBody, requestBody)
+ c.Set(flow.KeyRequestURI, requestURI)
+ c.Set(flow.KeyRequestHeaders, state.requestHeaders)
+ c.Set(flow.KeyProxyRequest, state.proxyReq)
+ c.Set(flow.KeyUpstreamAttempt, attemptRecord)
+ c.Set(flow.KeyEventChan, eventChan)
+ c.Set(flow.KeyBroadcaster, e.broadcaster)
+ eventDone := make(chan struct{})
+ go e.processAdapterEventsRealtime(eventChan, attemptRecord, eventDone)
+
+ var responseWriter http.ResponseWriter
+ var convertingWriter *ConvertingResponseWriter
+ responseCapture := NewResponseCapture(c.Writer)
+ if needsConversion {
+ convertingWriter = NewConvertingResponseWriter(
+ responseCapture, e.converter, originalClientType, currentClientType, state.isStream, state.originalRequestBody)
+ responseWriter = convertingWriter
+ } else {
+ responseWriter = responseCapture
+ }
+
+ originalWriter := c.Writer
+ c.Writer = responseWriter
+ err := matchedRoute.ProviderAdapter.Execute(c, matchedRoute.Provider)
+ c.Writer = originalWriter
+
+ if needsConversion && convertingWriter != nil && !state.isStream {
+ if finalizeErr := convertingWriter.Finalize(); finalizeErr != nil {
+ log.Printf("[Executor] Response conversion finalize failed: %v", finalizeErr)
+ }
+ }
+
+ eventChan.Close()
+ <-eventDone
+
+ if err == nil {
+ attemptRecord.EndTime = time.Now()
+ attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime)
+ attemptRecord.Status = "COMPLETED"
+
+ if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 {
+ metrics := &usage.Metrics{
+ InputTokens: attemptRecord.InputTokenCount,
+ OutputTokens: attemptRecord.OutputTokenCount,
+ CacheReadCount: attemptRecord.CacheReadCount,
+ CacheCreationCount: attemptRecord.CacheWriteCount,
+ Cache5mCreationCount: attemptRecord.Cache5mWriteCount,
+ Cache1hCreationCount: attemptRecord.Cache1hWriteCount,
+ }
+ pricingModel := attemptRecord.ResponseModel
+ if pricingModel == "" {
+ pricingModel = attemptRecord.MappedModel
+ }
+ multiplier := getProviderMultiplier(matchedRoute.Provider, clientType)
+ result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier)
+ attemptRecord.Cost = result.Cost
+ attemptRecord.ModelPriceID = result.ModelPriceID
+ attemptRecord.Multiplier = result.Multiplier
+ }
+
+ if e.shouldClearRequestDetail() {
+ attemptRecord.RequestInfo = nil
+ attemptRecord.ResponseInfo = nil
+ }
+
+ _ = e.attemptRepo.Update(attemptRecord)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord)
+ }
+ state.currentAttempt = nil
+
+ cooldown.Default().RecordSuccess(matchedRoute.Provider.ID, string(currentClientType))
+
+ proxyReq.Status = "COMPLETED"
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID
+ proxyReq.ModelPriceID = attemptRecord.ModelPriceID
+ proxyReq.Multiplier = attemptRecord.Multiplier
+ proxyReq.ResponseModel = mappedModel
+
+ if !e.shouldClearRequestDetail() {
+ proxyReq.ResponseInfo = &domain.ResponseInfo{
+ Status: responseCapture.StatusCode(),
+ Headers: responseCapture.CapturedHeaders(),
+ Body: responseCapture.Body(),
+ }
+ }
+ proxyReq.StatusCode = responseCapture.StatusCode()
+
+ if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil {
+ proxyReq.InputTokenCount = metrics.InputTokens
+ proxyReq.OutputTokenCount = metrics.OutputTokens
+ proxyReq.CacheReadCount = metrics.CacheReadCount
+ proxyReq.CacheWriteCount = metrics.CacheCreationCount
+ proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount
+ proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount
+ }
+ proxyReq.Cost = attemptRecord.Cost
+ proxyReq.TTFT = attemptRecord.TTFT
+
+ if e.shouldClearRequestDetail() {
+ proxyReq.RequestInfo = nil
+ proxyReq.ResponseInfo = nil
+ }
+
+ _ = e.proxyRequestRepo.Update(proxyReq)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+
+ state.lastErr = nil
+ state.ctx = ctx
+ return
+ }
+
+ attemptRecord.EndTime = time.Now()
+ attemptRecord.Duration = attemptRecord.EndTime.Sub(attemptRecord.StartTime)
+ state.lastErr = err
+
+ if ctx.Err() != nil {
+ attemptRecord.Status = "CANCELLED"
+ } else {
+ attemptRecord.Status = "FAILED"
+ }
+
+ if attemptRecord.InputTokenCount > 0 || attemptRecord.OutputTokenCount > 0 {
+ metrics := &usage.Metrics{
+ InputTokens: attemptRecord.InputTokenCount,
+ OutputTokens: attemptRecord.OutputTokenCount,
+ CacheReadCount: attemptRecord.CacheReadCount,
+ CacheCreationCount: attemptRecord.CacheWriteCount,
+ Cache5mCreationCount: attemptRecord.Cache5mWriteCount,
+ Cache1hCreationCount: attemptRecord.Cache1hWriteCount,
+ }
+ pricingModel := attemptRecord.ResponseModel
+ if pricingModel == "" {
+ pricingModel = attemptRecord.MappedModel
+ }
+ multiplier := getProviderMultiplier(matchedRoute.Provider, clientType)
+ result := pricing.GlobalCalculator().CalculateWithResult(pricingModel, metrics, multiplier)
+ attemptRecord.Cost = result.Cost
+ attemptRecord.ModelPriceID = result.ModelPriceID
+ attemptRecord.Multiplier = result.Multiplier
+ }
+
+ if e.shouldClearRequestDetail() {
+ attemptRecord.RequestInfo = nil
+ attemptRecord.ResponseInfo = nil
+ }
+
+ _ = e.attemptRepo.Update(attemptRecord)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyUpstreamAttempt(attemptRecord)
+ }
+ state.currentAttempt = nil
+
+ proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID
+ proxyReq.ModelPriceID = attemptRecord.ModelPriceID
+ proxyReq.Multiplier = attemptRecord.Multiplier
+
+ if responseCapture.Body() != "" {
+ proxyReq.StatusCode = responseCapture.StatusCode()
+ if !e.shouldClearRequestDetail() {
+ proxyReq.ResponseInfo = &domain.ResponseInfo{
+ Status: responseCapture.StatusCode(),
+ Headers: responseCapture.CapturedHeaders(),
+ Body: responseCapture.Body(),
+ }
+ }
+ if metrics := usage.ExtractFromResponse(responseCapture.Body()); metrics != nil {
+ proxyReq.InputTokenCount = metrics.InputTokens
+ proxyReq.OutputTokenCount = metrics.OutputTokens
+ proxyReq.CacheReadCount = metrics.CacheReadCount
+ proxyReq.CacheWriteCount = metrics.CacheCreationCount
+ proxyReq.Cache5mWriteCount = metrics.Cache5mCreationCount
+ proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount
+ }
+ }
+ proxyReq.Cost = attemptRecord.Cost
+ proxyReq.TTFT = attemptRecord.TTFT
+
+ _ = e.proxyRequestRepo.Update(proxyReq)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+
+ proxyErr, ok := err.(*domain.ProxyError)
+ if ok && ctx.Err() != nil {
+ proxyReq.Status = "CANCELLED"
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ if ctx.Err() == context.Canceled {
+ proxyReq.Error = "client disconnected"
+ } else if ctx.Err() == context.DeadlineExceeded {
+ proxyReq.Error = "request timeout"
+ } else {
+ proxyReq.Error = ctx.Err().Error()
+ }
+ _ = e.proxyRequestRepo.Update(proxyReq)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+ state.lastErr = ctx.Err()
+ c.Err = state.lastErr
+ return
+ }
+
+ if ok && ctx.Err() != context.Canceled {
+ log.Printf("[Executor] ProxyError - IsNetworkError: %v, IsServerError: %v, Retryable: %v, Provider: %d",
+ proxyErr.IsNetworkError, proxyErr.IsServerError, proxyErr.Retryable, matchedRoute.Provider.ID)
+ e.handleCooldown(proxyErr, matchedRoute.Provider, currentClientType, originalClientType)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastMessage("cooldown_update", map[string]interface{}{
+ "providerID": matchedRoute.Provider.ID,
+ })
+ }
+ } else if ok && ctx.Err() == context.Canceled {
+ log.Printf("[Executor] Client disconnected, skipping cooldown for Provider: %d", matchedRoute.Provider.ID)
+ } else if !ok {
+ log.Printf("[Executor] Error is not ProxyError, type: %T, error: %v", err, err)
+ }
+
+ if !ok || !proxyErr.Retryable {
+ break
+ }
+
+ if attempt < retryConfig.MaxRetries {
+ waitTime := e.calculateBackoff(retryConfig, attempt)
+ if proxyErr.RetryAfter > 0 {
+ waitTime = proxyErr.RetryAfter
+ }
+ select {
+ case <-ctx.Done():
+ proxyReq.Status = "CANCELLED"
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ if ctx.Err() == context.Canceled {
+ proxyReq.Error = "client disconnected during retry wait"
+ } else if ctx.Err() == context.DeadlineExceeded {
+ proxyReq.Error = "request timeout during retry wait"
+ } else {
+ proxyReq.Error = ctx.Err().Error()
+ }
+ _ = e.proxyRequestRepo.Update(proxyReq)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+ state.lastErr = ctx.Err()
+ c.Err = state.lastErr
+ return
+ case <-time.After(waitTime):
+ }
+ }
+ }
+ }
+
+ proxyReq.Status = "FAILED"
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ if state.lastErr != nil {
+ proxyReq.Error = state.lastErr.Error()
+ }
+ if e.shouldClearRequestDetail() {
+ proxyReq.RequestInfo = nil
+ proxyReq.ResponseInfo = nil
+ }
+ _ = e.proxyRequestRepo.Update(proxyReq)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+
+ if state.lastErr == nil {
+ state.lastErr = domain.NewProxyErrorWithMessage(domain.ErrAllRoutesFailed, false, "all routes exhausted")
+ }
+ state.ctx = ctx
+ c.Err = state.lastErr
+}
diff --git a/internal/executor/middleware_egress.go b/internal/executor/middleware_egress.go
new file mode 100644
index 00000000..9b41cd1b
--- /dev/null
+++ b/internal/executor/middleware_egress.go
@@ -0,0 +1,56 @@
+package executor
+
+import (
+ "context"
+ "time"
+
+ "github.com/awsl-project/maxx/internal/flow"
+)
+
+func (e *Executor) egress(c *flow.Ctx) {
+ state, ok := getExecState(c)
+ if !ok {
+ c.Next()
+ return
+ }
+
+ c.Next()
+
+ proxyReq := state.proxyReq
+ if proxyReq != nil && proxyReq.Status == "IN_PROGRESS" {
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ if state.ctx != nil && state.ctx.Err() != nil {
+ proxyReq.Status = "CANCELLED"
+ if state.ctx.Err() == context.Canceled {
+ proxyReq.Error = "client disconnected"
+ } else if state.ctx.Err() == context.DeadlineExceeded {
+ proxyReq.Error = "request timeout"
+ } else {
+ proxyReq.Error = state.ctx.Err().Error()
+ }
+ } else {
+ proxyReq.Status = "FAILED"
+ }
+ _ = e.proxyRequestRepo.Update(proxyReq)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+ }
+
+ if state.currentAttempt != nil && state.currentAttempt.Status == "IN_PROGRESS" {
+ state.currentAttempt.EndTime = time.Now()
+ state.currentAttempt.Duration = state.currentAttempt.EndTime.Sub(state.currentAttempt.StartTime)
+ if state.ctx != nil && state.ctx.Err() != nil {
+ state.currentAttempt.Status = "CANCELLED"
+ } else {
+ state.currentAttempt.Status = "FAILED"
+ }
+ _ = e.attemptRepo.Update(state.currentAttempt)
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyUpstreamAttempt(state.currentAttempt)
+ }
+ }
+
+ _ = state.lastErr
+}
diff --git a/internal/executor/middleware_ingress.go b/internal/executor/middleware_ingress.go
new file mode 100644
index 00000000..00a959d7
--- /dev/null
+++ b/internal/executor/middleware_ingress.go
@@ -0,0 +1,169 @@
+package executor
+
+import (
+ "context"
+ "errors"
+ "log"
+ "net/http"
+ "time"
+
+ "github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
+)
+
+func (e *Executor) ingress(c *flow.Ctx) {
+ state, ok := getExecState(c)
+ if !ok {
+ err := domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "executor state missing")
+ c.Err = err
+ c.Abort()
+ return
+ }
+
+ ctx := state.ctx
+ if v, ok := c.Get(flow.KeyClientType); ok {
+ if ct, ok := v.(domain.ClientType); ok {
+ state.clientType = ct
+ }
+ }
+ if v, ok := c.Get(flow.KeyProjectID); ok {
+ if pid, ok := v.(uint64); ok {
+ state.projectID = pid
+ }
+ }
+ if v, ok := c.Get(flow.KeySessionID); ok {
+ if sid, ok := v.(string); ok {
+ state.sessionID = sid
+ }
+ }
+ if v, ok := c.Get(flow.KeyRequestModel); ok {
+ if model, ok := v.(string); ok {
+ state.requestModel = model
+ }
+ }
+ if v, ok := c.Get(flow.KeyIsStream); ok {
+ if s, ok := v.(bool); ok {
+ state.isStream = s
+ }
+ }
+ if v, ok := c.Get(flow.KeyAPITokenID); ok {
+ if id, ok := v.(uint64); ok {
+ state.apiTokenID = id
+ }
+ }
+ if v, ok := c.Get(flow.KeyRequestBody); ok {
+ if body, ok := v.([]byte); ok {
+ state.requestBody = body
+ }
+ }
+ if v, ok := c.Get(flow.KeyOriginalRequestBody); ok {
+ if body, ok := v.([]byte); ok {
+ state.originalRequestBody = body
+ }
+ }
+ if v, ok := c.Get(flow.KeyRequestHeaders); ok {
+ if headers, ok := v.(map[string][]string); ok {
+ state.requestHeaders = headers
+ }
+ if headers, ok := v.(http.Header); ok {
+ state.requestHeaders = headers
+ }
+ }
+ if v, ok := c.Get(flow.KeyRequestURI); ok {
+ if uri, ok := v.(string); ok {
+ state.requestURI = uri
+ }
+ }
+
+ proxyReq := &domain.ProxyRequest{
+ InstanceID: e.instanceID,
+ RequestID: generateRequestID(),
+ SessionID: state.sessionID,
+ ClientType: state.clientType,
+ ProjectID: state.projectID,
+ RequestModel: state.requestModel,
+ StartTime: time.Now(),
+ IsStream: state.isStream,
+ Status: "PENDING",
+ APITokenID: state.apiTokenID,
+ }
+
+ if !e.shouldClearRequestDetail() {
+ requestURI := state.requestURI
+ requestHeaders := state.requestHeaders
+ requestBody := state.requestBody
+ headers := flattenHeaders(requestHeaders)
+ if c.Request != nil {
+ if c.Request.Host != "" {
+ if headers == nil {
+ headers = make(map[string]string)
+ }
+ headers["Host"] = c.Request.Host
+ }
+ proxyReq.RequestInfo = &domain.RequestInfo{
+ Method: c.Request.Method,
+ URL: requestURI,
+ Headers: headers,
+ Body: string(requestBody),
+ }
+ }
+ }
+
+ if err := e.proxyRequestRepo.Create(proxyReq); err != nil {
+ log.Printf("[Executor] Failed to create proxy request: %v", err)
+ }
+
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+
+ state.proxyReq = proxyReq
+ state.ctx = ctx
+
+ if state.projectID == 0 && e.projectWaiter != nil {
+ session, _ := e.sessionRepo.GetBySessionID(state.sessionID)
+ if session == nil {
+ session = &domain.Session{
+ SessionID: state.sessionID,
+ ClientType: state.clientType,
+ ProjectID: 0,
+ }
+ }
+
+ if err := e.projectWaiter.WaitForProject(ctx, session); err != nil {
+ status := "REJECTED"
+ errorMsg := "project binding timeout: " + err.Error()
+ if errors.Is(err, context.Canceled) {
+ status = "CANCELLED"
+ errorMsg = "client cancelled: " + err.Error()
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastMessage("session_pending_cancelled", map[string]interface{}{
+ "sessionID": state.sessionID,
+ })
+ }
+ }
+
+ proxyReq.Status = status
+ proxyReq.Error = errorMsg
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ _ = e.proxyRequestRepo.Update(proxyReq)
+
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+
+ err := domain.NewProxyErrorWithMessage(err, false, "project binding required: "+err.Error())
+ state.lastErr = err
+ c.Err = err
+ c.Abort()
+ return
+ }
+
+ state.projectID = session.ProjectID
+ proxyReq.ProjectID = state.projectID
+ state.ctx = ctx
+ }
+
+ c.Next()
+}
diff --git a/internal/executor/middleware_route_match.go b/internal/executor/middleware_route_match.go
new file mode 100644
index 00000000..fb11ad71
--- /dev/null
+++ b/internal/executor/middleware_route_match.go
@@ -0,0 +1,75 @@
+package executor
+
+import (
+ "fmt"
+ "log"
+ "time"
+
+ "github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/flow"
+ "github.com/awsl-project/maxx/internal/router"
+)
+
+func (e *Executor) routeMatch(c *flow.Ctx) {
+ state, ok := getExecState(c)
+ if !ok {
+ err := domain.NewProxyErrorWithMessage(domain.ErrInvalidInput, false, "executor state missing")
+ c.Err = err
+ c.Abort()
+ return
+ }
+
+ proxyReq := state.proxyReq
+ routes, err := e.router.Match(&router.MatchContext{
+ ClientType: state.clientType,
+ ProjectID: state.projectID,
+ RequestModel: state.requestModel,
+ APITokenID: state.apiTokenID,
+ })
+ if err != nil {
+ proxyReq.Status = "FAILED"
+ proxyReq.Error = "no routes available"
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ if err := e.proxyRequestRepo.Update(proxyReq); err != nil {
+ log.Printf("[Executor] failed to update proxy request: %v", err)
+ }
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+ err = domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, fmt.Sprintf("route match failed: %v", err))
+ state.lastErr = err
+ c.Err = err
+ c.Abort()
+ return
+ }
+
+ if len(routes) == 0 {
+ proxyReq.Status = "FAILED"
+ proxyReq.Error = "no routes configured"
+ proxyReq.EndTime = time.Now()
+ proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime)
+ if err := e.proxyRequestRepo.Update(proxyReq); err != nil {
+ log.Printf("[Executor] failed to update proxy request: %v", err)
+ }
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+ err = domain.NewProxyErrorWithMessage(domain.ErrNoRoutes, false, "no routes configured")
+ state.lastErr = err
+ c.Err = err
+ c.Abort()
+ return
+ }
+
+ proxyReq.Status = "IN_PROGRESS"
+ if err := e.proxyRequestRepo.Update(proxyReq); err != nil {
+ log.Printf("[Executor] failed to update proxy request: %v", err)
+ }
+ if e.broadcaster != nil {
+ e.broadcaster.BroadcastProxyRequest(proxyReq)
+ }
+ state.routes = routes
+
+ c.Next()
+}
diff --git a/internal/flow/engine.go b/internal/flow/engine.go
new file mode 100644
index 00000000..34c81191
--- /dev/null
+++ b/internal/flow/engine.go
@@ -0,0 +1,88 @@
+package flow
+
+import (
+ "io"
+ "net/http"
+)
+
+type HandlerFunc func(*Ctx)
+
+type Engine struct {
+ handlers []HandlerFunc
+}
+
+func NewEngine() *Engine {
+ return &Engine{}
+}
+
+func (e *Engine) Use(handlers ...HandlerFunc) {
+ e.handlers = append(e.handlers, handlers...)
+}
+
+func (e *Engine) Handle(c *Ctx) {
+ c.handlers = e.handlers
+ c.index = -1
+ c.Next()
+}
+
+func (e *Engine) HandleWith(c *Ctx, handlers ...HandlerFunc) {
+ c.handlers = append(append([]HandlerFunc{}, e.handlers...), handlers...)
+ c.index = -1
+ c.Next()
+}
+
+type Ctx struct {
+ Writer http.ResponseWriter
+ Request *http.Request
+ InboundBody []byte
+ OutboundBody []byte
+ StreamBody io.ReadCloser
+ IsStream bool
+ Keys map[string]interface{}
+ Err error
+
+ handlers []HandlerFunc
+ index int
+ aborted bool
+}
+
+func NewCtx(w http.ResponseWriter, r *http.Request) *Ctx {
+ return &Ctx{
+ Writer: w,
+ Request: r,
+ Keys: make(map[string]interface{}),
+ }
+}
+
+func (c *Ctx) Next() {
+ if c.aborted {
+ return
+ }
+ c.index++
+ for c.index < len(c.handlers) {
+ c.handlers[c.index](c)
+ if c.aborted {
+ return
+ }
+ c.index++
+ }
+}
+
+func (c *Ctx) Abort() {
+ c.aborted = true
+}
+
+func (c *Ctx) Set(key string, value interface{}) {
+ if c.Keys == nil {
+ c.Keys = make(map[string]interface{})
+ }
+ c.Keys[key] = value
+}
+
+func (c *Ctx) Get(key string) (interface{}, bool) {
+ if c.Keys == nil {
+ return nil, false
+ }
+ v, ok := c.Keys[key]
+ return v, ok
+}
diff --git a/internal/flow/helpers.go b/internal/flow/helpers.go
new file mode 100644
index 00000000..37cb14eb
--- /dev/null
+++ b/internal/flow/helpers.go
@@ -0,0 +1,152 @@
+package flow
+
+import (
+ "net/http"
+
+ "github.com/awsl-project/maxx/internal/domain"
+ "github.com/awsl-project/maxx/internal/event"
+)
+
+func GetClientType(c *Ctx) domain.ClientType {
+ if v, ok := c.Get(KeyClientType); ok {
+ if ct, ok := v.(domain.ClientType); ok {
+ return ct
+ }
+ }
+ return ""
+}
+
+func GetOriginalClientType(c *Ctx) domain.ClientType {
+ if v, ok := c.Get(KeyOriginalClientType); ok {
+ if ct, ok := v.(domain.ClientType); ok {
+ return ct
+ }
+ }
+ return ""
+}
+
+func GetSessionID(c *Ctx) string {
+ if v, ok := c.Get(KeySessionID); ok {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+func GetProjectID(c *Ctx) uint64 {
+ if v, ok := c.Get(KeyProjectID); ok {
+ if id, ok := v.(uint64); ok {
+ return id
+ }
+ }
+ return 0
+}
+
+func GetRequestModel(c *Ctx) string {
+ if v, ok := c.Get(KeyRequestModel); ok {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+func GetMappedModel(c *Ctx) string {
+ if v, ok := c.Get(KeyMappedModel); ok {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+func GetRequestBody(c *Ctx) []byte {
+ if v, ok := c.Get(KeyRequestBody); ok {
+ if b, ok := v.([]byte); ok {
+ return b
+ }
+ }
+ return nil
+}
+
+func GetOriginalRequestBody(c *Ctx) []byte {
+ if v, ok := c.Get(KeyOriginalRequestBody); ok {
+ if b, ok := v.([]byte); ok {
+ return b
+ }
+ }
+ return nil
+}
+
+func GetRequestHeaders(c *Ctx) http.Header {
+ if v, ok := c.Get(KeyRequestHeaders); ok {
+ if h, ok := v.(http.Header); ok {
+ return h
+ }
+ }
+ return nil
+}
+
+func GetRequestURI(c *Ctx) string {
+ if v, ok := c.Get(KeyRequestURI); ok {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+func GetIsStream(c *Ctx) bool {
+ if v, ok := c.Get(KeyIsStream); ok {
+ if s, ok := v.(bool); ok {
+ return s
+ }
+ }
+ return false
+}
+
+func GetAPITokenID(c *Ctx) uint64 {
+ if v, ok := c.Get(KeyAPITokenID); ok {
+ if id, ok := v.(uint64); ok {
+ return id
+ }
+ }
+ return 0
+}
+
+func GetProxyRequest(c *Ctx) *domain.ProxyRequest {
+ if v, ok := c.Get(KeyProxyRequest); ok {
+ if pr, ok := v.(*domain.ProxyRequest); ok {
+ return pr
+ }
+ }
+ return nil
+}
+
+func GetUpstreamAttempt(c *Ctx) *domain.ProxyUpstreamAttempt {
+ if v, ok := c.Get(KeyUpstreamAttempt); ok {
+ if at, ok := v.(*domain.ProxyUpstreamAttempt); ok {
+ return at
+ }
+ }
+ return nil
+}
+
+func GetEventChan(c *Ctx) domain.AdapterEventChan {
+ if v, ok := c.Get(KeyEventChan); ok {
+ if ch, ok := v.(domain.AdapterEventChan); ok {
+ return ch
+ }
+ }
+ return nil
+}
+
+func GetBroadcaster(c *Ctx) event.Broadcaster {
+ if v, ok := c.Get(KeyBroadcaster); ok {
+ if b, ok := v.(event.Broadcaster); ok {
+ return b
+ }
+ }
+ return nil
+}
diff --git a/internal/flow/keys.go b/internal/flow/keys.go
new file mode 100644
index 00000000..caaec1cf
--- /dev/null
+++ b/internal/flow/keys.go
@@ -0,0 +1,25 @@
+package flow
+
+const (
+ KeyProxyContext = "proxy_context"
+ KeyProxyStream = "proxy_stream"
+ KeyProxyRequestModel = "proxy_request_model"
+ KeyExecutorState = "executor_state"
+
+ KeyClientType = "client_type"
+ KeyOriginalClientType = "original_client_type"
+ KeySessionID = "session_id"
+ KeyProjectID = "project_id"
+ KeyRequestModel = "request_model"
+ KeyMappedModel = "mapped_model"
+ KeyRequestBody = "request_body"
+ KeyOriginalRequestBody = "original_request_body"
+ KeyRequestHeaders = "request_headers"
+ KeyRequestURI = "request_uri"
+ KeyIsStream = "is_stream"
+ KeyAPITokenID = "api_token_id"
+ KeyProxyRequest = "proxy_request"
+ KeyUpstreamAttempt = "upstream_attempt"
+ KeyEventChan = "event_chan"
+ KeyBroadcaster = "broadcaster"
+)
diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go
index 5934e27e..f2e63dfc 100644
--- a/internal/handler/proxy.go
+++ b/internal/handler/proxy.go
@@ -1,6 +1,7 @@
package handler
import (
+ "bytes"
"encoding/json"
"io"
"log"
@@ -10,9 +11,10 @@ import (
"sync"
"github.com/awsl-project/maxx/internal/adapter/client"
- ctxutil "github.com/awsl-project/maxx/internal/context"
+ "github.com/awsl-project/maxx/internal/converter"
"github.com/awsl-project/maxx/internal/domain"
"github.com/awsl-project/maxx/internal/executor"
+ "github.com/awsl-project/maxx/internal/flow"
"github.com/awsl-project/maxx/internal/repository/cached"
)
@@ -31,6 +33,8 @@ type ProxyHandler struct {
tokenAuth *TokenAuthMiddleware
tracker RequestTracker
trackerMu sync.RWMutex
+ engine *flow.Engine
+ extra []flow.HandlerFunc
}
// NewProxyHandler creates a new proxy handler
@@ -40,12 +44,19 @@ func NewProxyHandler(
sessionRepo *cached.SessionRepository,
tokenAuth *TokenAuthMiddleware,
) *ProxyHandler {
- return &ProxyHandler{
+ h := &ProxyHandler{
clientAdapter: clientAdapter,
executor: exec,
sessionRepo: sessionRepo,
tokenAuth: tokenAuth,
+ engine: flow.NewEngine(),
}
+ h.engine.Use(h.ingress)
+ return h
+}
+
+func (h *ProxyHandler) Use(handlers ...flow.HandlerFunc) {
+ h.extra = append(h.extra, handlers...)
}
// SetRequestTracker sets the request tracker for graceful shutdown
@@ -57,6 +68,16 @@ func (h *ProxyHandler) SetRequestTracker(tracker RequestTracker) {
// ServeHTTP handles proxy requests
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ ctx := flow.NewCtx(w, r)
+ handlers := make([]flow.HandlerFunc, len(h.extra)+1)
+ copy(handlers, h.extra)
+ handlers[len(h.extra)] = h.dispatch
+ h.engine.HandleWith(ctx, handlers...)
+}
+
+func (h *ProxyHandler) ingress(c *flow.Ctx) {
+ r := c.Request
+ w := c.Writer
log.Printf("[Proxy] Received request: %s %s", r.Method, r.URL.Path)
// Track request for graceful shutdown
@@ -66,9 +87,9 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if tracker != nil {
if !tracker.Add() {
- // Server is shutting down, reject new requests
log.Printf("[Proxy] Rejecting request during shutdown: %s %s", r.Method, r.URL.Path)
writeError(w, http.StatusServiceUnavailable, "server is shutting down")
+ c.Abort()
return
}
defer tracker.Done()
@@ -76,6 +97,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
+ c.Abort()
return
}
@@ -88,26 +110,36 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
_ = r.Body.Close()
writeCountTokensResponse(w)
+ c.Abort()
return
}
- // Read body
body, err := io.ReadAll(r.Body)
if err != nil {
writeError(w, http.StatusBadRequest, "failed to read request body")
+ c.Abort()
return
}
- defer r.Body.Close()
+ _ = r.Body.Close()
+
+ // Normalize OpenAI Responses payloads sent to chat/completions
+ if strings.HasPrefix(r.URL.Path, "/v1/chat/completions") {
+ if normalized, ok := normalizeOpenAIChatCompletionsPayload(body); ok {
+ body = normalized
+ }
+ }
+
+ r.Body = io.NopCloser(bytes.NewReader(body))
+ ctx := r.Context()
- // Detect client type and extract info
clientType := h.clientAdapter.DetectClientType(r, body)
log.Printf("[Proxy] Detected client type: %s", clientType)
if clientType == "" {
writeError(w, http.StatusBadRequest, "unable to detect client type")
+ c.Abort()
return
}
- // Token authentication (uses clientType for primary header, with fallback)
var apiToken *domain.APIToken
var apiTokenID uint64
if h.tokenAuth != nil {
@@ -115,6 +147,7 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err != nil {
log.Printf("[Proxy] Token auth failed: %v", err)
writeError(w, http.StatusUnauthorized, err.Error())
+ c.Abort()
return
}
if apiToken != nil {
@@ -128,18 +161,17 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sessionID := h.clientAdapter.ExtractSessionID(r, body, clientType)
stream := h.clientAdapter.IsStreamRequest(r, body)
- // Build context
- ctx := r.Context()
- ctx = ctxutil.WithClientType(ctx, clientType)
- ctx = ctxutil.WithSessionID(ctx, sessionID)
- ctx = ctxutil.WithRequestModel(ctx, requestModel)
- ctx = ctxutil.WithRequestBody(ctx, body)
- ctx = ctxutil.WithRequestHeaders(ctx, r.Header)
- ctx = ctxutil.WithRequestURI(ctx, r.URL.RequestURI())
- ctx = ctxutil.WithIsStream(ctx, stream)
- ctx = ctxutil.WithAPITokenID(ctx, apiTokenID)
-
- // Check for project ID from header (set by ProjectProxyHandler)
+ c.Set(flow.KeyClientType, clientType)
+ c.Set(flow.KeySessionID, sessionID)
+ c.Set(flow.KeyRequestModel, requestModel)
+ originalBody := bytes.Clone(body)
+ c.Set(flow.KeyRequestBody, body)
+ c.Set(flow.KeyOriginalRequestBody, originalBody)
+ c.Set(flow.KeyRequestHeaders, r.Header)
+ c.Set(flow.KeyRequestURI, r.URL.RequestURI())
+ c.Set(flow.KeyIsStream, stream)
+ c.Set(flow.KeyAPITokenID, apiTokenID)
+
var projectID uint64
if pidStr := r.Header.Get("X-Maxx-Project-ID"); pidStr != "" {
if pid, err := strconv.ParseUint(pidStr, 10, 64); err == nil {
@@ -148,10 +180,8 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
- // Get or create session to get project ID
session, _ := h.sessionRepo.GetBySessionID(sessionID)
if session != nil {
- // Priority: Session binding (Admin configured) > Token association > Header > 0
if session.ProjectID > 0 {
projectID = session.ProjectID
log.Printf("[Proxy] Using project ID from session binding: %d", projectID)
@@ -160,8 +190,6 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("[Proxy] Using project ID from token: %d", projectID)
}
} else {
- // Create new session
- // If no project from header, use token's project
if projectID == 0 && apiToken != nil && apiToken.ProjectID > 0 {
projectID = apiToken.ProjectID
log.Printf("[Proxy] Using project ID from token for new session: %d", projectID)
@@ -174,22 +202,74 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_ = h.sessionRepo.Create(session)
}
- ctx = ctxutil.WithProjectID(ctx, projectID)
+ c.Set(flow.KeyProjectID, projectID)
- // Execute request (executor handles request recording, project binding, routing, etc.)
- err = h.executor.Execute(ctx, w, r)
- if err != nil {
- proxyErr, ok := err.(*domain.ProxyError)
- if ok {
- if stream {
- writeStreamError(w, proxyErr)
- } else {
- writeProxyError(w, proxyErr)
- }
+ r = r.WithContext(ctx)
+ c.Request = r
+ c.InboundBody = body
+ c.IsStream = stream
+ c.Set(flow.KeyProxyContext, ctx)
+ c.Set(flow.KeyProxyStream, stream)
+ c.Set(flow.KeyProxyRequestModel, requestModel)
+
+ c.Next()
+}
+
+func (h *ProxyHandler) dispatch(c *flow.Ctx) {
+ stream := c.IsStream
+ if v, ok := c.Get(flow.KeyProxyStream); ok {
+ if s, ok := v.(bool); ok {
+ stream = s
+ }
+ }
+
+ err := h.executor.ExecuteWith(c)
+ if err == nil {
+ return
+ }
+ proxyErr, ok := err.(*domain.ProxyError)
+ if ok {
+ if stream {
+ writeStreamError(c.Writer, proxyErr)
} else {
- writeError(w, http.StatusInternalServerError, err.Error())
+ writeProxyError(c.Writer, proxyErr)
}
+ c.Err = err
+ c.Abort()
+ return
+ }
+ writeError(c.Writer, http.StatusInternalServerError, err.Error())
+ c.Err = err
+ c.Abort()
+}
+
+func normalizeOpenAIChatCompletionsPayload(body []byte) ([]byte, bool) {
+ var data map[string]interface{}
+ if err := json.Unmarshal(body, &data); err != nil {
+ return nil, false
+ }
+ if _, hasMessages := data["messages"]; hasMessages {
+ return nil, false
+ }
+ if _, hasInput := data["input"]; !hasInput {
+ if _, hasInstructions := data["instructions"]; !hasInstructions {
+ return nil, false
+ }
+ }
+
+ model, _ := data["model"].(string)
+ stream, _ := data["stream"].(bool)
+ converted, err := converter.GetGlobalRegistry().TransformRequest(
+ domain.ClientTypeCodex,
+ domain.ClientTypeOpenAI,
+ body,
+ model,
+ stream,
+ )
+ if err != nil {
+ return nil, false
}
+ return converted, true
}
// Helper functions
diff --git a/internal/service/admin.go b/internal/service/admin.go
index bcd56147..12a7c32a 100644
--- a/internal/service/admin.go
+++ b/internal/service/admin.go
@@ -553,11 +553,11 @@ func (s *AdminService) GetLogs(limit int) (*LogsResult, error) {
func (s *AdminService) autoSetSupportedClientTypes(provider *domain.Provider) {
switch provider.Type {
case "antigravity":
- // Antigravity natively supports Claude and Gemini
- // OpenAI requests will be converted to Claude format by Executor
+ // Antigravity natively supports Claude and Gemini.
+ // Conversion preference is Gemini-first.
provider.SupportedClientTypes = []domain.ClientType{
- domain.ClientTypeClaude,
domain.ClientTypeGemini,
+ domain.ClientTypeClaude,
}
case "kiro":
// Kiro natively supports Claude protocol only
diff --git a/web/package.json b/web/package.json
index 817611a9..373b61b1 100644
--- a/web/package.json
+++ b/web/package.json
@@ -52,6 +52,11 @@
"tw-animate-css": "^1.4.0",
"zustand": "^5.0.9"
},
+ "pnpm": {
+ "onlyBuiltDependencies": [
+ "esbuild"
+ ]
+ },
"devDependencies": {
"@eslint/js": "^9.39.1",
"@tailwindcss/postcss": "^4.1.18",
diff --git a/web/src/hooks/queries/use-providers.ts b/web/src/hooks/queries/use-providers.ts
index 009b8731..d5bcfa16 100644
--- a/web/src/hooks/queries/use-providers.ts
+++ b/web/src/hooks/queries/use-providers.ts
@@ -51,8 +51,16 @@ export function useUpdateProvider() {
const queryClient = useQueryClient();
return useMutation({
- mutationFn: ({ id, data }: { id: number; data: Partial
+
{t('provider.cloakStrictModeDesc')}
{t('provider.cloakSensitiveWordsDesc')}