diff --git a/app/app.go b/app/app.go index d1e42171..cf9f5935 100644 --- a/app/app.go +++ b/app/app.go @@ -13,12 +13,14 @@ import ( _ "github.com/awsl-project/maxx/internal/adapter/provider/antigravity" _ "github.com/awsl-project/maxx/internal/adapter/provider/custom" "github.com/awsl-project/maxx/internal/cooldown" + "github.com/awsl-project/maxx/internal/event" "github.com/awsl-project/maxx/internal/executor" "github.com/awsl-project/maxx/internal/handler" "github.com/awsl-project/maxx/internal/repository/cached" "github.com/awsl-project/maxx/internal/repository/sqlite" "github.com/awsl-project/maxx/internal/router" "github.com/awsl-project/maxx/internal/service" + "github.com/awsl-project/maxx/internal/waiter" "github.com/wailsapp/wails/v2/pkg/runtime" ) @@ -174,7 +176,13 @@ func (a *App) initializeServices() { a.antigravitySvc = NewAntigravityService(antigravityQuotaRepo, a.wsHub, a.adminService) - exec := executor.NewExecutor(r, proxyRequestRepo, attemptRepo, cachedRetryConfigRepo, a.wsHub, "") + // Create broadcaster (wraps WebSocket hub) + broadcaster := event.NewWailsBroadcaster(a.wsHub) + + // Create project waiter for force project binding + projectWaiter := waiter.NewProjectWaiter(cachedSessionRepo, settingRepo, broadcaster) + + exec := executor.NewExecutor(r, proxyRequestRepo, attemptRepo, cachedRetryConfigRepo, cachedSessionRepo, broadcaster, projectWaiter, "") proxyHandler := handler.NewProxyHandler(clientAdapter, exec, cachedSessionRepo) projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, cachedProjectRepo) diff --git a/app/bindings.go b/app/bindings.go index a8d9da32..349be217 100644 --- a/app/bindings.go +++ b/app/bindings.go @@ -285,3 +285,20 @@ func (a *App) StartAntigravityOAuth(ctx context.Context) (map[string]interface{} redirectURI := "http://localhost:19380/antigravity/oauth/callback" return a.antigravitySvc.StartOAuth(ctx, redirectURI) } + +// Antigravity Global Settings methods + +//wails:bind +func (a *App) GetAntigravityGlobalSettings() (*service.AntigravityGlobalSettings, error) { + return a.adminService.GetAntigravityGlobalSettings() +} + +//wails:bind +func (a *App) UpdateAntigravityGlobalSettings(settings *service.AntigravityGlobalSettings) error { + return a.adminService.UpdateAntigravityGlobalSettings(settings) +} + +//wails:bind +func (a *App) ResetAntigravityGlobalSettings() (*service.AntigravityGlobalSettings, error) { + return a.adminService.ResetAntigravityGlobalSettings() +} diff --git a/cmd/maxx/main.go b/cmd/maxx/main.go index dcbdae8b..a628a4d9 100644 --- a/cmd/maxx/main.go +++ b/cmd/maxx/main.go @@ -10,9 +10,10 @@ import ( "time" "github.com/awsl-project/maxx/internal/adapter/client" - _ "github.com/awsl-project/maxx/internal/adapter/provider/antigravity" // Register antigravity adapter - _ "github.com/awsl-project/maxx/internal/adapter/provider/custom" // Register custom adapter + "github.com/awsl-project/maxx/internal/adapter/provider/antigravity" + _ "github.com/awsl-project/maxx/internal/adapter/provider/custom" // Register custom adapter "github.com/awsl-project/maxx/internal/cooldown" + "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/executor" "github.com/awsl-project/maxx/internal/handler" "github.com/awsl-project/maxx/internal/repository/cached" @@ -187,6 +188,18 @@ func main() { r, // Router implements ProviderAdapterRefresher interface ) + // Initialize Antigravity global settings getter + antigravity.SetGlobalSettingsGetter(func() (*antigravity.GlobalSettings, error) { + rulesJSON, _ := settingRepo.Get(domain.SettingKeyAntigravityModelMapping) + rules, err := antigravity.ParseModelMappingRules(rulesJSON) + if err != nil { + return nil, err + } + return &antigravity.GlobalSettings{ + ModelMappingRules: rules, + }, nil + }) + // Create handlers proxyHandler := handler.NewProxyHandler(clientAdapter, exec, cachedSessionRepo) adminHandler := handler.NewAdminHandler(adminService, logPath) diff --git a/go.mod b/go.mod index f7aacf5d..2ef2276c 100644 --- a/go.mod +++ b/go.mod @@ -37,9 +37,4 @@ require ( golang.org/x/text v0.22.0 // indirect ) -require ( - github.com/emersion/go-autostart v0.0.0-20250403115856-34830d6457d2 - github.com/google/uuid v1.6.0 - github.com/gorilla/websocket v1.5.3 - github.com/wailsapp/wails/v2 v2.11.0 -) +require github.com/emersion/go-autostart v0.0.0-20250403115856-34830d6457d2 diff --git a/internal/adapter/provider/antigravity/adapter.go b/internal/adapter/provider/antigravity/adapter.go index e39592af..98e1e063 100644 --- a/internal/adapter/provider/antigravity/adapter.go +++ b/internal/adapter/provider/antigravity/adapter.go @@ -57,7 +57,7 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, clientType := ctxutil.GetClientType(ctx) baseCtx := ctx requestModel := ctxutil.GetRequestModel(ctx) // Original model from request (e.g., "claude-3-5-sonnet-20241022-online") - mappedModel := ctxutil.GetMappedModel(ctx) // Mapped model after route resolution + mappedModel := ctxutil.GetMappedModel(ctx) // Mapped model after executor's unified mapping requestBody := ctxutil.GetRequestBody(ctx) backgroundDowngrade := false backgroundModel := "" @@ -72,7 +72,6 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, } } - // [Model Mapping] Apply Antigravity model mapping (like Antigravity-Manager) // We'll attempt at most twice: original + retry without thinking on signature errors retriedWithoutThinking := false @@ -80,21 +79,16 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, ctx = ctxutil.WithRequestModel(baseCtx, requestModel) ctx = ctxutil.WithRequestBody(ctx, requestBody) - // Only map if route didn't provide a mapping (mappedModel empty or same as request) + // Apply background downgrade override if needed config := provider.Config.Antigravity - if mappedModel == "" || mappedModel == requestModel { - // Route didn't provide mapping, use our internal mapping with haikuTarget config - haikuTarget := "" - if config != nil { - haikuTarget = config.HaikuTarget - } - mappedModel = MapClaudeModelToGeminiWithConfig(requestModel, haikuTarget) - } if backgroundDowngrade && backgroundModel != "" { mappedModel = backgroundModel } - // If route provided a different mappedModel, trust it and don't re-map - // (user/route has explicitly configured the target model) + + // Update attempt record with the final mapped model (in case of background downgrade) + if attempt := ctxutil.GetUpstreamAttempt(ctx); attempt != nil { + attempt.MappedModel = mappedModel + } // Get streaming flag from context (already detected correctly for Gemini URL path) stream := ctxutil.GetIsStream(ctx) @@ -538,6 +532,9 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http attempt.Cache1hWriteCount = metrics.Cache1hCreationCount } + // Extract modelVersion from Gemini response + attempt.ResponseModel = extractModelVersion(unwrappedBody) + // Broadcast attempt update with token info if bc := ctxutil.GetBroadcaster(ctx); bc != nil { bc.BroadcastProxyUpstreamAttempt(attempt) @@ -629,6 +626,12 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re attempt.Cache5mWriteCount = metrics.Cache5mCreationCount attempt.Cache1hWriteCount = metrics.Cache1hCreationCount } + // Extract responseModel from claudeState (for Claude clients) or SSE content + if claudeState != nil { + attempt.ResponseModel = claudeState.GetModelVersion() + } else { + attempt.ResponseModel = extractModelVersionFromSSE(sseBuffer.String()) + } // Broadcast attempt update with token info if bc := ctxutil.GetBroadcaster(ctx); bc != nil { bc.BroadcastProxyUpstreamAttempt(attempt) @@ -846,6 +849,12 @@ func (a *AntigravityAdapter) handleCollectedStreamResponse(ctx context.Context, attempt.Cache5mWriteCount = metrics.Cache5mCreationCount attempt.Cache1hWriteCount = metrics.Cache1hCreationCount } + // Extract responseModel from claudeState (for Claude clients) or SSE content + if claudeState != nil { + attempt.ResponseModel = claudeState.GetModelVersion() + } else { + attempt.ResponseModel = extractModelVersionFromSSE(upstreamSSE.String()) + } if bc := ctxutil.GetBroadcaster(ctx); bc != nil { bc.BroadcastProxyUpstreamAttempt(attempt) } @@ -998,3 +1007,33 @@ func (a *AntigravityAdapter) parseRateLimitInfo(ctx context.Context, body []byte ClientType: "", // Global cooldown }, updateChan } + +// extractModelVersion extracts modelVersion from Gemini response JSON +func extractModelVersion(body []byte) string { + var resp struct { + ModelVersion string `json:"modelVersion"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "" + } + return resp.ModelVersion +} + +// extractModelVersionFromSSE extracts modelVersion from SSE content +// Looks for the last "modelVersion" field in the SSE data +func extractModelVersionFromSSE(sseContent string) string { + var lastModelVersion string + for _, line := range strings.Split(sseContent, "\n") { + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + var chunk struct { + ModelVersion string `json:"modelVersion"` + } + if err := json.Unmarshal([]byte(data), &chunk); err == nil && chunk.ModelVersion != "" { + lastModelVersion = chunk.ModelVersion + } + } + return lastModelVersion +} diff --git a/internal/adapter/provider/antigravity/claude_streaming.go b/internal/adapter/provider/antigravity/claude_streaming.go index 20065553..d23a9282 100644 --- a/internal/adapter/provider/antigravity/claude_streaming.go +++ b/internal/adapter/provider/antigravity/claude_streaming.go @@ -61,6 +61,11 @@ func NewClaudeStreamingStateWithSession(_ string, requestModel string) *ClaudeSt } } +// GetModelVersion returns the upstream model version captured during streaming +func (s *ClaudeStreamingState) GetModelVersion() string { + return s.modelVersion +} + // GeminiPart represents a part in Gemini response type GeminiPart struct { Text string `json:"text,omitempty"` diff --git a/internal/adapter/provider/antigravity/model_mapping.go b/internal/adapter/provider/antigravity/model_mapping.go index 1959f393..1b2813df 100644 --- a/internal/adapter/provider/antigravity/model_mapping.go +++ b/internal/adapter/provider/antigravity/model_mapping.go @@ -1,56 +1,65 @@ package antigravity -import "strings" - -// Claude to Gemini model mapping (like Antigravity-Manager) -var claudeToGeminiMap = map[string]string{ - // 直接支持的模型 - "claude-opus-4-5-thinking": "claude-opus-4-5-thinking", - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", - - // 别名映射 - "claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking", - "claude-3-5-sonnet-20241022": "claude-sonnet-4-5", - "claude-3-5-sonnet-20240620": "claude-sonnet-4-5", - "claude-opus-4": "claude-opus-4-5-thinking", - "claude-opus-4-5-20251101": "claude-opus-4-5-thinking", - - // Haiku 映射: 默认使用 gemini-2.5-flash-lite (省钱) - // 可通过 Provider 配置 haikuTarget 覆盖为 "claude-sonnet-4-5" (更强) - "claude-haiku-4": "gemini-2.5-flash-lite", - "claude-3-haiku-20240307": "gemini-2.5-flash-lite", - "claude-haiku-4-5-20251001": "gemini-2.5-flash-lite", - - // OpenAI 协议映射表 - "gpt-4": "gemini-2.5-pro", - "gpt-4-turbo": "gemini-2.5-pro", - "gpt-4-turbo-preview": "gemini-2.5-pro", - "gpt-4-0125-preview": "gemini-2.5-pro", - "gpt-4-1106-preview": "gemini-2.5-pro", - "gpt-4-0613": "gemini-2.5-pro", - "gpt-4o": "gemini-2.5-pro", - "gpt-4o-2024-05-13": "gemini-2.5-pro", - "gpt-4o-2024-08-06": "gemini-2.5-pro", - "gpt-4o-mini": "gemini-2.5-flash", - "gpt-4o-mini-2024-07-18": "gemini-2.5-flash", - "gpt-3.5-turbo": "gemini-2.5-flash", - "gpt-3.5-turbo-16k": "gemini-2.5-flash", - "gpt-3.5-turbo-0125": "gemini-2.5-flash", - "gpt-3.5-turbo-1106": "gemini-2.5-flash", - "gpt-3.5-turbo-0613": "gemini-2.5-flash", - - // Gemini 协议映射表 (直接穿透) - "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", - "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", - "gemini-3-pro-low": "gemini-3-pro-low", - "gemini-3-pro-high": "gemini-3-pro-high", - "gemini-3-pro-preview": "gemini-3-pro-preview", - "gemini-3-pro": "gemini-3-pro", - "gemini-2.5-flash": "gemini-2.5-flash", - "gemini-2.5-pro": "gemini-2.5-pro", - "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-image": "gemini-3-pro-image", +import ( + "log" + "strings" +) + +// ModelMappingRule represents a single model mapping rule +// Rules are matched in order, first match wins +type ModelMappingRule struct { + Pattern string // Source pattern, supports * wildcard + Target string // Target model name +} + +// defaultModelMappingRules is the ordered list of default mapping rules +// Rules are matched in order, first match wins (higher priority first) +// Supports wildcard patterns: * matches any characters +// Note: gemini-* models pass through automatically without needing a mapping rule +var defaultModelMappingRules = []ModelMappingRule{ + // Claude 模型 - 按优先级排序 + {"*opus*", "claude-opus-4-5-thinking"}, // 所有 opus 变体 + {"*sonnet*", "claude-sonnet-4-5"}, // 所有 sonnet 变体 + {"*haiku*", "gemini-2.5-flash-lite"}, // 所有 haiku 变体 + + // OpenAI 协议映射表 - gpt-4o-mini 优先于 gpt-4 + {"gpt-4o-mini*", "gemini-2.5-flash"}, // gpt-4o-mini 系列 + {"gpt-4*", "gemini-2.5-pro"}, // 所有 gpt-4 变体 + {"gpt-3.5-*", "gemini-2.5-flash"}, // 所有 gpt-3.5 变体 +} + +// GetDefaultModelMapping returns the default mapping as a map (for API compatibility) +// Note: The map loses ordering, use defaultModelMappingRules for ordered matching +func GetDefaultModelMapping() map[string]string { + result := make(map[string]string, len(defaultModelMappingRules)) + for _, rule := range defaultModelMappingRules { + result[rule.Pattern] = rule.Target + } + return result +} + +// AvailableTargetModels is the list of valid target models for mapping +var AvailableTargetModels = []string{ + // Claude models + "claude-opus-4-5-thinking", + "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking", + // Gemini models + "gemini-2.5-flash-lite", + "gemini-2.5-flash", + "gemini-2.5-flash-thinking", + "gemini-2.5-pro", + "gemini-3-flash", + "gemini-3-pro", + "gemini-3-pro-low", + "gemini-3-pro-high", + "gemini-3-pro-preview", + "gemini-3-pro-image", +} + +// GetAvailableTargetModels returns the list of valid target models +func GetAvailableTargetModels() []string { + return AvailableTargetModels } // MapClaudeModelToGemini maps Claude model names to Gemini model names @@ -70,21 +79,100 @@ func MapClaudeModelToGeminiWithConfig(input string, haikuTarget string) string { return haikuTarget } - // 2. Check exact match in map - if mapped, ok := claudeToGeminiMap[cleanInput]; ok { + // 2. Check global settings first (highest priority for user customization) + // Rules are matched in order, first match wins + if globalSettings := GetGlobalSettings(); globalSettings != nil { + if len(globalSettings.ModelMappingRules) > 0 { + if mapped := MatchRulesInOrder(cleanInput, globalSettings.ModelMappingRules); mapped != "" { + return mapped + } + } + } + + // 3. Check default rules in order + if mapped := MatchRulesInOrder(cleanInput, defaultModelMappingRules); mapped != "" { return mapped } - // 3. Pass-through known prefixes (gemini-, -thinking) to support dynamic suffixes + // 4. Pass-through known prefixes (gemini-, -thinking) to support dynamic suffixes // (like Antigravity-Manager) if strings.HasPrefix(cleanInput, "gemini-") || strings.Contains(cleanInput, "thinking") { return cleanInput } - // 4. Fallback to default + // 5. Fallback to default return "claude-sonnet-4-5" } +// MatchRulesInOrder matches input against rules in order, first match wins +// Returns the target model or empty string if no match +func MatchRulesInOrder(input string, rules []ModelMappingRule) string { + for i, rule := range rules { + matched := MatchWildcard(rule.Pattern, input) + log.Printf("[MatchRulesInOrder] Rule[%d]: pattern=%q, input=%q, matched=%v", i, rule.Pattern, input, matched) + if matched { + log.Printf("[MatchRulesInOrder] Matched! Returning target=%q", rule.Target) + return rule.Target + } + } + return "" +} + +// MatchWildcard checks if input matches a wildcard pattern +// Supports * as wildcard matching any characters +// Examples: +// - "claude-3-5-sonnet-*" matches "claude-3-5-sonnet-20241022" +// - "*haiku*" matches "claude-haiku-4", "claude-3-haiku-20240307" +// - "gpt-4-*" matches "gpt-4-turbo", "gpt-4-0613" +func MatchWildcard(pattern, input string) bool { + // Simple cases + if pattern == "*" { + return true + } + if !strings.Contains(pattern, "*") { + return pattern == input + } + + parts := strings.Split(pattern, "*") + + // Handle prefix-only pattern: "prefix*" + if len(parts) == 2 && parts[1] == "" { + return strings.HasPrefix(input, parts[0]) + } + + // Handle suffix-only pattern: "*suffix" + if len(parts) == 2 && parts[0] == "" { + return strings.HasSuffix(input, parts[1]) + } + + // Handle patterns with multiple wildcards + pos := 0 + for i, part := range parts { + if part == "" { + continue + } + + idx := strings.Index(input[pos:], part) + if idx < 0 { + return false + } + + // First part must be at the beginning if pattern doesn't start with * + if i == 0 && idx != 0 { + return false + } + + pos += idx + len(part) + } + + // Last part must be at the end if pattern doesn't end with * + if parts[len(parts)-1] != "" && !strings.HasSuffix(input, parts[len(parts)-1]) { + return false + } + + return true +} + // isHaikuModel checks if the model name is a Haiku variant func isHaikuModel(model string) bool { modelLower := strings.ToLower(model) diff --git a/internal/adapter/provider/antigravity/settings.go b/internal/adapter/provider/antigravity/settings.go new file mode 100644 index 00000000..b24fe3fc --- /dev/null +++ b/internal/adapter/provider/antigravity/settings.go @@ -0,0 +1,80 @@ +package antigravity + +import ( + "encoding/json" + "sync" +) + +// GlobalSettings holds global Antigravity configuration +type GlobalSettings struct { + // ModelMappingRules is an ordered list of model mapping rules + // Rules are matched in order, first match wins + ModelMappingRules []ModelMappingRule +} + +var ( + globalSettings *GlobalSettings + globalSettingsMu sync.RWMutex + settingsGetterFunc func() (*GlobalSettings, error) +) + +// SetGlobalSettingsGetter sets the function to retrieve global settings +// This should be called during application initialization +func SetGlobalSettingsGetter(getter func() (*GlobalSettings, error)) { + globalSettingsMu.Lock() + defer globalSettingsMu.Unlock() + settingsGetterFunc = getter +} + +// GetGlobalSettings retrieves the current global settings +func GetGlobalSettings() *GlobalSettings { + globalSettingsMu.RLock() + defer globalSettingsMu.RUnlock() + + if settingsGetterFunc == nil { + return nil + } + + settings, err := settingsGetterFunc() + if err != nil { + return nil + } + return settings +} + +// ParseModelMappingRules parses a JSON string into model mapping rules +// Supports both new array format and legacy map format for backwards compatibility +func ParseModelMappingRules(jsonStr string) ([]ModelMappingRule, error) { + if jsonStr == "" { + return nil, nil + } + + // Try new array format first + var rules []ModelMappingRule + if err := json.Unmarshal([]byte(jsonStr), &rules); err == nil { + return rules, nil + } + + // Fall back to legacy map format: {"pattern": "target", ...} + var legacyMap map[string]string + if err := json.Unmarshal([]byte(jsonStr), &legacyMap); err != nil { + return nil, err + } + + // Convert legacy map to rules array + rules = make([]ModelMappingRule, 0, len(legacyMap)) + for pattern, target := range legacyMap { + rules = append(rules, ModelMappingRule{ + Pattern: pattern, + Target: target, + }) + } + return rules, nil +} + +// GetDefaultModelMappingRules returns a copy of the default mapping rules +func GetDefaultModelMappingRules() []ModelMappingRule { + result := make([]ModelMappingRule, len(defaultModelMappingRules)) + copy(result, defaultModelMappingRules) + return result +} diff --git a/internal/core/database.go b/internal/core/database.go index c6f2a21b..d7572dd8 100644 --- a/internal/core/database.go +++ b/internal/core/database.go @@ -6,9 +6,10 @@ import ( "time" "github.com/awsl-project/maxx/internal/adapter/client" - _ "github.com/awsl-project/maxx/internal/adapter/provider/antigravity" + "github.com/awsl-project/maxx/internal/adapter/provider/antigravity" _ "github.com/awsl-project/maxx/internal/adapter/provider/custom" "github.com/awsl-project/maxx/internal/cooldown" + "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/event" "github.com/awsl-project/maxx/internal/executor" "github.com/awsl-project/maxx/internal/handler" @@ -234,6 +235,20 @@ func InitializeServerComponents( r, ) + log.Printf("[Core] Initializing Antigravity global settings getter") + antigravity.SetGlobalSettingsGetter(func() (*antigravity.GlobalSettings, error) { + // Read model mapping rules from database + rulesJSON, _ := repos.SettingRepo.Get(domain.SettingKeyAntigravityModelMapping) + rules, err := antigravity.ParseModelMappingRules(rulesJSON) + if err != nil { + return nil, err + } + + return &antigravity.GlobalSettings{ + ModelMappingRules: rules, + }, nil + }) + log.Printf("[Core] Creating handlers") proxyHandler := handler.NewProxyHandler(clientAdapter, exec, repos.CachedSessionRepo) adminHandler := handler.NewAdminHandler(adminService, logPath) diff --git a/internal/desktop/api.go b/internal/desktop/api.go index 868b0f6b..0ba3a62b 100644 --- a/internal/desktop/api.go +++ b/internal/desktop/api.go @@ -349,6 +349,20 @@ func (a *DesktopApp) StartAntigravityOAuth() (*AntigravityOAuthResult, error) { }, nil } +// ===== Antigravity Global Settings API ===== + +func (a *DesktopApp) GetAntigravityGlobalSettings() (*service.AntigravityGlobalSettings, error) { + return a.components.AdminService.GetAntigravityGlobalSettings() +} + +func (a *DesktopApp) UpdateAntigravityGlobalSettings(settings *service.AntigravityGlobalSettings) error { + return a.components.AdminService.UpdateAntigravityGlobalSettings(settings) +} + +func (a *DesktopApp) ResetAntigravityGlobalSettings() (*service.AntigravityGlobalSettings, error) { + return a.components.AdminService.ResetAntigravityGlobalSettings() +} + // ===== Cooldown API ===== func (a *DesktopApp) GetCooldowns() ([]*domain.Cooldown, error) { diff --git a/internal/domain/model.go b/internal/domain/model.go index b2a621d4..5b24d535 100644 --- a/internal/domain/model.go +++ b/internal/domain/model.go @@ -221,6 +221,14 @@ type ProxyUpstreamAttempt struct { // 是否为 SSE 流式请求 IsStream bool `json:"isStream"` + // 模型信息 + // RequestModel: 客户端请求的原始模型 + // MappedModel: 经过映射后实际发送给上游的模型 + // ResponseModel: 上游响应中返回的模型名称 + RequestModel string `json:"requestModel"` + MappedModel string `json:"mappedModel"` + ResponseModel string `json:"responseModel"` + RequestInfo *RequestInfo `json:"requestInfo"` ResponseInfo *ResponseInfo `json:"responseInfo"` @@ -311,7 +319,8 @@ type SystemSetting struct { // 系统设置 Key 常量 const ( - SettingKeyProxyPort = "proxy_port" // 代理服务器端口,默认 9880 + SettingKeyProxyPort = "proxy_port" // 代理服务器端口,默认 9880 + SettingKeyAntigravityModelMapping = "antigravity_model_mapping" // Antigravity 全局模型映射 (JSON) ) // Antigravity 模型配额 diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 40f0af32..ca8c5d1d 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -4,8 +4,10 @@ import ( "context" "log" "net/http" + "strings" "time" + "github.com/awsl-project/maxx/internal/adapter/provider/antigravity" "github.com/awsl-project/maxx/internal/cooldown" ctxutil "github.com/awsl-project/maxx/internal/context" "github.com/awsl-project/maxx/internal/domain" @@ -78,10 +80,18 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http requestURI := ctxutil.GetRequestURI(ctx) requestHeaders := ctxutil.GetRequestHeaders(ctx) requestBody := ctxutil.GetRequestBody(ctx) + headers := flattenHeaders(requestHeaders) + // Go stores Host separately from headers, add it explicitly + if req.Host != "" { + if headers == nil { + headers = make(map[string]string) + } + headers["Host"] = req.Host + } proxyReq.RequestInfo = &domain.RequestInfo{ Method: req.Method, URL: requestURI, - Headers: flattenHeaders(requestHeaders), + Headers: headers, Body: string(requestBody), } @@ -267,6 +277,8 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http IsStream: isStream, Status: "IN_PROGRESS", StartTime: attemptStartTime, + RequestModel: requestModel, + MappedModel: mappedModel, } log.Printf("[Executor] Creating attempt for route %d, attempt %d (proxyRequestID=%d, routeID=%d, providerID=%d)", routeIdx+1, attempt+1, proxyReq.ID, matchedRoute.Route.ID, matchedRoute.Provider.ID) @@ -319,6 +331,7 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http proxyReq.EndTime = time.Now() proxyReq.Duration = proxyReq.EndTime.Sub(proxyReq.StartTime) proxyReq.FinalProxyUpstreamAttemptID = attemptRecord.ID + proxyReq.ResponseModel = mappedModel // Record the actual model used // Capture actual client response (what was sent to client, e.g. Claude format) // This is different from attemptRecord.ResponseInfo which is upstream response (Gemini format) @@ -480,31 +493,151 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http } func (e *Executor) mapModel(requestModel string, route *domain.Route, provider *domain.Provider) string { - // Route mapping takes precedence + log.Printf("[ModelMapping] Input: requestModel=%q, routeID=%d, providerID=%d", requestModel, route.ID, provider.ID) + + // Route mapping takes precedence (supports wildcard patterns) if route.ModelMapping != nil { - if mapped, ok := route.ModelMapping[requestModel]; ok { + log.Printf("[ModelMapping] Route has mapping: %v", route.ModelMapping) + if mapped := matchModelMapping(requestModel, route.ModelMapping); mapped != "" { + log.Printf("[ModelMapping] Route mapping matched: %q -> %q", requestModel, mapped) return mapped } + log.Printf("[ModelMapping] Route mapping: no match") + } else { + log.Printf("[ModelMapping] Route has no mapping") } - // Provider mapping + // Provider mapping (supports wildcard patterns) if provider.Config != nil { if provider.Config.Custom != nil && provider.Config.Custom.ModelMapping != nil { - if mapped, ok := provider.Config.Custom.ModelMapping[requestModel]; ok { + log.Printf("[ModelMapping] Provider Custom has mapping: %v", provider.Config.Custom.ModelMapping) + if mapped := matchModelMapping(requestModel, provider.Config.Custom.ModelMapping); mapped != "" { + log.Printf("[ModelMapping] Provider Custom mapping matched: %q -> %q", requestModel, mapped) return mapped } + log.Printf("[ModelMapping] Provider Custom mapping: no match") } if provider.Config.Antigravity != nil && provider.Config.Antigravity.ModelMapping != nil { - if mapped, ok := provider.Config.Antigravity.ModelMapping[requestModel]; ok { + log.Printf("[ModelMapping] Provider Antigravity has mapping: %v", provider.Config.Antigravity.ModelMapping) + if mapped := matchModelMapping(requestModel, provider.Config.Antigravity.ModelMapping); mapped != "" { + log.Printf("[ModelMapping] Provider Antigravity mapping matched: %q -> %q", requestModel, mapped) return mapped } + log.Printf("[ModelMapping] Provider Antigravity mapping: no match") } + } else { + log.Printf("[ModelMapping] Provider has no config") + } + + // Global Antigravity model mapping rules (lowest priority fallback) + // This applies the global settings configured in Settings page + if globalSettings := antigravity.GetGlobalSettings(); globalSettings != nil { + log.Printf("[ModelMapping] Global settings found, rules count: %d", len(globalSettings.ModelMappingRules)) + if len(globalSettings.ModelMappingRules) > 0 { + for i, rule := range globalSettings.ModelMappingRules { + log.Printf("[ModelMapping] Global rule[%d]: pattern=%q, target=%q", i, rule.Pattern, rule.Target) + } + if mapped := antigravity.MatchRulesInOrder(requestModel, globalSettings.ModelMappingRules); mapped != "" { + log.Printf("[ModelMapping] Global mapping matched: %q -> %q", requestModel, mapped) + return mapped + } + log.Printf("[ModelMapping] Global mapping: no match") + } + } else { + log.Printf("[ModelMapping] No global settings found") + } + + // Fallback to default model mapping rules + defaultRules := antigravity.GetDefaultModelMappingRules() + log.Printf("[ModelMapping] Trying default rules, count: %d", len(defaultRules)) + if mapped := antigravity.MatchRulesInOrder(requestModel, defaultRules); mapped != "" { + log.Printf("[ModelMapping] Default mapping matched: %q -> %q", requestModel, mapped) + return mapped } // No mapping, use original + log.Printf("[ModelMapping] No mapping found, using original: %q", requestModel) return requestModel } +// matchModelMapping matches requestModel against mapping rules with wildcard support +// Returns the mapped model or empty string if no match +func matchModelMapping(requestModel string, mapping map[string]string) string { + // First try exact match (fast path) + if mapped, ok := mapping[requestModel]; ok { + log.Printf("[matchModelMapping] Exact match: %q -> %q", requestModel, mapped) + return mapped + } + + // Then try wildcard patterns + for pattern, target := range mapping { + if strings.Contains(pattern, "*") { + matched := matchWildcard(pattern, requestModel) + log.Printf("[matchModelMapping] Wildcard check: pattern=%q, input=%q, matched=%v", pattern, requestModel, matched) + if matched { + return target + } + } + } + + return "" +} + +// matchWildcard checks if input matches a wildcard pattern +// Supports * as wildcard matching any characters +// Examples: +// - "*sonnet*" matches "claude-sonnet-4-20250514" +// - "gpt-4*" matches "gpt-4-turbo" +// - "*-20241022" matches "claude-3-5-sonnet-20241022" +func matchWildcard(pattern, input string) bool { + // Simple cases + if pattern == "*" { + return true + } + if !strings.Contains(pattern, "*") { + return pattern == input + } + + parts := strings.Split(pattern, "*") + + // Handle prefix-only pattern: "prefix*" + if len(parts) == 2 && parts[1] == "" { + return strings.HasPrefix(input, parts[0]) + } + + // Handle suffix-only pattern: "*suffix" + if len(parts) == 2 && parts[0] == "" { + return strings.HasSuffix(input, parts[1]) + } + + // Handle patterns with multiple wildcards + pos := 0 + for i, part := range parts { + if part == "" { + continue + } + + idx := strings.Index(input[pos:], part) + if idx < 0 { + return false + } + + // First part must be at the beginning if pattern doesn't start with * + if i == 0 && idx != 0 { + return false + } + + pos += idx + len(part) + } + + // Last part must be at the end if pattern doesn't end with * + if parts[len(parts)-1] != "" && !strings.HasSuffix(input, parts[len(parts)-1]) { + return false + } + + return true +} + func (e *Executor) getRetryConfig(config *domain.RetryConfig) *domain.RetryConfig { if config != nil { log.Printf("[Executor] Using provided retry config: MaxRetries=%d", config.MaxRetries) diff --git a/internal/executor/wildcard_test.go b/internal/executor/wildcard_test.go new file mode 100644 index 00000000..a5cc01c4 --- /dev/null +++ b/internal/executor/wildcard_test.go @@ -0,0 +1,97 @@ +package executor + +import "testing" + +func TestMatchWildcard(t *testing.T) { + tests := []struct { + pattern string + input string + want bool + }{ + // Catch-all + {"*", "anything", true}, + {"*", "", true}, + + // Contains patterns (*xxx*) + {"*sonnet*", "claude-sonnet-4-20250514", true}, + {"*sonnet*", "claude-3-5-sonnet-20241022", true}, + {"*opus*", "claude-opus-4-20250514", true}, + {"*haiku*", "claude-3-5-haiku-20241022", true}, + {"*claude*", "claude-sonnet-4-20250514", true}, + {"*o1*", "o1", true}, + {"*o1*", "o1-mini", true}, + {"*o1*", "o1-pro", true}, + {"*flash*", "gemini-2.5-flash", true}, + + // Prefix patterns (xxx*) + {"gpt-4*", "gpt-4-turbo", true}, + {"gpt-4*", "gpt-4o", true}, + {"gpt-4o-mini*", "gpt-4o-mini", true}, + {"gpt-4o-mini*", "gpt-4o-mini-2024", true}, + {"claude-*", "claude-sonnet-4", true}, + + // Suffix patterns (*xxx) + {"*-20241022", "claude-3-5-sonnet-20241022", true}, + {"*-turbo", "gpt-4-turbo", true}, + + // Exact match (no wildcard) + {"claude-sonnet-4", "claude-sonnet-4", true}, + {"gpt-4", "gpt-4", true}, + + // Non-matches + {"*sonnet*", "claude-opus-4", false}, + {"*opus*", "claude-sonnet-4", false}, + {"gpt-4*", "gpt-3.5-turbo", false}, + {"claude-sonnet-4", "claude-sonnet-4-20250514", false}, + {"*-20241022", "claude-3-5-sonnet-20250514", false}, + } + + for _, tt := range tests { + t.Run(tt.pattern+"_"+tt.input, func(t *testing.T) { + got := matchWildcard(tt.pattern, tt.input) + if got != tt.want { + t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.input, got, tt.want) + } + }) + } +} + +func TestMatchModelMapping(t *testing.T) { + mapping := map[string]string{ + "*sonnet*": "gemini-2.5-pro", + "*opus*": "claude-opus-4-5-thinking", + "*haiku*": "gemini-2.5-flash-lite", + "gpt-4o-mini*": "gemini-2.5-flash", + "gpt-4*": "gemini-2.5-pro", + "exact-model": "exact-target", + } + + tests := []struct { + requestModel string + want string + }{ + // Wildcard matches + {"claude-sonnet-4-20250514", "gemini-2.5-pro"}, + {"claude-3-5-sonnet-20241022", "gemini-2.5-pro"}, + {"claude-opus-4-20250514", "claude-opus-4-5-thinking"}, + {"claude-3-5-haiku-20241022", "gemini-2.5-flash-lite"}, + {"gpt-4-turbo", "gemini-2.5-pro"}, + {"gpt-4o", "gemini-2.5-pro"}, + + // Exact match + {"exact-model", "exact-target"}, + + // No match + {"unknown-model", ""}, + {"gemini-2.5-pro", ""}, + } + + for _, tt := range tests { + t.Run(tt.requestModel, func(t *testing.T) { + got := matchModelMapping(tt.requestModel, mapping) + if got != tt.want { + t.Errorf("matchModelMapping(%q) = %q, want %q", tt.requestModel, got, tt.want) + } + }) + } +} diff --git a/internal/handler/admin.go b/internal/handler/admin.go index dc1b884f..ea812f24 100644 --- a/internal/handler/admin.go +++ b/internal/handler/admin.go @@ -68,6 +68,10 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.handleCooldowns(w, r, id) case "logs": h.handleLogs(w, r) + case "antigravity-settings": + h.handleAntigravitySettings(w, r) + case "antigravity-settings-reset": + h.handleAntigravitySettingsReset(w, r) default: writeJSON(w, http.StatusNotFound, map[string]string{"error": "not found"}) } @@ -854,6 +858,51 @@ func (h *AdminHandler) handleCooldowns(w http.ResponseWriter, r *http.Request, p } } +// Antigravity global settings handler +// GET /admin/antigravity-settings - get global Antigravity settings +// PUT /admin/antigravity-settings - update global Antigravity settings +func (h *AdminHandler) handleAntigravitySettings(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + settings, err := h.svc.GetAntigravityGlobalSettings() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, settings) + + case http.MethodPut: + var settings service.AntigravityGlobalSettings + if err := json.NewDecoder(r.Body).Decode(&settings); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) + return + } + if err := h.svc.UpdateAntigravityGlobalSettings(&settings); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, settings) + + default: + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + } +} + +// POST /admin/antigravity-settings-reset - reset to preset defaults +func (h *AdminHandler) handleAntigravitySettingsReset(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + return + } + + settings, err := h.svc.ResetAntigravityGlobalSettings() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, settings) +} + func writeJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) diff --git a/internal/repository/sqlite/db.go b/internal/repository/sqlite/db.go index 2dd8a087..f02579ae 100644 --- a/internal/repository/sqlite/db.go +++ b/internal/repository/sqlite/db.go @@ -333,6 +333,34 @@ func (d *DB) migrate() error { } } + // Migration: Add request_model and mapped_model columns to proxy_upstream_attempts + var hasRequestModel bool + row = d.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxy_upstream_attempts') WHERE name='request_model'`) + row.Scan(&hasRequestModel) + + if !hasRequestModel { + _, err = d.db.Exec(`ALTER TABLE proxy_upstream_attempts ADD COLUMN request_model TEXT DEFAULT ''`) + if err != nil { + return err + } + _, err = d.db.Exec(`ALTER TABLE proxy_upstream_attempts ADD COLUMN mapped_model TEXT DEFAULT ''`) + if err != nil { + return err + } + } + + // Migration: Add response_model column to proxy_upstream_attempts + var hasResponseModel bool + row = d.db.QueryRow(`SELECT COUNT(*) FROM pragma_table_info('proxy_upstream_attempts') WHERE name='response_model'`) + row.Scan(&hasResponseModel) + + if !hasResponseModel { + _, err = d.db.Exec(`ALTER TABLE proxy_upstream_attempts ADD COLUMN response_model TEXT DEFAULT ''`) + if err != nil { + return err + } + } + return nil } diff --git a/internal/repository/sqlite/proxy_upstream_attempt.go b/internal/repository/sqlite/proxy_upstream_attempt.go index 053f7dd1..044ec5cc 100644 --- a/internal/repository/sqlite/proxy_upstream_attempt.go +++ b/internal/repository/sqlite/proxy_upstream_attempt.go @@ -22,8 +22,8 @@ func (r *ProxyUpstreamAttemptRepository) Create(a *domain.ProxyUpstreamAttempt) a.UpdatedAt = now result, err := r.db.db.Exec( - `INSERT INTO proxy_upstream_attempts (created_at, updated_at, start_time, end_time, duration_ms, status, proxy_request_id, is_stream, request_info, response_info, route_id, provider_id, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - a.CreatedAt, a.UpdatedAt, a.StartTime, a.EndTime, a.Duration.Milliseconds(), a.Status, a.ProxyRequestID, a.IsStream, toJSON(a.RequestInfo), toJSON(a.ResponseInfo), a.RouteID, a.ProviderID, a.InputTokenCount, a.OutputTokenCount, a.CacheReadCount, a.CacheWriteCount, a.Cache5mWriteCount, a.Cache1hWriteCount, a.Cost, + `INSERT INTO proxy_upstream_attempts (created_at, updated_at, start_time, end_time, duration_ms, status, proxy_request_id, is_stream, request_model, mapped_model, response_model, request_info, response_info, route_id, provider_id, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + a.CreatedAt, a.UpdatedAt, a.StartTime, a.EndTime, a.Duration.Milliseconds(), a.Status, a.ProxyRequestID, a.IsStream, a.RequestModel, a.MappedModel, a.ResponseModel, toJSON(a.RequestInfo), toJSON(a.ResponseInfo), a.RouteID, a.ProviderID, a.InputTokenCount, a.OutputTokenCount, a.CacheReadCount, a.CacheWriteCount, a.Cache5mWriteCount, a.Cache1hWriteCount, a.Cost, ) if err != nil { return err @@ -40,14 +40,14 @@ func (r *ProxyUpstreamAttemptRepository) Create(a *domain.ProxyUpstreamAttempt) func (r *ProxyUpstreamAttemptRepository) Update(a *domain.ProxyUpstreamAttempt) error { a.UpdatedAt = time.Now() _, err := r.db.db.Exec( - `UPDATE proxy_upstream_attempts SET updated_at = ?, start_time = ?, end_time = ?, duration_ms = ?, status = ?, is_stream = ?, request_info = ?, response_info = ?, route_id = ?, provider_id = ?, input_token_count = ?, output_token_count = ?, cache_read_count = ?, cache_write_count = ?, cache_5m_write_count = ?, cache_1h_write_count = ?, cost = ? WHERE id = ?`, - a.UpdatedAt, a.StartTime, a.EndTime, a.Duration.Milliseconds(), a.Status, a.IsStream, toJSON(a.RequestInfo), toJSON(a.ResponseInfo), a.RouteID, a.ProviderID, a.InputTokenCount, a.OutputTokenCount, a.CacheReadCount, a.CacheWriteCount, a.Cache5mWriteCount, a.Cache1hWriteCount, a.Cost, a.ID, + `UPDATE proxy_upstream_attempts SET updated_at = ?, start_time = ?, end_time = ?, duration_ms = ?, status = ?, is_stream = ?, request_model = ?, mapped_model = ?, response_model = ?, request_info = ?, response_info = ?, route_id = ?, provider_id = ?, input_token_count = ?, output_token_count = ?, cache_read_count = ?, cache_write_count = ?, cache_5m_write_count = ?, cache_1h_write_count = ?, cost = ? WHERE id = ?`, + a.UpdatedAt, a.StartTime, a.EndTime, a.Duration.Milliseconds(), a.Status, a.IsStream, a.RequestModel, a.MappedModel, a.ResponseModel, toJSON(a.RequestInfo), toJSON(a.ResponseInfo), a.RouteID, a.ProviderID, a.InputTokenCount, a.OutputTokenCount, a.CacheReadCount, a.CacheWriteCount, a.Cache5mWriteCount, a.Cache1hWriteCount, a.Cost, a.ID, ) return err } func (r *ProxyUpstreamAttemptRepository) ListByProxyRequestID(proxyRequestID uint64) ([]*domain.ProxyUpstreamAttempt, error) { - rows, err := r.db.db.Query(`SELECT id, created_at, updated_at, start_time, end_time, duration_ms, status, proxy_request_id, is_stream, request_info, response_info, route_id, provider_id, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost FROM proxy_upstream_attempts WHERE proxy_request_id = ? ORDER BY id`, proxyRequestID) + rows, err := r.db.db.Query(`SELECT id, created_at, updated_at, start_time, end_time, duration_ms, status, proxy_request_id, is_stream, request_model, mapped_model, response_model, request_info, response_info, route_id, provider_id, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost FROM proxy_upstream_attempts WHERE proxy_request_id = ? ORDER BY id`, proxyRequestID) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func (r *ProxyUpstreamAttemptRepository) ListByProxyRequestID(proxyRequestID uin var reqInfoJSON, respInfoJSON string var startTime, endTime sql.NullTime var durationMs int64 - err := rows.Scan(&a.ID, &a.CreatedAt, &a.UpdatedAt, &startTime, &endTime, &durationMs, &a.Status, &a.ProxyRequestID, &a.IsStream, &reqInfoJSON, &respInfoJSON, &a.RouteID, &a.ProviderID, &a.InputTokenCount, &a.OutputTokenCount, &a.CacheReadCount, &a.CacheWriteCount, &a.Cache5mWriteCount, &a.Cache1hWriteCount, &a.Cost) + err := rows.Scan(&a.ID, &a.CreatedAt, &a.UpdatedAt, &startTime, &endTime, &durationMs, &a.Status, &a.ProxyRequestID, &a.IsStream, &a.RequestModel, &a.MappedModel, &a.ResponseModel, &reqInfoJSON, &respInfoJSON, &a.RouteID, &a.ProviderID, &a.InputTokenCount, &a.OutputTokenCount, &a.CacheReadCount, &a.CacheWriteCount, &a.Cache5mWriteCount, &a.Cache1hWriteCount, &a.Cost) if err != nil { return nil, err } diff --git a/internal/router/router.go b/internal/router/router.go index 3528d412..1fe371ee 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -105,12 +105,6 @@ func (r *Router) Match(clientType domain.ClientType, projectID uint64) ([]*Match log.Printf("[Router] Match called: clientType=%s, projectID=%d, total routes in cache=%d", clientType, projectID, len(routes)) - // Debug: print all routes in cache - for _, rt := range routes { - log.Printf("[Router] Route in cache: id=%d, clientType=%s, projectID=%d, providerID=%d, isEnabled=%v", - rt.ID, rt.ClientType, rt.ProjectID, rt.ProviderID, rt.IsEnabled) - } - // Check if ClientType has custom routes enabled for this project useProjectRoutes := false if projectID != 0 { diff --git a/internal/service/admin.go b/internal/service/admin.go index 2e97e3df..17af2e32 100644 --- a/internal/service/admin.go +++ b/internal/service/admin.go @@ -1,10 +1,12 @@ package service import ( + "encoding/json" "strconv" "strings" "time" + "github.com/awsl-project/maxx/internal/adapter/provider/antigravity" "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/repository" ) @@ -404,6 +406,103 @@ func (s *AdminService) DeleteSetting(key string) error { return s.settingRepo.Delete(key) } +// ===== Antigravity Global Settings API ===== + +// ModelMappingRule represents a single model mapping rule (for API) +type ModelMappingRule struct { + Pattern string `json:"pattern"` // Source pattern, supports * wildcard + Target string `json:"target"` // Target model name +} + +// AntigravityGlobalSettings represents the global Antigravity configuration +type AntigravityGlobalSettings struct { + ModelMappingRules []ModelMappingRule `json:"modelMappingRules"` + AvailableTargetModels []string `json:"availableTargetModels"` +} + +// GetAntigravityGlobalSettings retrieves the global Antigravity settings +// If no custom mapping exists, returns the preset mapping as default +func (s *AdminService) GetAntigravityGlobalSettings() (*AntigravityGlobalSettings, error) { + settings := &AntigravityGlobalSettings{ + ModelMappingRules: []ModelMappingRule{}, + AvailableTargetModels: antigravity.GetAvailableTargetModels(), + } + + // Get model mapping rules from database + rulesJSON, err := s.settingRepo.Get(domain.SettingKeyAntigravityModelMapping) + if err == nil && rulesJSON != "" { + // Use ParseModelMappingRules which handles both new array format and legacy map format + agRules, parseErr := antigravity.ParseModelMappingRules(rulesJSON) + if parseErr != nil { + return nil, parseErr + } + // Convert antigravity.ModelMappingRule to service.ModelMappingRule + settings.ModelMappingRules = make([]ModelMappingRule, len(agRules)) + for i, r := range agRules { + settings.ModelMappingRules[i] = ModelMappingRule{Pattern: r.Pattern, Target: r.Target} + } + } + + // If no rules exist, initialize with preset rules + if len(settings.ModelMappingRules) == 0 { + defaultRules := antigravity.GetDefaultModelMappingRules() + settings.ModelMappingRules = make([]ModelMappingRule, len(defaultRules)) + for i, r := range defaultRules { + settings.ModelMappingRules[i] = ModelMappingRule{Pattern: r.Pattern, Target: r.Target} + } + // Save to database + if rulesJSON, err := json.Marshal(settings.ModelMappingRules); err == nil { + s.settingRepo.Set(domain.SettingKeyAntigravityModelMapping, string(rulesJSON)) + } + } + + return settings, nil +} + +// UpdateAntigravityGlobalSettings updates the global Antigravity settings +func (s *AdminService) UpdateAntigravityGlobalSettings(settings *AntigravityGlobalSettings) error { + // Update model mapping rules + if settings.ModelMappingRules != nil { + rulesJSON, err := json.Marshal(settings.ModelMappingRules) + if err != nil { + return err + } + if err := s.settingRepo.Set(domain.SettingKeyAntigravityModelMapping, string(rulesJSON)); err != nil { + return err + } + } else { + // Clear rules if nil + if err := s.settingRepo.Set(domain.SettingKeyAntigravityModelMapping, "[]"); err != nil { + return err + } + } + + return nil +} + +// ResetAntigravityGlobalSettings resets the model mapping to preset defaults +func (s *AdminService) ResetAntigravityGlobalSettings() (*AntigravityGlobalSettings, error) { + defaultRules := antigravity.GetDefaultModelMappingRules() + rules := make([]ModelMappingRule, len(defaultRules)) + for i, r := range defaultRules { + rules[i] = ModelMappingRule{Pattern: r.Pattern, Target: r.Target} + } + + rulesJSON, err := json.Marshal(rules) + if err != nil { + return nil, err + } + + if err := s.settingRepo.Set(domain.SettingKeyAntigravityModelMapping, string(rulesJSON)); err != nil { + return nil, err + } + + return &AntigravityGlobalSettings{ + ModelMappingRules: rules, + AvailableTargetModels: antigravity.GetAvailableTargetModels(), + }, nil +} + // ===== Proxy Status API ===== type ProxyStatus struct { diff --git a/web/package.json.md5 b/web/package.json.md5 index 009687a5..d39dab00 100644 --- a/web/package.json.md5 +++ b/web/package.json.md5 @@ -1 +1 @@ -bcf00bcbf32d60bcf629084b94e66ea8 +0916fdee20c14f10b77d7d2bcc6ac3b9 \ No newline at end of file diff --git a/web/src/components/ui/index.ts b/web/src/components/ui/index.ts index fec4fe30..16b13782 100644 --- a/web/src/components/ui/index.ts +++ b/web/src/components/ui/index.ts @@ -34,9 +34,23 @@ export { Input } from './input'; export type { InputProps } from './input'; // Select -export { Select } from './select'; +export { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from './select'; export type { SelectProps } from './select'; +// Dialog +export { + Dialog, + DialogClose, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogOverlay, + DialogPortal, + DialogTitle, + DialogTrigger, +} from './dialog'; + // Switch export { Switch } from './switch'; diff --git a/web/src/components/ui/model-input.tsx b/web/src/components/ui/model-input.tsx new file mode 100644 index 00000000..f8e14af0 --- /dev/null +++ b/web/src/components/ui/model-input.tsx @@ -0,0 +1,385 @@ +import { useState, useMemo, useEffect, useRef } from 'react' +import { ChevronDown, Search, X } from 'lucide-react' +import { cn } from '@/lib/utils' +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, +} from '@/components/ui/dialog' +import { Input } from '@/components/ui/input' + +// 常见模型列表 +const COMMON_MODELS = [ + // Claude wildcards (for source patterns) + { id: '*claude*', name: 'All Claude models', provider: 'Claude' }, + { id: '*sonnet*', name: 'All Sonnet models', provider: 'Claude' }, + { id: '*opus*', name: 'All Opus models', provider: 'Claude' }, + { id: '*haiku*', name: 'All Haiku models', provider: 'Claude' }, + // Claude models + { id: 'claude-sonnet-4-20250514', name: 'Claude Sonnet 4', provider: 'Claude' }, + { id: 'claude-opus-4-20250514', name: 'Claude Opus 4', provider: 'Claude' }, + { + id: 'claude-3-5-sonnet-20241022', + name: 'Claude 3.5 Sonnet', + provider: 'Claude', + }, + { + id: 'claude-3-5-haiku-20241022', + name: 'Claude 3.5 Haiku', + provider: 'Claude', + }, + { id: 'claude-3-opus-20240229', name: 'Claude 3 Opus', provider: 'Claude' }, + // Gemini wildcards + { id: '*gemini*', name: 'All Gemini models', provider: 'Gemini' }, + { id: '*flash*', name: 'All Flash models', provider: 'Gemini' }, + // Gemini models + { id: 'gemini-2.5-pro', name: 'Gemini 2.5 Pro', provider: 'Gemini' }, + { id: 'gemini-2.5-flash', name: 'Gemini 2.5 Flash', provider: 'Gemini' }, + { + id: 'gemini-2.5-flash-lite', + name: 'Gemini 2.5 Flash Lite', + provider: 'Gemini', + }, + { id: 'gemini-2.0-flash', name: 'Gemini 2.0 Flash', provider: 'Gemini' }, + { id: 'gemini-1.5-pro', name: 'Gemini 1.5 Pro', provider: 'Gemini' }, + { id: 'gemini-1.5-flash', name: 'Gemini 1.5 Flash', provider: 'Gemini' }, + // OpenAI wildcards + { id: '*gpt*', name: 'All GPT models', provider: 'OpenAI' }, + { id: '*o1*', name: 'All o1 models', provider: 'OpenAI' }, + { id: '*o3*', name: 'All o3 models', provider: 'OpenAI' }, + // OpenAI models + { id: 'gpt-4o', name: 'GPT-4o', provider: 'OpenAI' }, + { id: 'gpt-4o-mini', name: 'GPT-4o Mini', provider: 'OpenAI' }, + { id: 'gpt-4-turbo', name: 'GPT-4 Turbo', provider: 'OpenAI' }, + { id: 'gpt-4', name: 'GPT-4', provider: 'OpenAI' }, + { id: 'gpt-3.5-turbo', name: 'GPT-3.5 Turbo', provider: 'OpenAI' }, + { id: 'o1', name: 'o1', provider: 'OpenAI' }, + { id: 'o1-mini', name: 'o1 Mini', provider: 'OpenAI' }, + { id: 'o1-pro', name: 'o1 Pro', provider: 'OpenAI' }, + { id: 'o3-mini', name: 'o3 Mini', provider: 'OpenAI' }, + // Antigravity supported target models (use these as mapping targets) + { id: 'claude-opus-4-5-thinking', name: 'Claude Opus 4.5 Thinking', provider: 'Antigravity' }, + { id: 'claude-sonnet-4-5', name: 'Claude Sonnet 4.5', provider: 'Antigravity' }, + { id: 'claude-sonnet-4-5-thinking', name: 'Claude Sonnet 4.5 Thinking', provider: 'Antigravity' }, + { id: 'gemini-2.5-flash-lite', name: 'Gemini 2.5 Flash Lite', provider: 'Antigravity' }, + { id: 'gemini-2.5-flash', name: 'Gemini 2.5 Flash', provider: 'Antigravity' }, + { id: 'gemini-2.5-flash-thinking', name: 'Gemini 2.5 Flash Thinking', provider: 'Antigravity' }, + { id: 'gemini-2.5-pro', name: 'Gemini 2.5 Pro', provider: 'Antigravity' }, + { id: 'gemini-3-flash', name: 'Gemini 3 Flash', provider: 'Antigravity' }, + { id: 'gemini-3-pro', name: 'Gemini 3 Pro', provider: 'Antigravity' }, + { id: 'gemini-3-pro-low', name: 'Gemini 3 Pro Low', provider: 'Antigravity' }, + { id: 'gemini-3-pro-high', name: 'Gemini 3 Pro High', provider: 'Antigravity' }, + { id: 'gemini-3-pro-preview', name: 'Gemini 3 Pro Preview', provider: 'Antigravity' }, + { id: 'gemini-3-pro-image', name: 'Gemini 3 Pro Image', provider: 'Antigravity' }, + // Generic wildcard + { id: '*', name: 'All models (catch-all)', provider: 'Other' }, +] as const + +type Model = (typeof COMMON_MODELS)[number] +type Provider = Model['provider'] + +interface ModelInputProps { + value: string + onChange: (value: string) => void + placeholder?: string + disabled?: boolean + className?: string + /** Filter to only show models from specific providers */ + providers?: Provider[] +} + +// 简单的模糊匹配函数 +function fuzzyMatch(text: string, pattern: string): boolean { + const lowerText = text.toLowerCase() + const lowerPattern = pattern.toLowerCase() + + // 先尝试普通包含匹配 + if (lowerText.includes(lowerPattern)) return true + + // 模糊匹配:pattern 中的字符按顺序出现在 text 中 + let patternIdx = 0 + for (let i = 0; i < lowerText.length && patternIdx < lowerPattern.length; i++) { + if (lowerText[i] === lowerPattern[patternIdx]) { + patternIdx++ + } + } + return patternIdx === lowerPattern.length +} + +// 计算匹配分数(用于排序) +function matchScore(model: Model, pattern: string): number { + const lowerPattern = pattern.toLowerCase() + const lowerId = model.id.toLowerCase() + const lowerName = model.name.toLowerCase() + + // 精确匹配得分最高 + if (lowerId === lowerPattern || lowerName === lowerPattern) return 100 + + // 前缀匹配次之 + if (lowerId.startsWith(lowerPattern) || lowerName.startsWith(lowerPattern)) return 80 + + // 包含匹配 + if (lowerId.includes(lowerPattern) || lowerName.includes(lowerPattern)) return 60 + + // 模糊匹配得分最低 + return 40 +} + +export function ModelInput({ + value, + onChange, + placeholder = 'Select or enter model...', + disabled = false, + className, + providers, +}: ModelInputProps) { + const [isOpen, setIsOpen] = useState(false) + const [search, setSearch] = useState('') + const [focusedIndex, setFocusedIndex] = useState(-1) + const focusedRef = useRef(null) + + // Base models filtered by providers prop + const baseModels = useMemo(() => { + if (!providers || providers.length === 0) return [...COMMON_MODELS] + return COMMON_MODELS.filter(model => providers.includes(model.provider)) + }, [providers]) + + // 过滤和排序模型(支持模糊匹配) + const filteredModels = useMemo(() => { + if (!search.trim()) return baseModels + + return baseModels + .filter( + model => + fuzzyMatch(model.id, search) || + fuzzyMatch(model.name, search) || + fuzzyMatch(model.provider, search) + ) + .sort((a, b) => matchScore(b, search) - matchScore(a, search)) + }, [search, baseModels]) + + // 重置 focusedIndex 当过滤结果变化时 + useEffect(() => { + setFocusedIndex(-1) + }, [filteredModels.length]) + + // 自动滚动到高亮项 + useEffect(() => { + if (focusedIndex >= 0 && focusedRef.current) { + focusedRef.current.scrollIntoView({ block: 'nearest' }) + } + }, [focusedIndex]) + + // 按 provider 分组 + const groupedModels = useMemo(() => { + return filteredModels.reduce( + (acc, model) => { + if (!acc[model.provider]) { + acc[model.provider] = [] + } + acc[model.provider].push(model) + return acc + }, + {} as Record + ) + }, [filteredModels]) + + const handleOpen = () => { + if (!disabled) { + setSearch(value) // 初始化搜索框为当前值 + setIsOpen(true) + } + } + + const handleSelect = (modelId: string) => { + onChange(modelId) + setIsOpen(false) + setSearch('') + } + + const handleClear = (e: React.MouseEvent) => { + e.stopPropagation() + onChange('') + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === 'Enter') { + e.preventDefault() + // 如果有高亮项,选择高亮项;否则使用搜索框内容 + if (focusedIndex >= 0 && focusedIndex < filteredModels.length) { + handleSelect(filteredModels[focusedIndex].id) + } else if (search.trim()) { + handleSelect(search.trim()) + } + return + } + + // Tab/Shift+Tab 切换高亮项 + if (e.key === 'Tab' && filteredModels.length > 0) { + e.preventDefault() + + if (e.shiftKey) { + // Shift+Tab: 上一个 + setFocusedIndex(prev => (prev <= 0 ? filteredModels.length - 1 : prev - 1)) + } else { + // Tab: 下一个 + setFocusedIndex(prev => (prev >= filteredModels.length - 1 ? 0 : prev + 1)) + } + } + + // 上下箭头也可以切换 + if (e.key === 'ArrowDown' && filteredModels.length > 0) { + e.preventDefault() + setFocusedIndex(prev => (prev >= filteredModels.length - 1 ? 0 : prev + 1)) + } + if (e.key === 'ArrowUp' && filteredModels.length > 0) { + e.preventDefault() + setFocusedIndex(prev => (prev <= 0 ? filteredModels.length - 1 : prev - 1)) + } + } + + return ( + <> + {/* 触发按钮 */} + + + {/* Dialog */} + + + + Select Model + + + {/* 搜索框 */} +
+ + setSearch(e.target.value)} + onKeyDown={handleKeyDown} + placeholder="Search or enter custom model..." + className="pl-9" + autoFocus + /> +
+ + {/* 模型列表 */} +
+ {Object.keys(groupedModels).length > 0 ? ( +
+ {Object.entries(groupedModels).map(([provider, models]) => ( +
+
+ {provider} +
+
+ {models.map(model => { + const modelIndex = filteredModels.findIndex(m => m.id === model.id) + const isFocused = modelIndex === focusedIndex + return ( + + ) + })} +
+
+ ))} +
+ ) : search.trim() ? ( +
+

+ No matching models found. +

+ +
+ ) : ( +
+ No models available +
+ )} +
+ + {/* 提示 */} +
+ Press{' '} + + Enter + {' '} + to use custom model name +
+
+
+ + ) +} diff --git a/web/src/hooks/queries/index.ts b/web/src/hooks/queries/index.ts index d2de0944..f96282d2 100644 --- a/web/src/hooks/queries/index.ts +++ b/web/src/hooks/queries/index.ts @@ -81,4 +81,7 @@ export { useSetting, useUpdateSetting, useDeleteSetting, + useAntigravityGlobalSettings, + useUpdateAntigravityGlobalSettings, + useResetAntigravityGlobalSettings, } from './use-settings'; diff --git a/web/src/hooks/queries/use-settings.ts b/web/src/hooks/queries/use-settings.ts index df5953b1..452d5659 100644 --- a/web/src/hooks/queries/use-settings.ts +++ b/web/src/hooks/queries/use-settings.ts @@ -4,10 +4,12 @@ import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; import { getTransport } from '@/lib/transport'; +import type { AntigravityGlobalSettings } from '@/lib/transport'; export const settingsKeys = { all: ['settings'] as const, detail: (key: string) => ['settings', key] as const, + antigravityGlobal: ['settings', 'antigravity-global'] as const, }; export function useSettings() { @@ -47,3 +49,35 @@ export function useDeleteSetting() { }, }); } + +// ===== Antigravity Global Settings ===== + +export function useAntigravityGlobalSettings() { + return useQuery({ + queryKey: settingsKeys.antigravityGlobal, + queryFn: () => getTransport().getAntigravityGlobalSettings(), + }); +} + +export function useUpdateAntigravityGlobalSettings() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: (settings: AntigravityGlobalSettings) => + getTransport().updateAntigravityGlobalSettings(settings), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: settingsKeys.antigravityGlobal }); + }, + }); +} + +export function useResetAntigravityGlobalSettings() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: () => getTransport().resetAntigravityGlobalSettings(), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: settingsKeys.antigravityGlobal }); + }, + }); +} diff --git a/web/src/lib/transport/http-transport.ts b/web/src/lib/transport/http-transport.ts index 17644640..a6fa8341 100644 --- a/web/src/lib/transport/http-transport.ts +++ b/web/src/lib/transport/http-transport.ts @@ -30,6 +30,7 @@ import type { AntigravityTokenValidationResult, AntigravityBatchValidationResult, AntigravityQuotaData, + AntigravityGlobalSettings, ImportResult, Cooldown, } from './types'; @@ -337,6 +338,21 @@ export class HttpTransport implements Transport { return data; } + async getAntigravityGlobalSettings(): Promise { + const { data } = await this.client.get('/antigravity-settings'); + return data; + } + + async updateAntigravityGlobalSettings(settings: AntigravityGlobalSettings): Promise { + const { data } = await this.client.put('/antigravity-settings', settings); + return data; + } + + async resetAntigravityGlobalSettings(): Promise { + const { data } = await this.client.post('/antigravity-settings-reset'); + return data; + } + // ===== Cooldown API ===== async getCooldowns(): Promise { diff --git a/web/src/lib/transport/index.ts b/web/src/lib/transport/index.ts index 36c4507b..6c9cce77 100644 --- a/web/src/lib/transport/index.ts +++ b/web/src/lib/transport/index.ts @@ -46,6 +46,7 @@ export type { AntigravityTokenValidationResult, AntigravityBatchValidationResult, AntigravityOAuthResult, + AntigravityGlobalSettings, // Import ImportResult, // Cooldown diff --git a/web/src/lib/transport/interface.ts b/web/src/lib/transport/interface.ts index 82aabc4f..4dc53c7b 100644 --- a/web/src/lib/transport/interface.ts +++ b/web/src/lib/transport/interface.ts @@ -27,6 +27,7 @@ import type { AntigravityTokenValidationResult, AntigravityBatchValidationResult, AntigravityQuotaData, + AntigravityGlobalSettings, ImportResult, Cooldown, } from './types'; @@ -105,6 +106,9 @@ export interface Transport { validateAntigravityTokenText(tokenText: string): Promise; getAntigravityProviderQuota(providerId: number, forceRefresh?: boolean): Promise; startAntigravityOAuth(): Promise<{ authURL: string; state: string }>; + getAntigravityGlobalSettings(): Promise; + updateAntigravityGlobalSettings(settings: AntigravityGlobalSettings): Promise; + resetAntigravityGlobalSettings(): Promise; // ===== Cooldown API ===== getCooldowns(): Promise; diff --git a/web/src/lib/transport/types.ts b/web/src/lib/transport/types.ts index e2dcdfa0..4877d48a 100644 --- a/web/src/lib/transport/types.ts +++ b/web/src/lib/transport/types.ts @@ -188,6 +188,10 @@ export interface ProxyUpstreamAttempt { status: ProxyUpstreamAttemptStatus; proxyRequestID: number; isStream: boolean; // 是否为 SSE 流式请求 + // 模型信息 + requestModel: string; // 客户端请求的原始模型 + mappedModel: string; // 映射后实际发送的模型 + responseModel: string; // 上游响应中返回的模型名称 requestInfo: RequestInfo | null; responseInfo: ResponseInfo | null; routeID: number; @@ -325,6 +329,18 @@ export interface AntigravityOAuthResult { error?: string; } +// Antigravity 模型映射规则 +export interface ModelMappingRule { + pattern: string; // 源模式,支持 * 通配符 + target: string; // 目标模型名 +} + +// Antigravity 全局设置 +export interface AntigravityGlobalSettings { + modelMappingRules: ModelMappingRule[]; + availableTargetModels?: string[]; // 只在响应中返回,更新时不需要 +} + // ===== 回调类型 ===== export type EventCallback = (data: T) => void; diff --git a/web/src/lib/transport/wails-transport.ts b/web/src/lib/transport/wails-transport.ts index 43a5b954..7e151849 100644 --- a/web/src/lib/transport/wails-transport.ts +++ b/web/src/lib/transport/wails-transport.ts @@ -30,6 +30,7 @@ import type { AntigravityTokenValidationResult, AntigravityBatchValidationResult, AntigravityQuotaData, + AntigravityGlobalSettings, Cooldown, ImportResult, } from './types'; @@ -301,6 +302,19 @@ export class WailsTransport implements Transport { return DesktopApp.StartAntigravityOAuth() as Promise<{ authURL: string; state: string }>; } + async getAntigravityGlobalSettings(): Promise { + return DesktopApp.GetAntigravityGlobalSettings() as Promise; + } + + async updateAntigravityGlobalSettings(settings: AntigravityGlobalSettings): Promise { + await DesktopApp.UpdateAntigravityGlobalSettings(settings as any); + return settings; + } + + async resetAntigravityGlobalSettings(): Promise { + return DesktopApp.ResetAntigravityGlobalSettings() as Promise; + } + // ===== Cooldown API ===== async getCooldowns(): Promise { diff --git a/web/src/pages/providers/components/antigravity-provider-view.tsx b/web/src/pages/providers/components/antigravity-provider-view.tsx index 02309d57..f2865393 100644 --- a/web/src/pages/providers/components/antigravity-provider-view.tsx +++ b/web/src/pages/providers/components/antigravity-provider-view.tsx @@ -7,6 +7,8 @@ import { RefreshCw, Clock, Lock, + Shuffle, + Check, } from 'lucide-react' import { ClientIcon } from '@/components/icons/client-icons' import type { @@ -16,6 +18,9 @@ import type { } from '@/lib/transport' import { getTransport } from '@/lib/transport' import { ANTIGRAVITY_COLOR } from '../types' +import { ModelMappingEditor } from './model-mapping-editor' +import { useUpdateProvider } from '@/hooks/queries' +import { Button } from '@/components/ui/button' interface AntigravityProviderViewProps { provider: Provider @@ -119,6 +124,14 @@ export function AntigravityProviderView({ const [quota, setQuota] = useState(null) const [loading, setLoading] = useState(false) const [error, setError] = useState(null) + const [modelMapping, setModelMapping] = useState>( + provider.config?.antigravity?.modelMapping || {} + ) + const [savingMapping, setSavingMapping] = useState(false) + const [mappingSaveStatus, setMappingSaveStatus] = useState< + 'idle' | 'success' | 'error' + >('idle') + const updateProvider = useUpdateProvider() const fetchQuota = async (forceRefresh = false) => { setLoading(true) @@ -140,6 +153,45 @@ export function AntigravityProviderView({ fetchQuota(false) }, [provider.id]) // eslint-disable-line react-hooks/exhaustive-deps + const handleSaveModelMapping = async () => { + setSavingMapping(true) + setMappingSaveStatus('idle') + try { + const antigravityConfig = provider.config?.antigravity + if (!antigravityConfig) return + + await updateProvider.mutateAsync({ + id: Number(provider.id), + data: { + name: provider.name, + type: 'antigravity', + config: { + antigravity: { + email: antigravityConfig.email, + refreshToken: antigravityConfig.refreshToken, + projectID: antigravityConfig.projectID, + endpoint: antigravityConfig.endpoint, + modelMapping: + Object.keys(modelMapping).length > 0 ? modelMapping : undefined, + }, + }, + supportedClientTypes: provider.supportedClientTypes, + }, + }) + setMappingSaveStatus('success') + setTimeout(() => setMappingSaveStatus('idle'), 2000) + } catch (err) { + console.error('Failed to save model mapping:', err) + setMappingSaveStatus('error') + } finally { + setSavingMapping(false) + } + } + + const hasModelMappingChanged = + JSON.stringify(modelMapping) !== + JSON.stringify(provider.config?.antigravity?.modelMapping || {}) + return (
@@ -288,6 +340,44 @@ export function AntigravityProviderView({ )}
+ {/* Model Mapping */} +
+
+

