Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions cmd/maxx/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,27 @@ func main() {
}()
log.Println("[Cooldown] Background cleanup started (runs every 1 hour)")

// Create WebSocket hub
wsHub := handler.NewWebSocketHub()

// Create Antigravity task service for periodic quota refresh and auto-sorting
antigravityTaskSvc := service.NewAntigravityTaskService(
cachedProviderRepo,
cachedRouteRepo,
antigravityQuotaRepo,
settingRepo,
proxyRequestRepo,
wsHub,
)

// Start background tasks
core.StartBackgroundTasks(core.BackgroundTaskDeps{
UsageStats: usageStatsRepo,
ProxyRequest: proxyRequestRepo,
Settings: settingRepo,
UsageStats: usageStatsRepo,
ProxyRequest: proxyRequestRepo,
Settings: settingRepo,
AntigravityTaskSvc: antigravityTaskSvc,
})

// Create WebSocket hub
wsHub := handler.NewWebSocketHub()

// Setup log output to broadcast via WebSocket
logWriter := handler.NewWebSocketLogWriter(wsHub, os.Stdout, logPath)
log.SetOutput(logWriter)
Expand Down Expand Up @@ -251,6 +262,7 @@ func main() {
adminHandler := handler.NewAdminHandler(adminService, backupService, logPath)
authHandler := handler.NewAuthHandler(authMiddleware)
antigravityHandler := handler.NewAntigravityHandler(adminService, antigravityQuotaRepo, wsHub)
antigravityHandler.SetTaskService(antigravityTaskSvc)
kiroHandler := handler.NewKiroHandler(adminService)

// Use already-created cached project repository for project proxy handler
Expand Down
35 changes: 32 additions & 3 deletions internal/core/task.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package core

import (
"context"
"log"
"strconv"
"time"

"github.com/awsl-project/maxx/internal/domain"
"github.com/awsl-project/maxx/internal/repository"
"github.com/awsl-project/maxx/internal/service"
)

const (
Expand All @@ -15,9 +17,10 @@ const (

// BackgroundTaskDeps 后台任务依赖
type BackgroundTaskDeps struct {
UsageStats repository.UsageStatsRepository
ProxyRequest repository.ProxyRequestRepository
Settings repository.SystemSettingRepository
UsageStats repository.UsageStatsRepository
ProxyRequest repository.ProxyRequestRepository
Settings repository.SystemSettingRepository
AntigravityTaskSvc *service.AntigravityTaskService
}

// StartBackgroundTasks 启动所有后台任务
Expand Down Expand Up @@ -66,6 +69,11 @@ func StartBackgroundTasks(deps BackgroundTaskDeps) {
}
}()

// Antigravity 配额刷新任务(动态间隔)
if deps.AntigravityTaskSvc != nil {
go deps.runAntigravityQuotaRefresh()
}

log.Println("[Task] Background tasks started (minute:30s, hour:1m, day:5m, cleanup:1h)")
}

Expand Down Expand Up @@ -124,3 +132,24 @@ func (d *BackgroundTaskDeps) cleanupOldRequests() {
log.Printf("[Task] Deleted %d requests older than %d hours", deleted, retentionHours)
}
}

