diff --git a/cmd/maxx/main.go b/cmd/maxx/main.go index 0febad64..c5dfac0a 100644 --- a/cmd/maxx/main.go +++ b/cmd/maxx/main.go @@ -375,7 +375,7 @@ func main() { // Use already-created cached project repository for project proxy handler modelsHandler := handler.NewModelsHandler(responseModelRepo, cachedProviderRepo, cachedModelMappingRepo) - projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, modelsHandler, cachedProjectRepo) + projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, modelsHandler, cachedProjectRepo, tokenAuthMiddleware) // Setup routes mux := http.NewServeMux() @@ -417,6 +417,10 @@ func main() { w.Write([]byte(`{"status":"ok"}`)) }) + if authMiddleware != nil { + wsHub.SetAuthMiddleware(authMiddleware) + } + // WebSocket endpoint mux.HandleFunc("/ws", wsHub.HandleWebSocket) diff --git a/internal/adapter/provider/claude/service.go b/internal/adapter/provider/claude/service.go index b3504300..0fd7a955 100644 --- a/internal/adapter/provider/claude/service.go +++ b/internal/adapter/provider/claude/service.go @@ -68,6 +68,7 @@ func ValidateRefreshToken(ctx context.Context, refreshToken string) (*ClaudeToke // OAuthSession represents an OAuth authorization session type OAuthSession struct { State string + TenantID uint64 CodeVerifier string CreatedAt time.Time ExpiresAt time.Time @@ -113,7 +114,7 @@ func (m *OAuthManager) GenerateState() (string, error) { } // CreateSession creates a new OAuth session with PKCE -func (m *OAuthManager) CreateSession(state string) (*OAuthSession, *PKCEChallenge, error) { +func (m *OAuthManager) CreateSession(state string, tenantID uint64) (*OAuthSession, *PKCEChallenge, error) { pkce, err := GeneratePKCEChallenge() if err != nil { return nil, nil, fmt.Errorf("failed to generate PKCE challenge: %w", err) @@ -121,6 +122,7 @@ func (m *OAuthManager) CreateSession(state string) (*OAuthSession, *PKCEChalleng session := &OAuthSession{ State: state, + TenantID: tenantID, CodeVerifier: pkce.CodeVerifier, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(5 * time.Minute), @@ -153,10 +155,22 @@ func (m *OAuthManager) GetSession(state string) (*OAuthSession, bool) { // CompleteSession completes the OAuth session and broadcasts the result func (m *OAuthManager) CompleteSession(state string, result *OAuthResult) { result.State = state + + var tenantID uint64 + if val, ok := m.sessions.Load(state); ok { + if session, ok := val.(*OAuthSession); ok { + tenantID = session.TenantID + } + } + m.sessions.Delete(state) if m.broadcaster != nil { - m.broadcaster.BroadcastMessage("claude_oauth_result", result) + if tenantID != 0 { + m.broadcaster.BroadcastMessageToTenant(tenantID, "claude_oauth_result", result) + } else { + m.broadcaster.BroadcastMessage("claude_oauth_result", result) + } } } diff --git a/internal/adapter/provider/codex/service.go b/internal/adapter/provider/codex/service.go index 1a4df0d5..eea49e86 100644 --- a/internal/adapter/provider/codex/service.go +++ b/internal/adapter/provider/codex/service.go @@ -86,6 +86,7 @@ func ValidateRefreshToken(ctx context.Context, refreshToken string) (*CodexToken // OAuthSession represents an OAuth authorization session type OAuthSession struct { State string + TenantID uint64 CodeVerifier string CreatedAt time.Time ExpiresAt time.Time @@ -137,7 +138,7 @@ func (m *OAuthManager) GenerateState() (string, error) { } // CreateSession creates a new OAuth session with PKCE -func (m *OAuthManager) CreateSession(state string) (*OAuthSession, *PKCEChallenge, error) { +func (m *OAuthManager) CreateSession(state string, tenantID uint64) (*OAuthSession, *PKCEChallenge, error) { // Generate PKCE challenge pkce, err := GeneratePKCEChallenge() if err != nil { @@ -146,6 +147,7 @@ func (m *OAuthManager) CreateSession(state string) (*OAuthSession, *PKCEChalleng session := &OAuthSession{ State: state, + TenantID: tenantID, CodeVerifier: pkce.CodeVerifier, CreatedAt: time.Now(), ExpiresAt: time.Now().Add(5 * time.Minute), // 5 minute timeout @@ -181,12 +183,23 @@ func (m *OAuthManager) CompleteSession(state string, result *OAuthResult) { // Ensure state matches result.State = state + var tenantID uint64 + if val, ok := m.sessions.Load(state); ok { + if session, ok := val.(*OAuthSession); ok { + tenantID = session.TenantID + } + } + // Delete session m.sessions.Delete(state) // Broadcast result via WebSocket if m.broadcaster != nil { - m.broadcaster.BroadcastMessage("codex_oauth_result", result) + if tenantID != 0 { + m.broadcaster.BroadcastMessageToTenant(tenantID, "codex_oauth_result", result) + } else { + m.broadcaster.BroadcastMessage("codex_oauth_result", result) + } } } diff --git a/internal/event/broadcaster.go b/internal/event/broadcaster.go index 326ef352..e4554ab8 100644 --- a/internal/event/broadcaster.go +++ b/internal/event/broadcaster.go @@ -9,6 +9,7 @@ type Broadcaster interface { BroadcastProxyUpstreamAttempt(attempt *domain.ProxyUpstreamAttempt) BroadcastLog(message string) BroadcastMessage(messageType string, data interface{}) + BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) } // NopBroadcaster 空实现,用于测试或不需要广播的场景 @@ -18,6 +19,7 @@ func (n *NopBroadcaster) BroadcastProxyRequest(req *domain.ProxyRequest) func (n *NopBroadcaster) BroadcastProxyUpstreamAttempt(attempt *domain.ProxyUpstreamAttempt) {} func (n *NopBroadcaster) BroadcastLog(message string) {} func (n *NopBroadcaster) BroadcastMessage(messageType string, data interface{}) {} +func (n *NopBroadcaster) BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) {} // SanitizeProxyRequestForBroadcast 用于“实时广播”场景瘦身 payload: // 去掉 request/response 大字段,避免 WebSocket 消息动辄几十/几百 KB,导致前端 JSON.parse / GC 卡死。 diff --git a/internal/event/wails_broadcaster_desktop.go b/internal/event/wails_broadcaster_desktop.go index 377c6989..1e51278b 100644 --- a/internal/event/wails_broadcaster_desktop.go +++ b/internal/event/wails_broadcaster_desktop.go @@ -73,6 +73,14 @@ func (w *WailsBroadcaster) BroadcastLog(message string) { w.emitWailsEvent("log_message", message) } +// BroadcastMessageToTenant broadcasts a tenant-scoped custom message +func (w *WailsBroadcaster) BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) { + if w.inner != nil { + w.inner.BroadcastMessageToTenant(tenantID, messageType, data) + } + w.emitWailsEvent(messageType, data) +} + // BroadcastMessage broadcasts a custom message func (w *WailsBroadcaster) BroadcastMessage(messageType string, data interface{}) { if w.inner != nil { diff --git a/internal/event/wails_broadcaster_http.go b/internal/event/wails_broadcaster_http.go index d7715470..6595737c 100644 --- a/internal/event/wails_broadcaster_http.go +++ b/internal/event/wails_broadcaster_http.go @@ -52,6 +52,13 @@ func (w *WailsBroadcaster) BroadcastLog(message string) { } } +// BroadcastMessageToTenant broadcasts a tenant-scoped custom message +func (w *WailsBroadcaster) BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) { + if w.inner != nil { + w.inner.BroadcastMessageToTenant(tenantID, messageType, data) + } +} + // BroadcastMessage broadcasts a custom message func (w *WailsBroadcaster) BroadcastMessage(messageType string, data interface{}) { if w.inner != nil { diff --git a/internal/handler/claude.go b/internal/handler/claude.go index 325fdf63..920f64ab 100644 --- a/internal/handler/claude.go +++ b/internal/handler/claude.go @@ -102,13 +102,13 @@ type ClaudeOAuthStartResult struct { } // StartOAuth starts the OAuth authorization flow -func (h *ClaudeHandler) StartOAuth() (*ClaudeOAuthStartResult, error) { +func (h *ClaudeHandler) StartOAuth(tenantID uint64) (*ClaudeOAuthStartResult, error) { state, err := h.oauthManager.GenerateState() if err != nil { return nil, fmt.Errorf("failed to generate state: %w", err) } - _, pkce, err := h.oauthManager.CreateSession(state) + _, pkce, err := h.oauthManager.CreateSession(state, tenantID) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } @@ -158,7 +158,8 @@ func (h *ClaudeHandler) handleOAuthStart(w http.ResponseWriter, r *http.Request) cancel() } - result, err := h.StartOAuth() + tenantID := maxxctx.GetTenantID(r.Context()) + result, err := h.StartOAuth(tenantID) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return diff --git a/internal/handler/codex.go b/internal/handler/codex.go index ce8ef4e9..d1566d5a 100644 --- a/internal/handler/codex.go +++ b/internal/handler/codex.go @@ -146,7 +146,7 @@ type CodexOAuthStartResult struct { } // StartOAuth starts the OAuth authorization flow -func (h *CodexHandler) StartOAuth() (*CodexOAuthStartResult, error) { +func (h *CodexHandler) StartOAuth(tenantID uint64) (*CodexOAuthStartResult, error) { // Generate random state token state, err := h.oauthManager.GenerateState() if err != nil { @@ -154,7 +154,7 @@ func (h *CodexHandler) StartOAuth() (*CodexOAuthStartResult, error) { } // Create OAuth session with PKCE - _, pkce, err := h.oauthManager.CreateSession(state) + _, pkce, err := h.oauthManager.CreateSession(state, tenantID) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } @@ -207,7 +207,13 @@ func (h *CodexHandler) handleOAuthStart(w http.ResponseWriter, r *http.Request) cancel() } - result, err := h.StartOAuth() + tenantID := maxxctx.GetTenantID(r.Context()) + if tenantID == 0 { + writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "tenant context is required"}) + return + } + + result, err := h.StartOAuth(tenantID) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return diff --git a/internal/handler/websocket.go b/internal/handler/websocket.go index 681ef476..7c903f3c 100644 --- a/internal/handler/websocket.go +++ b/internal/handler/websocket.go @@ -6,6 +6,7 @@ import ( "io" "log" "net/http" + "net/url" "os" "strconv" "strings" @@ -29,11 +30,26 @@ type WSMessage struct { Data interface{} `json:"data"` } +type wsClient struct { + tenantID uint64 + userID uint64 + role string +} + +type queuedWSMessage struct { + message WSMessage + tenantID uint64 + scoped bool + defaultOnly bool +} + type WebSocketHub struct { - clients map[*websocket.Conn]bool - broadcast chan WSMessage + clients map[*websocket.Conn]wsClient + broadcast chan queuedWSMessage mu sync.RWMutex + authMiddleware *AuthMiddleware + // broadcast channel 满时的丢弃计数(热路径:只做原子累加) broadcastDroppedTotal atomic.Uint64 } @@ -42,43 +58,61 @@ const websocketWriteTimeout = 5 * time.Second func NewWebSocketHub() *WebSocketHub { hub := &WebSocketHub{ - clients: make(map[*websocket.Conn]bool), - broadcast: make(chan WSMessage, 100), + clients: make(map[*websocket.Conn]wsClient), + broadcast: make(chan queuedWSMessage, 100), } go hub.run() return hub } +func (h *WebSocketHub) SetAuthMiddleware(auth *AuthMiddleware) { + h.mu.Lock() + defer h.mu.Unlock() + h.authMiddleware = auth +} + func (h *WebSocketHub) run() { - for msg := range h.broadcast { + for queued := range h.broadcast { // 避免在持锁状态下进行网络写入;同时修复 RLock 下 delete map 的数据竞争风险 h.mu.RLock() - clients := make([]*websocket.Conn, 0, len(h.clients)) - for client := range h.clients { - clients = append(clients, client) + targets := make([]struct { + conn *websocket.Conn + client wsClient + }, 0, len(h.clients)) + for conn, client := range h.clients { + if queued.defaultOnly && client.tenantID != domain.DefaultTenantID { + continue + } + if queued.scoped && client.tenantID != queued.tenantID { + continue + } + targets = append(targets, struct { + conn *websocket.Conn + client wsClient + }{conn: conn, client: client}) } h.mu.RUnlock() var toRemove []*websocket.Conn - for _, client := range clients { - _ = client.SetWriteDeadline(time.Now().Add(websocketWriteTimeout)) - if err := client.WriteJSON(msg); err != nil { - _ = client.Close() - toRemove = append(toRemove, client) + for _, target := range targets { + _ = target.conn.SetWriteDeadline(time.Now().Add(websocketWriteTimeout)) + if err := target.conn.WriteJSON(queued.message); err != nil { + _ = target.conn.Close() + toRemove = append(toRemove, target.conn) } } if len(toRemove) > 0 { h.mu.Lock() - for _, client := range toRemove { - delete(h.clients, client) + for _, conn := range toRemove { + delete(h.clients, conn) } h.mu.Unlock() } } } -func (h *WebSocketHub) tryEnqueueBroadcast(msg WSMessage, meta string) { +func (h *WebSocketHub) tryEnqueueBroadcast(msg queuedWSMessage, meta string) { select { case h.broadcast <- msg: default: @@ -90,7 +124,41 @@ func (h *WebSocketHub) tryEnqueueBroadcast(msg WSMessage, meta string) { } } +func (h *WebSocketHub) authenticateRequest(r *http.Request) (wsClient, bool) { + h.mu.RLock() + auth := h.authMiddleware + h.mu.RUnlock() + + if auth == nil { + return wsClient{}, false + } + + token := strings.TrimSpace(r.URL.Query().Get("access_token")) + if token == "" { + authHeader := r.Header.Get(AuthHeader) + if strings.HasPrefix(authHeader, "Bearer ") { + token = strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + } + } + if token == "" { + return wsClient{}, false + } + + claims, valid := auth.ValidateToken(token) + if !valid { + return wsClient{}, false + } + + return wsClient{tenantID: claims.TenantID, userID: claims.UserID, role: claims.Role}, true +} + func (h *WebSocketHub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + clientMeta, ok := h.authenticateRequest(r) + if !ok { + writeUnauthorized(w) + return + } + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket upgrade error: %v", err) @@ -98,7 +166,7 @@ func (h *WebSocketHub) HandleWebSocket(w http.ResponseWriter, r *http.Request) { } h.mu.Lock() - h.clients[conn] = true + h.clients[conn] = clientMeta h.mu.Unlock() defer func() { @@ -130,9 +198,15 @@ func (h *WebSocketHub) BroadcastProxyRequest(req *domain.ProxyRequest) { meta += " requestDbID=" + strconv.FormatUint(snapshot.ID, 10) } } - msg := WSMessage{ - Type: "proxy_request_update", - Data: data, + msg := queuedWSMessage{ + message: WSMessage{ + Type: "proxy_request_update", + Data: data, + }, + } + if sanitized != nil && sanitized.TenantID != 0 { + msg.tenantID = sanitized.TenantID + msg.scoped = true } h.tryEnqueueBroadcast(msg, meta) } @@ -154,15 +228,30 @@ func (h *WebSocketHub) BroadcastProxyUpstreamAttempt(attempt *domain.ProxyUpstre meta += "attemptDbID=" + strconv.FormatUint(snapshot.ID, 10) } } - msg := WSMessage{ - Type: "proxy_upstream_attempt_update", - Data: data, + msg := queuedWSMessage{ + message: WSMessage{ + Type: "proxy_upstream_attempt_update", + Data: data, + }, + } + if sanitized != nil && sanitized.TenantID != 0 { + msg.tenantID = sanitized.TenantID + msg.scoped = true } h.tryEnqueueBroadcast(msg, meta) } // BroadcastMessage sends a custom message with specified type to all connected clients func (h *WebSocketHub) BroadcastMessage(messageType string, data interface{}) { + h.broadcastJSONMessage(0, false, false, messageType, data) +} + +// BroadcastMessageToTenant sends a custom message to a specific tenant only. +func (h *WebSocketHub) BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) { + h.broadcastJSONMessage(tenantID, true, false, messageType, data) +} + +func (h *WebSocketHub) broadcastJSONMessage(tenantID uint64, scoped bool, defaultOnly bool, messageType string, data interface{}) { // 约定:BroadcastMessage 允许调用方传入 map/struct/指针等可变对象。 // // 但由于实际发送是异步的(入队后由 run() 写到各连接),如果这里直接把可变指针放进 channel, @@ -182,20 +271,21 @@ func (h *WebSocketHub) BroadcastMessage(messageType string, data interface{}) { snapshot = json.RawMessage(b) } } - msg := WSMessage{ - Type: messageType, - Data: snapshot, + msg := queuedWSMessage{ + message: WSMessage{ + Type: messageType, + Data: snapshot, + }, + tenantID: tenantID, + scoped: scoped, + defaultOnly: defaultOnly, } h.tryEnqueueBroadcast(msg, "") } -// BroadcastLog sends a log message to all connected clients +// BroadcastLog sends a log message to all connected clients. func (h *WebSocketHub) BroadcastLog(message string) { - msg := WSMessage{ - Type: "log_message", - Data: message, - } - h.tryEnqueueBroadcast(msg, "") + h.broadcastJSONMessage(0, false, false, "log_message", message) } // WebSocketLogWriter implements io.Writer to capture logs and broadcast via WebSocket @@ -333,3 +423,17 @@ func countNewlines(chunks [][]byte) int { } return count } + +func appendAccessToken(rawURL string, token string) string { + if strings.TrimSpace(token) == "" { + return rawURL + } + parsed, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + q := parsed.Query() + q.Set("access_token", token) + parsed.RawQuery = q.Encode() + return parsed.String() +} diff --git a/internal/handler/websocket_test.go b/internal/handler/websocket_test.go index af24d36e..7a8fd24a 100644 --- a/internal/handler/websocket_test.go +++ b/internal/handler/websocket_test.go @@ -12,7 +12,7 @@ import ( func TestWebSocketHub_BroadcastProxyRequest_SendsSnapshot(t *testing.T) { hub := &WebSocketHub{ - broadcast: make(chan WSMessage, 1), + broadcast: make(chan queuedWSMessage, 1), } req := &domain.ProxyRequest{ @@ -26,7 +26,8 @@ func TestWebSocketHub_BroadcastProxyRequest_SendsSnapshot(t *testing.T) { // 如果 Broadcast 发送的是同一个指针,那么这里对原对象的修改会“污染”队列中的消息。 req.Status = "COMPLETED" - msg := <-hub.broadcast + queued := <-hub.broadcast + msg := queued.message if msg.Type != "proxy_request_update" { t.Fatalf("unexpected message type: %s", msg.Type) } @@ -50,7 +51,7 @@ func TestWebSocketHub_BroadcastProxyRequest_SendsSnapshot(t *testing.T) { func TestWebSocketHub_BroadcastProxyUpstreamAttempt_SendsSnapshot(t *testing.T) { hub := &WebSocketHub{ - broadcast: make(chan WSMessage, 1), + broadcast: make(chan queuedWSMessage, 1), } attempt := &domain.ProxyUpstreamAttempt{ @@ -62,7 +63,8 @@ func TestWebSocketHub_BroadcastProxyUpstreamAttempt_SendsSnapshot(t *testing.T) hub.BroadcastProxyUpstreamAttempt(attempt) attempt.Status = "COMPLETED" - msg := <-hub.broadcast + queued := <-hub.broadcast + msg := queued.message if msg.Type != "proxy_upstream_attempt_update" { t.Fatalf("unexpected message type: %s", msg.Type) } @@ -86,9 +88,9 @@ func TestWebSocketHub_BroadcastProxyUpstreamAttempt_SendsSnapshot(t *testing.T) func TestWebSocketHub_BroadcastDrop_IncrementsCounter(t *testing.T) { hub := &WebSocketHub{ - broadcast: make(chan WSMessage, 1), + broadcast: make(chan queuedWSMessage, 1), } - hub.broadcast <- WSMessage{Type: "dummy", Data: nil} + hub.broadcast <- queuedWSMessage{message: WSMessage{Type: "dummy", Data: nil}} before := hub.broadcastDroppedTotal.Load() @@ -108,12 +110,12 @@ func TestWebSocketHub_BroadcastDrop_IncrementsCounter(t *testing.T) { func TestWebSocketLogWriter_NoDeadlockOnFullChannel(t *testing.T) { // Create hub WITHOUT starting run() goroutine, so channel stays full hub := &WebSocketHub{ - broadcast: make(chan WSMessage, 100), + broadcast: make(chan queuedWSMessage, 100), } // Fill broadcast channel completely for i := 0; i < 100; i++ { - hub.broadcast <- WSMessage{Type: "fill", Data: i} + hub.broadcast <- queuedWSMessage{message: WSMessage{Type: "fill", Data: i}} } // Create WebSocketLogWriter pointing to this hub @@ -149,7 +151,7 @@ func TestWebSocketLogWriter_NoDeadlockOnFullChannel(t *testing.T) { func TestWebSocketHub_BroadcastMessage_SendsSnapshot(t *testing.T) { hub := &WebSocketHub{ - broadcast: make(chan WSMessage, 1), + broadcast: make(chan queuedWSMessage, 1), } type payload struct { @@ -162,7 +164,8 @@ func TestWebSocketHub_BroadcastMessage_SendsSnapshot(t *testing.T) { // 如果 BroadcastMessage 直接把指针放进队列,这里修改会污染后续消费者看到的数据。 p.A = 2 - msg := <-hub.broadcast + queued := <-hub.broadcast + msg := queued.message if msg.Type != "custom_event" { t.Fatalf("unexpected message type: %s", msg.Type) } diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 0490ec5b..8f72d180 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -157,9 +157,12 @@ type ProxyUpstreamAttemptRepository interface { ListAll() ([]*domain.ProxyUpstreamAttempt, error) // CountAll returns total count of attempts CountAll() (int64, error) + // CountByTenant returns total count of attempts for a tenant (or all tenants when using TenantIDAll) + CountByTenant(tenantID uint64) (int64, error) // StreamForCostCalc iterates through all attempts for cost calculation // Calls the callback with batches of minimal data, returns early if callback returns error StreamForCostCalc(batchSize int, callback func(batch []*domain.AttemptCostData) error) error + StreamForCostCalcByTenant(tenantID uint64, batchSize int, callback func(batch []*domain.AttemptCostData) error) error // UpdateCost updates only the cost field of an attempt UpdateCost(id uint64, cost uint64) error // BatchUpdateCosts updates costs for multiple attempts in a single transaction diff --git a/internal/repository/sqlite/proxy_upstream_attempt.go b/internal/repository/sqlite/proxy_upstream_attempt.go index 6fa7d04f..71fda5b1 100644 --- a/internal/repository/sqlite/proxy_upstream_attempt.go +++ b/internal/repository/sqlite/proxy_upstream_attempt.go @@ -53,8 +53,13 @@ func (r *ProxyUpstreamAttemptRepository) ListAll() ([]*domain.ProxyUpstreamAttem } func (r *ProxyUpstreamAttemptRepository) CountAll() (int64, error) { + return r.CountByTenant(domain.TenantIDAll) +} + +func (r *ProxyUpstreamAttemptRepository) CountByTenant(tenantID uint64) (int64, error) { var count int64 - if err := r.db.gorm.Model(&ProxyUpstreamAttempt{}).Count(&count).Error; err != nil { + query := tenantScope(r.db.gorm.Model(&ProxyUpstreamAttempt{}), tenantID) + if err := query.Count(&count).Error; err != nil { return 0, err } return count, nil @@ -63,6 +68,10 @@ func (r *ProxyUpstreamAttemptRepository) CountAll() (int64, error) { // StreamForCostCalc iterates through all attempts in batches for cost calculation // Only fetches fields needed for cost calculation, avoiding expensive JSON parsing func (r *ProxyUpstreamAttemptRepository) StreamForCostCalc(batchSize int, callback func(batch []*domain.AttemptCostData) error) error { + return r.StreamForCostCalcByTenant(domain.TenantIDAll, batchSize, callback) +} + +func (r *ProxyUpstreamAttemptRepository) StreamForCostCalcByTenant(tenantID uint64, batchSize int, callback func(batch []*domain.AttemptCostData) error) error { var lastID uint64 = 0 for { @@ -81,10 +90,14 @@ func (r *ProxyUpstreamAttemptRepository) StreamForCostCalc(batchSize int, callba Cost uint64 `gorm:"column:cost"` } - err := r.db.gorm.Table("proxy_upstream_attempts"). - Select("id, proxy_request_id, response_model, mapped_model, request_model, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost"). - Where("id > ?", lastID). - Order("id"). + query := tenantScope( + r.db.gorm.Table("proxy_upstream_attempts"). + Select("id, proxy_request_id, response_model, mapped_model, request_model, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost"). + Where("id > ?", lastID), + tenantID, + ) + + err := query.Order("id"). Limit(batchSize). Find(&results).Error diff --git a/tests/e2e/playwright/requests-table-alignment.spec.ts b/tests/e2e/playwright/requests-table-alignment.spec.ts index 49e8145a..443b31c1 100644 --- a/tests/e2e/playwright/requests-table-alignment.spec.ts +++ b/tests/e2e/playwright/requests-table-alignment.spec.ts @@ -206,6 +206,7 @@ async function openRequestsPage(page: Page, providerId?: number) { } test('virtualized requests table keeps header and body columns aligned', async ({ page }, testInfo) => { + test.setTimeout(90_000); const mock = await startMockClaudeServer(); let jwt: string | undefined; let providerId: number | null = null; @@ -267,9 +268,9 @@ test('virtualized requests table keeps header and body columns aligned', async ( const requests = await adminAPI('GET', '/requests?limit=100', undefined, jwt); return requests.items?.filter((item: any) => item.providerID === providerId).length ?? 0; }, - { timeout: 15000 }, + { timeout: 20000 }, ) - .toBeGreaterThanOrEqual(40); + .toBeGreaterThanOrEqual(24); await openRequestsPage(page, provider.id); await expect(page.locator('table thead th').first()).toBeVisible({ timeout: 30_000 }); @@ -315,3 +316,4 @@ test('virtualized requests table keeps header and body columns aligned', async ( await closeServer(mock.server); } }); +