+ + Model Mapping +

+ {hasModelMappingChanged && ( + + )} +
+

+ Map request models to different upstream models. For example, map + "claude-sonnet-4-20250514" to "gemini-2.5-pro". +

+ + {mappingSaveStatus === 'error' && ( +

+ Failed to save model mapping. Please try again. +

+ )} +
+ {/* Supported Clients */}

diff --git a/web/src/pages/providers/components/model-mapping-editor.tsx b/web/src/pages/providers/components/model-mapping-editor.tsx new file mode 100644 index 00000000..f29f46a2 --- /dev/null +++ b/web/src/pages/providers/components/model-mapping-editor.tsx @@ -0,0 +1,130 @@ +import { useState } from 'react' +import { Plus, Trash2, ArrowRight } from 'lucide-react' +import { Button } from '@/components/ui/button' +import { ModelInput } from '@/components/ui/model-input' + +interface ModelMappingEditorProps { + value: Record + onChange: (value: Record) => void + disabled?: boolean + /** Only show Antigravity-supported models for target selection */ + targetOnlyAntigravity?: boolean +} + +export function ModelMappingEditor({ + value, + onChange, + disabled = false, + targetOnlyAntigravity = false, +}: ModelMappingEditorProps) { + const [newFrom, setNewFrom] = useState('') + const [newTo, setNewTo] = useState('') + + const entries = Object.entries(value) + + // Target model providers filter + const targetProviders = targetOnlyAntigravity ? ['Antigravity' as const] : undefined + + const handleAdd = () => { + if (!newFrom.trim() || !newTo.trim()) return + if (value[newFrom.trim()]) return // Already exists + + onChange({ + ...value, + [newFrom.trim()]: newTo.trim(), + }) + setNewFrom('') + setNewTo('') + } + + const handleRemove = (key: string) => { + const newValue = { ...value } + delete newValue[key] + onChange(newValue) + } + + const handleUpdate = (oldKey: string, newKey: string, newVal: string) => { + const newValue = { ...value } + if (oldKey !== newKey) { + delete newValue[oldKey] + } + newValue[newKey] = newVal + onChange(newValue) + } + + return ( +
+ {entries.length > 0 && ( +
+ {entries.map(([from, to]) => ( +
+ handleUpdate(from, newKey, to)} + placeholder="Request Model" + className="flex-1" + disabled={disabled} + /> + + handleUpdate(from, from, newVal)} + placeholder="Mapped Model" + className="flex-1" + disabled={disabled} + providers={targetProviders} + /> + +
+ ))} +
+ )} + + {/* Add new mapping */} +
+ + + + +
+ + {entries.length === 0 && ( +

+ No model mappings configured. Add mappings to transform request models + before sending to upstream. +

+ )} +
+ ) +} diff --git a/web/src/pages/providers/components/provider-edit-flow.tsx b/web/src/pages/providers/components/provider-edit-flow.tsx index 4aeb1a6d..419c18da 100644 --- a/web/src/pages/providers/components/provider-edit-flow.tsx +++ b/web/src/pages/providers/components/provider-edit-flow.tsx @@ -1,5 +1,5 @@ import { useState } from 'react' -import { Globe, ChevronLeft, Key, Check, Trash2 } from 'lucide-react' +import { Globe, ChevronLeft, Key, Check, Trash2, Shuffle } from 'lucide-react' import { Dialog, DialogContent, @@ -15,6 +15,7 @@ import { ClientsConfigSection } from './clients-config-section' import { AntigravityProviderView } from './antigravity-provider-view' import { Button } from '@/components/ui/button' import { Input } from '@/components/ui/input' +import { ModelMappingEditor } from './model-mapping-editor' interface ProviderEditFlowProps { provider: Provider @@ -26,6 +27,7 @@ type EditFormData = { baseURL: string apiKey: string clients: ClientConfig[] + modelMapping: Record } export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) { @@ -53,6 +55,7 @@ export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) { baseURL: provider.config?.custom?.baseURL || '', apiKey: provider.config?.custom?.apiKey || '', clients: initClients(), + modelMapping: provider.config?.custom?.modelMapping || {}, }) const updateClient = ( @@ -102,6 +105,10 @@ export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) { apiKey: formData.apiKey || provider.config?.custom?.apiKey || '', clientBaseURL: Object.keys(clientBaseURL).length > 0 ? clientBaseURL : undefined, + modelMapping: + Object.keys(formData.modelMapping).length > 0 + ? formData.modelMapping + : undefined, }, }, supportedClientTypes, @@ -276,6 +283,25 @@ export function ProviderEditFlow({ provider, onClose }: ProviderEditFlowProps) { />

+
+

+
+ + 3. Model Mapping +
+

+

+ Map request models to different upstream models. For example, map + "claude-sonnet-4-20250514" to "gemini-2.5-pro". +

+ + setFormData(prev => ({ ...prev, modelMapping })) + } + /> +
+ {saveStatus === 'error' && (
diff --git a/web/src/pages/requests/detail/RequestDetailPanel.tsx b/web/src/pages/requests/detail/RequestDetailPanel.tsx index 63334a22..230017b2 100644 --- a/web/src/pages/requests/detail/RequestDetailPanel.tsx +++ b/web/src/pages/requests/detail/RequestDetailPanel.tsx @@ -12,7 +12,7 @@ import { import { Server, Code, Database, Info, Zap } from 'lucide-react' import type { ProxyUpstreamAttempt, ProxyRequest } from '@/lib/transport' import { cn } from '@/lib/utils' -import { CopyButton, DiffButton, EmptyState } from './components' +import { CopyButton, CopyAsCurlButton, DiffButton, EmptyState } from './components' import { RequestDetailView } from './RequestDetailView' // Selection type: either the main request or an attempt @@ -114,6 +114,13 @@ export function RequestDetailPanel({
Attempt #{selectedAttempt.id} + {selectedAttempt.mappedModel && selectedAttempt.requestModel !== selectedAttempt.mappedModel && ( + + {selectedAttempt.requestModel} + + {selectedAttempt.mappedModel} + + )} {selectedAttempt.cost > 0 && ( Cost: {formatCost(selectedAttempt.cost)} @@ -148,6 +155,7 @@ export function RequestDetailPanel({ {selectedAttempt.requestInfo.url} +
@@ -361,49 +369,69 @@ export function RequestDetailPanel({ - Request Info + Attempt Info
- Request ID + Attempt ID
- {request.requestID || '-'} + #{selectedAttempt.id}
- Session ID + Provider
-
- {request.sessionID || '-'} +
+ {providerMap.get(selectedAttempt.providerID) || `Provider #${selectedAttempt.providerID}`}
- Instance ID + Request Model
-
- {request.instanceID || '-'} +
+ {selectedAttempt.requestModel || '-'}
- Request Model + Mapped Model
- {request.requestModel || '-'} + {selectedAttempt.mappedModel || '-'} + {selectedAttempt.mappedModel && selectedAttempt.requestModel !== selectedAttempt.mappedModel && ( + + (converted) + + )}
+ {selectedAttempt.responseModel && ( +
+
+ Response Model +
+
+ {selectedAttempt.responseModel} + {selectedAttempt.responseModel !== selectedAttempt.mappedModel && ( + + (upstream) + + )} +
+
+ )}
- Response Model + Status
- {request.responseModel || '-'} + {selectedAttempt.status}
diff --git a/web/src/pages/requests/detail/RequestDetailView.tsx b/web/src/pages/requests/detail/RequestDetailView.tsx index 3bbbdd0e..cbef56de 100644 --- a/web/src/pages/requests/detail/RequestDetailView.tsx +++ b/web/src/pages/requests/detail/RequestDetailView.tsx @@ -17,7 +17,7 @@ import { getClientName, getClientColor, } from '@/components/icons/client-icons' -import { CopyButton, EmptyState } from './components' +import { CopyButton, CopyAsCurlButton, EmptyState } from './components' interface RequestDetailViewProps { request: ProxyRequest @@ -101,6 +101,7 @@ export function RequestDetailView({ {request.requestInfo.url} +
diff --git a/web/src/pages/requests/detail/components/CopyAsCurlButton.tsx b/web/src/pages/requests/detail/components/CopyAsCurlButton.tsx new file mode 100644 index 00000000..63608bbe --- /dev/null +++ b/web/src/pages/requests/detail/components/CopyAsCurlButton.tsx @@ -0,0 +1,89 @@ +import { useState } from 'react' +import { Button } from '@/components/ui' +import { Terminal, Check } from 'lucide-react' +import type { RequestInfo } from '@/lib/transport' +import { useSetting } from '@/hooks/queries' + +interface CopyAsCurlButtonProps { + requestInfo: RequestInfo +} + +function generateCurlCommand(requestInfo: RequestInfo, proxyPort: string): string { + const parts: string[] = ['curl'] + + // Method (default is GET, so only add if different) + if (requestInfo.method && requestInfo.method !== 'GET') { + parts.push(`-X ${requestInfo.method}`) + } + + // Build full URL using proxy server address + const port = proxyPort || '9880' + const baseUrl = `http://localhost:${port}` + let fullUrl = requestInfo.url + if (fullUrl && !fullUrl.startsWith('http://') && !fullUrl.startsWith('https://')) { + fullUrl = `${baseUrl}${fullUrl}` + } + + parts.push(`'${fullUrl}'`) + + // Headers + if (requestInfo.headers) { + for (const [key, value] of Object.entries(requestInfo.headers)) { + // Skip some headers that curl handles automatically or are not useful + // Also skip Host header since we're using proxy server address + const skipHeaders = ['content-length', 'connection', 'accept-encoding', 'host'] + if (skipHeaders.includes(key.toLowerCase())) continue + + // Escape single quotes in header values + const escapedValue = value.replace(/'/g, "'\\''") + parts.push(`-H '${key}: ${escapedValue}'`) + } + } + + // Body + if (requestInfo.body) { + // Escape single quotes in body + const escapedBody = requestInfo.body.replace(/'/g, "'\\''") + parts.push(`-d '${escapedBody}'`) + } + + return parts.join(' \\\n ') +} + +export function CopyAsCurlButton({ requestInfo }: CopyAsCurlButtonProps) { + const [copied, setCopied] = useState(false) + const { data: settingData } = useSetting('proxy_port') + const proxyPort = settingData?.value || '9880' + + const handleCopy = async () => { + try { + const curlCommand = generateCurlCommand(requestInfo, proxyPort) + await navigator.clipboard.writeText(curlCommand) + setCopied(true) + setTimeout(() => setCopied(false), 2000) + } catch (err) { + console.error('Failed to copy:', err) + } + } + + return ( + + ) +} diff --git a/web/src/pages/requests/detail/components/index.ts b/web/src/pages/requests/detail/components/index.ts index abaa0edd..b24524b2 100644 --- a/web/src/pages/requests/detail/components/index.ts +++ b/web/src/pages/requests/detail/components/index.ts @@ -1,4 +1,5 @@ export { CopyButton } from './CopyButton' +export { CopyAsCurlButton } from './CopyAsCurlButton' export { DiffModal } from './DiffModal' export { DiffButton } from './DiffButton' export { EmptyState } from './EmptyState' diff --git a/web/src/pages/routes/form.tsx b/web/src/pages/routes/form.tsx index 9fe79f97..9a71a357 100644 --- a/web/src/pages/routes/form.tsx +++ b/web/src/pages/routes/form.tsx @@ -7,6 +7,7 @@ import { useProjects, } from '@/hooks/queries' import type { ClientType, Route } from '@/lib/transport' +import { ModelMappingEditor } from '@/pages/providers/components/model-mapping-editor' interface RouteFormProps { route?: Route @@ -34,6 +35,7 @@ export function RouteForm({ ) const [position, setPosition] = useState('1') const [isEnabled, setIsEnabled] = useState(true) + const [modelMapping, setModelMapping] = useState>({}) useEffect(() => { if (route) { @@ -42,6 +44,7 @@ export function RouteForm({ setProjectID(String(route.projectID)) setPosition(String(route.position)) setIsEnabled(route.isEnabled) + setModelMapping(route.modelMapping || {}) } }, [route]) @@ -65,6 +68,7 @@ export function RouteForm({ isEnabled, isNative: route?.isNative ?? false, // 手动创建的 Route 默认为转换路由 retryConfigID: route?.retryConfigID ?? 0, + modelMapping: Object.keys(modelMapping).length > 0 ? modelMapping : undefined, } if (isEditing) { @@ -143,6 +147,21 @@ export function RouteForm({
+ {/* Model Mapping (route-level override) */} +
+ +

+ Route-level model mappings take priority over provider and global settings. +

+ +
+
+
@@ -147,4 +155,222 @@ function ForceProjectSection() { ) } +interface SortableRuleItemProps { + id: string + index: number + rule: ModelMappingRule + onRemove: () => void + onUpdate: (pattern: string, target: string) => void + disabled: boolean +} + +function SortableRuleItem({ id, index, rule, onRemove, onUpdate, disabled }: SortableRuleItemProps) { + const { + attributes, + listeners, + setNodeRef, + transform, + transition, + isDragging, + } = useSortable({ id }) + + const style = { + transform: CSS.Transform.toString(transform), + transition, + } + + return ( +
+ + {index + 1}. + onUpdate(pattern, rule.target)} + placeholder="匹配模式" + disabled={disabled} + className="flex-1 max-w-xs h-7 text-xs" + /> + + onUpdate(rule.pattern, target)} + placeholder="目标模型" + disabled={disabled} + className="flex-1 max-w-xs h-7 text-xs" + providers={['Antigravity']} + /> + +
+ ) +} + +function AntigravityModelMappingSection() { + const { data: settings, isLoading } = useAntigravityGlobalSettings() + const updateSettings = useUpdateAntigravityGlobalSettings() + const resetSettings = useResetAntigravityGlobalSettings() + const [newPattern, setNewPattern] = useState('') + const [newTarget, setNewTarget] = useState('') + + const rules = settings?.modelMappingRules || [] + + const sensors = useSensors( + useSensor(PointerSensor), + useSensor(KeyboardSensor, { + coordinateGetter: sortableKeyboardCoordinates, + }) + ) + + const handleDragEnd = async (event: DragEndEvent) => { + const { active, over } = event + if (!over || active.id === over.id) return + + const oldIndex = rules.findIndex((_, i) => `rule-${i}` === active.id) + const newIndex = rules.findIndex((_, i) => `rule-${i}` === over.id) + + if (oldIndex !== -1 && newIndex !== -1) { + const newRules = arrayMove(rules, oldIndex, newIndex) + await updateSettings.mutateAsync({ + modelMappingRules: newRules, + }) + } + } + + const handleAddRule = async () => { + if (!newPattern.trim() || !newTarget.trim()) return + + const newRule: ModelMappingRule = { + pattern: newPattern.trim(), + target: newTarget.trim(), + } + await updateSettings.mutateAsync({ + modelMappingRules: [...rules, newRule], + }) + setNewPattern('') + setNewTarget('') + } + + const handleRemoveRule = async (index: number) => { + const newRules = rules.filter((_, i) => i !== index) + await updateSettings.mutateAsync({ + modelMappingRules: newRules, + }) + } + + const handleUpdateRule = async (index: number, pattern: string, target: string) => { + const newRules = [...rules] + newRules[index] = { pattern, target } + await updateSettings.mutateAsync({ + modelMappingRules: newRules, + }) + } + + const handleReset = async () => { + await resetSettings.mutateAsync() + } + + if (isLoading) return null + + const isPending = updateSettings.isPending || resetSettings.isPending + + return ( + + +
+ + + Antigravity 全局模型映射 + + +
+
+ +

+ 模型名映射规则,按顺序匹配(越靠前优先级越高)。支持通配符 *,例如 *sonnet* 匹配所有 sonnet 变体 +

+ + {rules.length > 0 && ( + + `rule-${i}`)} + strategy={verticalListSortingStrategy} + > +
+ {rules.map((rule, index) => ( + handleRemoveRule(index)} + onUpdate={(pattern, target) => handleUpdateRule(index, pattern, target)} + disabled={isPending} + /> + ))} +
+
+
+ )} + +
+ + + + +
+
+
+ ) +} + export default SettingsPage