// runAntigravityQuotaRefresh 定期刷新 Antigravity 配额
func (d *BackgroundTaskDeps) runAntigravityQuotaRefresh() {
time.Sleep(30 * time.Second) // 初始延迟

for {
interval := d.AntigravityTaskSvc.GetRefreshInterval()
if interval <= 0 {
// 禁用状态,每分钟检查一次配置
time.Sleep(1 * time.Minute)
continue
}

// 执行刷新
ctx := context.Background()
d.AntigravityTaskSvc.RefreshQuotas(ctx)

// 等待下一次刷新
time.Sleep(time.Duration(interval) * time.Minute)
}
}
8 changes: 5 additions & 3 deletions internal/domain/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,11 @@ type SystemSetting struct {

// 系统设置 Key 常量
const (
SettingKeyProxyPort = "proxy_port" // 代理服务器端口,默认 9880
SettingKeyRequestRetentionHours = "request_retention_hours" // 请求记录保留小时数,默认 168 小时(7天),0 表示不清理
SettingKeyTimezone = "timezone" // 时区设置,默认 Asia/Shanghai
SettingKeyProxyPort = "proxy_port" // 代理服务器端口,默认 9880
SettingKeyRequestRetentionHours = "request_retention_hours" // 请求记录保留小时数,默认 168 小时(7天),0 表示不清理
SettingKeyTimezone = "timezone" // 时区设置,默认 Asia/Shanghai
SettingKeyQuotaRefreshInterval = "quota_refresh_interval" // Antigravity 配额刷新间隔(分钟),0 表示禁用
SettingKeyAutoSortAntigravity = "auto_sort_antigravity" // 是否自动排序 Antigravity 路由,"true" 或 "false"
)

// Antigravity 模型配额
Expand Down
70 changes: 54 additions & 16 deletions internal/handler/antigravity.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type AntigravityHandler struct {
svc *service.AdminService
quotaRepo repository.AntigravityQuotaRepository
oauthManager *antigravity.OAuthManager
taskSvc *service.AntigravityTaskService
}

// NewAntigravityHandler creates a new Antigravity handler
Expand All @@ -32,6 +33,11 @@ func NewAntigravityHandler(svc *service.AdminService, quotaRepo repository.Antig
}
}

// SetTaskService sets the AntigravityTaskService for background task operations
func (h *AntigravityHandler) SetTaskService(taskSvc *service.AntigravityTaskService) {
h.taskSvc = taskSvc
}

// ServeHTTP routes Antigravity requests
// Routes:
// POST /antigravity/validate-token - 验证单个 refresh token
Expand All @@ -40,6 +46,8 @@ func NewAntigravityHandler(svc *service.AdminService, quotaRepo repository.Antig
// GET /antigravity/providers/quotas - 批量获取所有 Antigravity provider 的配额信息
// POST /antigravity/oauth/start - 启动 OAuth 流程
// GET /antigravity/oauth/callback - OAuth 回调
// POST /antigravity/refresh-quotas - 强制刷新所有配额
// POST /antigravity/sort-routes - 手动排序路由
func (h *AntigravityHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := strings.TrimPrefix(r.URL.Path, "/antigravity")
path = strings.TrimSuffix(path, "/")
Expand All @@ -58,6 +66,18 @@ func (h *AntigravityHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// POST /antigravity/refresh-quotas - 强制刷新所有配额
if len(parts) >= 2 && parts[1] == "refresh-quotas" && r.Method == http.MethodPost {
h.handleForceRefreshQuotas(w, r)
return
}

// POST /antigravity/sort-routes - 手动排序路由
if len(parts) >= 2 && parts[1] == "sort-routes" && r.Method == http.MethodPost {
h.handleSortRoutes(w, r)
return
}
Comment on lines +69 to +79
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

需鉴权:refresh-quotas / sort-routes 当前可匿名调用。
这些是管理操作,但 /api/antigravity/ 在主路由里未加 auth;任何人都可触发刷新/排序。建议在路由层加管理员鉴权,或在 handler 内校验权限/令牌。

Also applies to: 456-481

🤖 Prompt for AI Agents
In `@internal/handler/antigravity.go` around lines 69 - 79, The refresh-quotas and
sort-routes endpoints are currently callable anonymously—ensure admin-only
access by adding an auth check before invoking handlers: update the routing
branch that dispatches to handleForceRefreshQuotas and handleSortRoutes (and the
other similar antigravity handlers around the 456-481 range) to verify the
requester is an authenticated admin (e.g., call the existing auth middleware or
validate an admin token/roles from r.Context or Authorization header) and return
401/403 on failure; alternatively, add the same admin permission check at the
start of handleForceRefreshQuotas and handleSortRoutes so they reject non-admin
requests.


// GET /antigravity/providers/quotas - 批量获取配额(必须在单个 provider 路由之前匹配)
if len(parts) >= 3 && parts[1] == "providers" && parts[2] == "quotas" && r.Method == http.MethodGet {
h.handleGetBatchQuotas(w, r)
Expand Down Expand Up @@ -368,6 +388,8 @@ type BatchQuotaResult struct {
}

// GetBatchQuotas 批量获取所有 Antigravity provider 的配额信息(供 HTTP handler 和 Wails 共用)
// 优先从数据库返回缓存数据,即使过期也会返回(避免 API 请求阻塞)
// 配额刷新由后台任务负责
func (h *AntigravityHandler) GetBatchQuotas(ctx context.Context) (*BatchQuotaResult, error) {
// 获取所有 providers
providers, err := h.svc.GetProviders()
Expand All @@ -388,30 +410,19 @@ func (h *AntigravityHandler) GetBatchQuotas(ctx context.Context) (*BatchQuotaRes
config := provider.Config.Antigravity
email := config.Email

// 尝试从数据库获取缓存的配额
// 优先从数据库获取缓存的配额(无论是否过期)
if email != "" && h.quotaRepo != nil {
cachedQuota, err := h.quotaRepo.GetByEmail(email)
if err == nil && cachedQuota != nil {
// 检查是否过期(10分钟)- 如果未过期,直接使用缓存
if time.Since(cachedQuota.UpdatedAt).Seconds() < 600 {
result.Quotas[provider.ID] = h.domainQuotaToResponse(cachedQuota)
continue
}
result.Quotas[provider.ID] = h.domainQuotaToResponse(cachedQuota)
continue
}
}

// 缓存过期或不存在,从 API 获取最新配额
// 数据库没有缓存,尝试从 API 获取
quota, err := antigravity.FetchQuotaForProvider(ctx, config.RefreshToken, config.ProjectID)
if err != nil {
// 如果 API 失败,尝试使用过期的缓存数据
if email != "" && h.quotaRepo != nil {
cachedQuota, _ := h.quotaRepo.GetByEmail(email)
if cachedQuota != nil {
result.Quotas[provider.ID] = h.domainQuotaToResponse(cachedQuota)
continue
}
}
// 跳过此 provider,不中断整体查询
// API 失败,跳过此 provider
continue
}

Expand Down Expand Up @@ -442,6 +453,33 @@ func (h *AntigravityHandler) handleGetBatchQuotas(w http.ResponseWriter, r *http
writeJSON(w, http.StatusOK, result)
}

// handleForceRefreshQuotas 强制刷新所有 Antigravity 配额
func (h *AntigravityHandler) handleForceRefreshQuotas(w http.ResponseWriter, r *http.Request) {
if h.taskSvc == nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "task service not available"})
return
}

refreshed := h.taskSvc.ForceRefreshQuotas(r.Context())
writeJSON(w, http.StatusOK, map[string]interface{}{
"success": true,
"refreshed": refreshed,
})
}

// handleSortRoutes 手动排序 Antigravity 路由
func (h *AntigravityHandler) handleSortRoutes(w http.ResponseWriter, r *http.Request) {
if h.taskSvc == nil {
writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "task service not available"})
return
}

h.taskSvc.SortRoutes(r.Context())
writeJSON(w, http.StatusOK, map[string]interface{}{
"success": true,
})
}

// ============================================================================
// OAuth 授权处理函数
// ============================================================================
Expand Down
2 changes: 2 additions & 0 deletions internal/repository/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ type ProxyRequestRepository interface {
MarkStaleAsFailed(currentInstanceID string) (int64, error)
// DeleteOlderThan 删除指定时间之前的请求记录
DeleteOlderThan(before time.Time) (int64, error)
// HasRecentRequests 检查指定时间之后是否有请求记录
HasRecentRequests(since time.Time) (bool, error)
}

type ProxyUpstreamAttemptRepository interface {
Expand Down
10 changes: 10 additions & 0 deletions internal/repository/sqlite/proxy_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ func (r *ProxyRequestRepository) DeleteOlderThan(before time.Time) (int64, error
return affected, nil
}

// HasRecentRequests 检查指定时间之后是否有请求记录
func (r *ProxyRequestRepository) HasRecentRequests(since time.Time) (bool, error) {
sinceTs := toTimestamp(since)
var count int64
if err := r.db.gorm.Model(&ProxyRequest{}).Where("created_at >= ?", sinceTs).Limit(1).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}

func (r *ProxyRequestRepository) toModel(p *domain.ProxyRequest) *ProxyRequest {
return &ProxyRequest{
BaseModel: BaseModel{
Expand Down
Loading