diff --git a/cmd/maxx/main.go b/cmd/maxx/main.go index 9eef3bb7..77b69d0f 100644 --- a/cmd/maxx/main.go +++ b/cmd/maxx/main.go @@ -1,12 +1,15 @@ package main import ( + "context" "flag" "fmt" "log" "net/http" "os" + "os/signal" "path/filepath" + "syscall" "time" "github.com/awsl-project/maxx/internal/adapter/client" @@ -18,9 +21,9 @@ import ( "github.com/awsl-project/maxx/internal/handler" "github.com/awsl-project/maxx/internal/repository/cached" "github.com/awsl-project/maxx/internal/repository/sqlite" - "github.com/awsl-project/maxx/internal/stats" "github.com/awsl-project/maxx/internal/router" "github.com/awsl-project/maxx/internal/service" + "github.com/awsl-project/maxx/internal/stats" "github.com/awsl-project/maxx/internal/version" "github.com/awsl-project/maxx/internal/waiter" ) @@ -118,6 +121,23 @@ func main() { } else if count > 0 { log.Printf("Marked %d stale requests as failed", count) } + // Also mark stale upstream attempts as failed + if count, err := attemptRepo.MarkStaleAttemptsFailed(); err != nil { + log.Printf("Warning: Failed to mark stale attempts: %v", err) + } else if count > 0 { + log.Printf("Marked %d stale upstream attempts as failed", count) + } + // Fix legacy failed requests/attempts without end_time + if count, err := proxyRequestRepo.FixFailedRequestsWithoutEndTime(); err != nil { + log.Printf("Warning: Failed to fix failed requests without end_time: %v", err) + } else if count > 0 { + log.Printf("Fixed %d failed requests without end_time", count) + } + if count, err := attemptRepo.FixFailedAttemptsWithoutEndTime(); err != nil { + log.Printf("Warning: Failed to fix failed attempts without end_time: %v", err) + } else if count > 0 { + log.Printf("Fixed %d failed attempts without end_time", count) + } // Create cached repositories cachedProviderRepo := cached.NewProviderRepository(providerRepo) @@ -228,6 +248,7 @@ func main() { responseModelRepo, *addr, r, // Router implements ProviderAdapterRefresher interface + wsHub, ) // Create backup service @@ -257,8 +278,12 @@ func main() { log.Println("Proxy token authentication is enabled") } + // Create request tracker for graceful shutdown + requestTracker := core.NewRequestTracker() + // Create handlers proxyHandler := handler.NewProxyHandler(clientAdapter, exec, cachedSessionRepo, tokenAuthMiddleware) + proxyHandler.SetRequestTracker(requestTracker) adminHandler := handler.NewAdminHandler(adminService, backupService, logPath) authHandler := handler.NewAuthHandler(authMiddleware) antigravityHandler := handler.NewAntigravityHandler(adminService, antigravityQuotaRepo, wsHub) @@ -309,22 +334,55 @@ func main() { // Wrap with logging middleware loggedMux := handler.LoggingMiddleware(mux) - // Start server + // Create HTTP server + server := &http.Server{ + Addr: *addr, + Handler: loggedMux, + } + + // Start server in goroutine log.Printf("Starting Maxx server %s on %s", version.Info(), *addr) log.Printf("Data directory: %s", dataDirPath) - log.Printf(" Database: %s", dbPath) - log.Printf(" Log file: %s", logPath) - log.Printf("Admin API: http://localhost%s/api/admin/", *addr) - log.Printf("WebSocket: ws://localhost%s/ws", *addr) - log.Printf("Proxy endpoints:") - log.Printf(" Claude: http://localhost%s/v1/messages", *addr) - log.Printf(" OpenAI: http://localhost%s/v1/chat/completions", *addr) - log.Printf(" Codex: http://localhost%s/v1/responses", *addr) - log.Printf(" Gemini: http://localhost%s/v1beta/models/{model}:generateContent", *addr) - log.Printf("Project proxy: http://localhost%s/{project-slug}/v1/messages (etc.)", *addr) - - if err := http.ListenAndServe(*addr, loggedMux); err != nil { - log.Printf("Server error: %v", err) - os.Exit(1) + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Printf("Server error: %v", err) + os.Exit(1) + } + }() + + // Wait for interrupt signal (SIGINT or SIGTERM) + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigCh + log.Printf("Received signal %v, initiating graceful shutdown...", sig) + + // Step 1: Wait for active proxy requests to complete + activeCount := requestTracker.ActiveCount() + if activeCount > 0 { + log.Printf("Waiting for %d active proxy requests to complete...", activeCount) + completed := requestTracker.GracefulShutdown(core.GracefulShutdownTimeout) + if !completed { + log.Printf("Graceful shutdown timeout, some requests may be interrupted") + } else { + log.Printf("All proxy requests completed successfully") + } + } else { + // Mark as shutting down to reject new requests + requestTracker.GracefulShutdown(0) + log.Printf("No active proxy requests") + } + + // Step 2: Shutdown HTTP server + shutdownCtx, cancel := context.WithTimeout(context.Background(), core.HTTPShutdownTimeout) + defer cancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + log.Printf("HTTP server graceful shutdown failed: %v, forcing close", err) + if closeErr := server.Close(); closeErr != nil { + log.Printf("Force close error: %v", closeErr) + } } + + log.Printf("Server stopped") } diff --git a/coverage.out b/coverage.out new file mode 100644 index 00000000..7e5a102f --- /dev/null +++ b/coverage.out @@ -0,0 +1,64 @@ +mode: set +github.com/awsl-project/maxx/internal/stats/aggregator.go:14.90,18.2 1 0 +github.com/awsl-project/maxx/internal/stats/aggregator.go:21.46,23.2 1 0 +github.com/awsl-project/maxx/internal/stats/pure.go:35.93,37.11 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:38.32,39.33 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:40.30,41.31 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:42.29,43.66 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:44.30,47.19 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:47.19,49.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:50.3,50.78 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:51.31,52.60 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:53.30,54.52 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:55.10,56.31 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:63.90,64.23 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:64.23,66.3 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:68.2,79.28 3 1 +github.com/awsl-project/maxx/internal/stats/pure.go:79.28,93.21 4 1 +github.com/awsl-project/maxx/internal/stats/pure.go:93.21,95.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:96.3,96.17 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:96.17,98.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:100.3,100.33 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:100.33,110.4 9 1 +github.com/awsl-project/maxx/internal/stats/pure.go:110.9,130.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:133.2,134.29 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:134.29,136.3 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:137.2,137.15 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:143.105,144.21 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:144.21,146.3 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:148.2,159.26 3 1 +github.com/awsl-project/maxx/internal/stats/pure.go:159.26,172.40 3 1 +github.com/awsl-project/maxx/internal/stats/pure.go:172.40,182.4 9 1 +github.com/awsl-project/maxx/internal/stats/pure.go:182.9,202.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:205.2,206.29 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:206.29,208.3 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:209.2,209.15 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:214.73,227.34 3 1 +github.com/awsl-project/maxx/internal/stats/pure.go:227.34,228.27 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:228.27,240.39 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:240.39,250.5 9 1 +github.com/awsl-project/maxx/internal/stats/pure.go:250.10,254.5 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:258.2,259.27 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:259.27,261.3 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:262.2,262.15 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:268.140,269.26 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:269.26,278.3 8 1 +github.com/awsl-project/maxx/internal/stats/pure.go:279.2,279.8 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:284.83,287.26 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:287.26,288.24 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:288.24,289.12 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:292.3,292.47 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:292.47,301.4 8 1 +github.com/awsl-project/maxx/internal/stats/pure.go:301.9,313.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:317.2,317.28 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:317.28,318.27 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:318.27,320.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:323.2,323.15 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:327.97,329.26 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:329.26,330.25 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:330.25,332.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:334.2,334.15 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:339.95,341.26 2 1 +github.com/awsl-project/maxx/internal/stats/pure.go:341.26,342.62 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:342.62,344.4 1 1 +github.com/awsl-project/maxx/internal/stats/pure.go:346.2,346.15 1 1 diff --git a/internal/adapter/provider/antigravity/adapter.go b/internal/adapter/provider/antigravity/adapter.go index 0964c927..853602c9 100644 --- a/internal/adapter/provider/antigravity/adapter.go +++ b/internal/adapter/provider/antigravity/adapter.go @@ -111,7 +111,8 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, // Transform request based on client type var geminiBody []byte - if clientType == domain.ClientTypeClaude { + switch clientType { + case domain.ClientTypeClaude: // Use direct transformation (no converter dependency) // This combines cache control cleanup, thinking filter, tool loop recovery, // system instruction building, content transformation, tool building, and generation config @@ -127,10 +128,10 @@ func (a *AntigravityAdapter) Execute(ctx context.Context, w http.ResponseWriter, // Apply minimal post-processing for features not yet fully integrated geminiBody = applyClaudePostProcess(geminiBody, sessionID, hasThinking, requestBody, mappedModel) - } else if clientType == domain.ClientTypeOpenAI { + case domain.ClientTypeOpenAI: // TODO: Implement OpenAI transformation in the future return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, true, "OpenAI transformation not yet implemented") - } else { + default: // For Gemini, unwrap CLI envelope if present geminiBody = unwrapGeminiCLIEnvelope(requestBody) } @@ -439,12 +440,9 @@ func applyClaudePostProcess(geminiBody []byte, sessionID string, hasThinking boo return geminiBody } - modified := false + modified := InjectToolConfig(request) // 1. Inject toolConfig with VALIDATED mode when tools exist - if InjectToolConfig(request) { - modified = true - } // 2. Process contents for additional signature validation if contents, ok := request["contents"].([]interface{}); ok { @@ -545,16 +543,17 @@ func (a *AntigravityAdapter) handleNonStreamResponse(ctx context.Context, w http var responseBody []byte // Transform response based on client type - if clientType == domain.ClientTypeClaude { + switch clientType { + case domain.ClientTypeClaude: requestModel := ctxutil.GetRequestModel(ctx) responseBody, err = convertGeminiToClaudeResponse(unwrappedBody, requestModel) if err != nil { return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "failed to transform response") } - } else if clientType == domain.ClientTypeOpenAI { + case domain.ClientTypeOpenAI: // TODO: Implement OpenAI response transformation return domain.NewProxyErrorWithMessage(domain.ErrFormatConversion, false, "OpenAI response transformation not yet implemented") - } else { + default: // Gemini native responseBody = unwrappedBody } @@ -648,6 +647,7 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re // Read chunks and accumulate until we have complete lines var lineBuffer bytes.Buffer buf := make([]byte, 4096) + firstChunkSent := false // Track TTFT for { // Check context before reading @@ -700,6 +700,12 @@ func (a *AntigravityAdapter) handleStreamResponse(ctx context.Context, w http.Re return domain.NewProxyErrorWithMessage(writeErr, false, "client disconnected") } flusher.Flush() + + // Track TTFT: send first token time on first successful write + if !firstChunkSent { + firstChunkSent = true + eventChan.SendFirstToken(time.Now().UnixMilli()) + } } } } diff --git a/internal/adapter/provider/antigravity/claude_request_postprocess.go b/internal/adapter/provider/antigravity/claude_request_postprocess.go index f18fe701..026ccefc 100644 --- a/internal/adapter/provider/antigravity/claude_request_postprocess.go +++ b/internal/adapter/provider/antigravity/claude_request_postprocess.go @@ -31,12 +31,9 @@ func PostProcessClaudeRequest(geminiBody []byte, sessionID string, hasThinking b return geminiBody } - modified := false + modified := injectAntigravityIdentity(request) // 1. Inject Antigravity identity into system instruction (like Antigravity-Manager) - if injectAntigravityIdentity(request) { - modified = true - } // 2. Clean tool input schemas for Gemini compatibility (like Antigravity-Manager) if cleanToolInputSchemas(request) { diff --git a/internal/adapter/provider/antigravity/response.go b/internal/adapter/provider/antigravity/response.go index 0f2712cf..b8ef97c7 100644 --- a/internal/adapter/provider/antigravity/response.go +++ b/internal/adapter/provider/antigravity/response.go @@ -304,7 +304,6 @@ func convertGeminiToClaudeResponse(geminiBody []byte, requestModel string) ([]by "thinking": "", "signature": trailingSignature, }) - trailingSignature = "" } } diff --git a/internal/adapter/provider/antigravity/retry_delay.go b/internal/adapter/provider/antigravity/retry_delay.go index a3ca2054..058ff760 100644 --- a/internal/adapter/provider/antigravity/retry_delay.go +++ b/internal/adapter/provider/antigravity/retry_delay.go @@ -47,7 +47,7 @@ func ParseRetryInfo(statusCode int, body []byte) *RetryInfo { bodyStr := string(body) // Parse reason - reason := RateLimitReasonUnknown + var reason RateLimitReason if statusCode == 429 { reason = parseRateLimitReason(bodyStr) } else { diff --git a/internal/adapter/provider/antigravity/transform_request.go b/internal/adapter/provider/antigravity/transform_request.go index fc3531ce..8a9609ac 100644 --- a/internal/adapter/provider/antigravity/transform_request.go +++ b/internal/adapter/provider/antigravity/transform_request.go @@ -270,7 +270,7 @@ func removeTrailingUnsignedThinking(messages *[]ClaudeMessage) { } blocks := parseContentBlocks((*messages)[i].Content) - if blocks == nil || len(blocks) == 0 { + if len(blocks) == 0 { continue } diff --git a/internal/adapter/provider/antigravity/transform_tools.go b/internal/adapter/provider/antigravity/transform_tools.go index 520493f1..298f2e5d 100644 --- a/internal/adapter/provider/antigravity/transform_tools.go +++ b/internal/adapter/provider/antigravity/transform_tools.go @@ -8,7 +8,7 @@ import ( // buildTools converts Claude tools to Gemini tools format // Reference: Antigravity-Manager's build_tools func buildTools(claudeReq *ClaudeRequest) interface{} { - if claudeReq.Tools == nil || len(claudeReq.Tools) == 0 { + if len(claudeReq.Tools) == 0 { return nil } diff --git a/internal/adapter/provider/custom/adapter.go b/internal/adapter/provider/custom/adapter.go index c0f928e5..f1e8a2ed 100644 --- a/internal/adapter/provider/custom/adapter.go +++ b/internal/adapter/provider/custom/adapter.go @@ -316,6 +316,7 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons // Use buffer-based approach to handle incomplete lines properly var lineBuffer bytes.Buffer buf := make([]byte, 4096) + firstChunkSent := false // Track TTFT for { // Check context before reading @@ -361,6 +362,12 @@ func (a *CustomAdapter) handleStreamResponse(ctx context.Context, w http.Respons return domain.NewProxyErrorWithMessage(writeErr, false, "client disconnected") } flusher.Flush() + + // Track TTFT: send first token time on first successful write + if !firstChunkSent { + firstChunkSent = true + eventChan.SendFirstToken(time.Now().UnixMilli()) + } } } } @@ -579,7 +586,7 @@ func copyResponseHeaders(dst, src http.Header) { // Supports multiple API formats: OpenAI, Anthropic, Gemini, etc. func parseRateLimitInfo(resp *http.Response, body []byte, clientType domain.ClientType) *domain.RateLimitInfo { var resetTime time.Time - var rateLimitType string = "rate_limit_exceeded" + var rateLimitType = "rate_limit_exceeded" // Method 1: Parse Retry-After header if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" { @@ -698,7 +705,6 @@ func extractResponseModel(body []byte, targetType domain.ClientType) string { return "" } - // extractResponseModelFromSSE extracts the model name from SSE content based on target type func extractResponseModelFromSSE(sseContent string, targetType domain.ClientType) string { var lastModel string diff --git a/internal/adapter/provider/kiro/adapter.go b/internal/adapter/provider/kiro/adapter.go index 5cedecd7..93dc1923 100644 --- a/internal/adapter/provider/kiro/adapter.go +++ b/internal/adapter/provider/kiro/adapter.go @@ -362,7 +362,7 @@ func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseW if err := streamCtx.sendInitialEvents(); err != nil { inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel) + a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return domain.NewProxyErrorWithMessage(err, false, "failed to send initial events") } @@ -370,34 +370,39 @@ func (a *KiroAdapter) handleStreamResponse(ctx context.Context, w http.ResponseW if err != nil { if ctx.Err() != nil { inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel) + a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return domain.NewProxyErrorWithMessage(ctx.Err(), false, "client disconnected") } _ = streamCtx.sendFinalEvents() inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel) + a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return nil } if err := streamCtx.sendFinalEvents(); err != nil { inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel) + a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return domain.NewProxyErrorWithMessage(err, false, "failed to send final events") } inTok, outTok := streamCtx.GetTokenCounts() - a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel) + a.sendFinalEvents(ctx, sseBuffer.String(), inTok, outTok, requestModel, streamCtx.GetFirstTokenTimeMs()) return nil } // sendFinalEvents sends final events via EventChannel -func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTokens, outputTokens int, requestModel string) { +func (a *KiroAdapter) sendFinalEvents(ctx context.Context, body string, inputTokens, outputTokens int, requestModel string, firstTokenTimeMs int64) { eventChan := ctxutil.GetEventChan(ctx) if eventChan == nil { return } + // Send first token time if available (for TTFT tracking) + if firstTokenTimeMs > 0 { + eventChan.SendFirstToken(firstTokenTimeMs) + } + // Send response info with body eventChan.SendResponseInfo(&domain.ResponseInfo{ Status: 200, // streaming always returns 200 at this point diff --git a/internal/adapter/provider/kiro/compliant_event_stream_parser.go b/internal/adapter/provider/kiro/compliant_event_stream_parser.go index da7579a7..85520a41 100644 --- a/internal/adapter/provider/kiro/compliant_event_stream_parser.go +++ b/internal/adapter/provider/kiro/compliant_event_stream_parser.go @@ -31,7 +31,7 @@ func (cesp *CompliantEventStreamParser) Reset() { func (cesp *CompliantEventStreamParser) ParseResponse(streamData []byte) (*ParseResult, error) { messages, err := cesp.robustParser.ParseStream(streamData) if err != nil { - // Continue with partial messages. + _ = err // Continue with partial messages. } var allEvents []SSEEvent @@ -63,7 +63,7 @@ func (cesp *CompliantEventStreamParser) ParseResponse(streamData []byte) (*Parse func (cesp *CompliantEventStreamParser) ParseStream(data []byte) ([]SSEEvent, error) { messages, err := cesp.robustParser.ParseStream(data) if err != nil { - // Continue with partial messages. + _ = err // Continue with partial messages. } var allEvents []SSEEvent diff --git a/internal/adapter/provider/kiro/robust_parser.go b/internal/adapter/provider/kiro/robust_parser.go index 4267e456..7bef1527 100644 --- a/internal/adapter/provider/kiro/robust_parser.go +++ b/internal/adapter/provider/kiro/robust_parser.go @@ -9,10 +9,10 @@ import ( // RobustEventStreamParser parses AWS EventStream frames with error recovery. type RobustEventStreamParser struct { - buffer *bytes.Buffer + buffer *bytes.Buffer errorCount int - maxErrors int - mu sync.Mutex + maxErrors int + mu sync.Mutex } // NewRobustEventStreamParser creates a parser instance. @@ -47,10 +47,7 @@ func (rp *RobustEventStreamParser) ParseStream(data []byte) ([]*EventStreamMessa messages := make([]*EventStreamMessage, 0, 8) - for { - if rp.buffer.Len() < EventStreamMinMessageSize { - break - } + for rp.buffer.Len() >= EventStreamMinMessageSize { bufferBytes := rp.buffer.Bytes() if len(bufferBytes) < EventStreamMinMessageSize { diff --git a/internal/adapter/provider/kiro/streaming.go b/internal/adapter/provider/kiro/streaming.go index 7eea351f..765fadf9 100644 --- a/internal/adapter/provider/kiro/streaming.go +++ b/internal/adapter/provider/kiro/streaming.go @@ -25,6 +25,7 @@ type streamProcessorContext struct { toolUseIdByBlockIndex map[int]string completedToolUseIds map[string]bool jsonBytesByBlockIndex map[int]int + firstTokenTimeMs int64 // Unix milliseconds of first token sent (for TTFT tracking) } func newStreamProcessorContext(w http.ResponseWriter, model string, inputTokens int, writer io.Writer) (*streamProcessorContext, error) { @@ -169,6 +170,12 @@ func (ctx *streamProcessorContext) processEvent(event SSEEvent) error { return err } ctx.flusher.Flush() + + // Track TTFT: record first token time on first successful send + if ctx.firstTokenTimeMs == 0 { + ctx.firstTokenTimeMs = time.Now().UnixMilli() + } + return nil } @@ -321,3 +328,8 @@ func (ctx *streamProcessorContext) GetTokenCounts() (inputTokens int, outputToke } return ctx.inputTokens, outputTokens } + +// GetFirstTokenTimeMs returns the first token time in Unix milliseconds (for TTFT tracking) +func (ctx *streamProcessorContext) GetFirstTokenTimeMs() int64 { + return ctx.firstTokenTimeMs +} diff --git a/internal/converter/claude_to_codex.go b/internal/converter/claude_to_codex.go index 4ddb333d..3be81d3f 100644 --- a/internal/converter/claude_to_codex.go +++ b/internal/converter/claude_to_codex.go @@ -66,7 +66,7 @@ func (c *claudeToCodexRequest) Transform(body []byte, model string, stream bool) // Convert tool use to function_call output name, _ := m["name"].(string) id, _ := m["id"].(string) - inputData, _ := m["input"] + inputData := m["input"] argJSON, _ := json.Marshal(inputData) input = append(input, CodexInputItem{ Type: "function_call", diff --git a/internal/converter/claude_to_openai.go b/internal/converter/claude_to_openai.go index 2d03ac42..c9608ad7 100644 --- a/internal/converter/claude_to_openai.go +++ b/internal/converter/claude_to_openai.go @@ -74,11 +74,11 @@ func (c *claudeToOpenAIRequest) Transform(body []byte, model string, stream bool case "tool_use": id, _ := m["id"].(string) name, _ := m["name"].(string) - input, _ := m["input"] + input := m["input"] inputJSON, _ := json.Marshal(input) toolCalls = append(toolCalls, OpenAIToolCall{ - ID: id, - Type: "function", + ID: id, + Type: "function", Function: OpenAIFunctionCall{Name: name, Arguments: string(inputJSON)}, }) case "tool_result": @@ -155,8 +155,8 @@ func (c *claudeToOpenAIResponse) Transform(body []byte) ([]byte, error) { case "tool_use": inputJSON, _ := json.Marshal(block.Input) toolCalls = append(toolCalls, OpenAIToolCall{ - ID: block.ID, - Type: "function", + ID: block.ID, + Type: "function", Function: OpenAIFunctionCall{Name: block.Name, Arguments: string(inputJSON)}, }) } diff --git a/internal/converter/codex_to_gemini.go b/internal/converter/codex_to_gemini.go index 65cbbfe0..736ef424 100644 --- a/internal/converter/codex_to_gemini.go +++ b/internal/converter/codex_to_gemini.go @@ -52,7 +52,7 @@ func (c *codexToGeminiRequest) Transform(body []byte, model string, stream bool) switch itemType { case "message": role := mapCodexRoleToGemini(m["role"]) - content, _ := m["content"] + content := m["content"] var parts []GeminiPart switch c := content.(type) { case string: diff --git a/internal/converter/gemini_to_openai.go b/internal/converter/gemini_to_openai.go index 44dff8d7..f43ae1a1 100644 --- a/internal/converter/gemini_to_openai.go +++ b/internal/converter/gemini_to_openai.go @@ -105,12 +105,8 @@ func (c *geminiToOpenAIRequest) Transform(body []byte, model string, stream bool for _, tool := range req.Tools { for _, decl := range tool.FunctionDeclarations { openaiReq.Tools = append(openaiReq.Tools, OpenAITool{ - Type: "function", - Function: OpenAIFunction{ - Name: decl.Name, - Description: decl.Description, - Parameters: decl.Parameters, - }, + Type: "function", + Function: OpenAIFunction(decl), }) } } diff --git a/internal/core/database.go b/internal/core/database.go index 1f96e127..cc43170f 100644 --- a/internal/core/database.go +++ b/internal/core/database.go @@ -70,6 +70,7 @@ type ServerComponents struct { AntigravityHandler *handler.AntigravityHandler KiroHandler *handler.KiroHandler ProjectProxyHandler *handler.ProjectProxyHandler + RequestTracker *RequestTracker } // InitializeDatabase 初始化数据库和所有仓库 @@ -171,6 +172,23 @@ func InitializeServerComponents( } else if count > 0 { log.Printf("[Core] Marked %d stale requests as failed", count) } + // Also mark stale upstream attempts as failed + if count, err := repos.AttemptRepo.MarkStaleAttemptsFailed(); err != nil { + log.Printf("[Core] Warning: Failed to mark stale attempts: %v", err) + } else if count > 0 { + log.Printf("[Core] Marked %d stale upstream attempts as failed", count) + } + // Fix legacy failed requests/attempts without end_time + if count, err := repos.ProxyRequestRepo.FixFailedRequestsWithoutEndTime(); err != nil { + log.Printf("[Core] Warning: Failed to fix failed requests without end_time: %v", err) + } else if count > 0 { + log.Printf("[Core] Fixed %d failed requests without end_time", count) + } + if count, err := repos.AttemptRepo.FixFailedAttemptsWithoutEndTime(); err != nil { + log.Printf("[Core] Warning: Failed to fix failed attempts without end_time: %v", err) + } else if count > 0 { + log.Printf("[Core] Fixed %d failed attempts without end_time", count) + } log.Printf("[Core] Loading cached data") if err := repos.CachedProviderRepo.Load(); err != nil { @@ -275,6 +293,7 @@ func InitializeServerComponents( repos.ResponseModelRepo, addr, r, + wailsBroadcaster, ) log.Printf("[Core] Creating backup service") @@ -298,6 +317,10 @@ func InitializeServerComponents( kiroHandler := handler.NewKiroHandler(adminService) projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, repos.CachedProjectRepo) + log.Printf("[Core] Creating request tracker for graceful shutdown") + requestTracker := NewRequestTracker() + proxyHandler.SetRequestTracker(requestTracker) + components := &ServerComponents{ Router: r, WebSocketHub: wsHub, @@ -310,6 +333,7 @@ func InitializeServerComponents( AntigravityHandler: antigravityHandler, KiroHandler: kiroHandler, ProjectProxyHandler: projectProxyHandler, + RequestTracker: requestTracker, } log.Printf("[Core] Server components initialized successfully") diff --git a/internal/core/request_tracker.go b/internal/core/request_tracker.go new file mode 100644 index 00000000..363eacb9 --- /dev/null +++ b/internal/core/request_tracker.go @@ -0,0 +1,162 @@ +package core + +import ( + "context" + "log" + "sync" + "sync/atomic" + "time" +) + +// RequestTracker tracks active proxy requests for graceful shutdown +type RequestTracker struct { + activeCount int64 + wg sync.WaitGroup + shutdownCh chan struct{} + isShutdown atomic.Bool + // notifyCh is used to notify when a request completes during shutdown + notifyCh chan struct{} + notifyMu sync.Mutex +} + +// NewRequestTracker creates a new request tracker +func NewRequestTracker() *RequestTracker { + return &RequestTracker{ + shutdownCh: make(chan struct{}), + } +} + +// Add increments the active request count +// Returns false if shutdown is in progress (request should be rejected) +func (t *RequestTracker) Add() bool { + if t.isShutdown.Load() { + return false + } + t.wg.Add(1) + atomic.AddInt64(&t.activeCount, 1) + return true +} + +// Done decrements the active request count +func (t *RequestTracker) Done() { + remaining := atomic.AddInt64(&t.activeCount, -1) + t.wg.Done() + + // Notify shutdown goroutine if shutting down + if t.isShutdown.Load() { + t.notifyMu.Lock() + ch := t.notifyCh + t.notifyMu.Unlock() + if ch != nil { + select { + case ch <- struct{}{}: + default: + // Non-blocking send, channel might be full or closed + } + } + log.Printf("[RequestTracker] Request completed, %d remaining", remaining) + } +} + +// ActiveCount returns the current number of active requests +func (t *RequestTracker) ActiveCount() int64 { + return atomic.LoadInt64(&t.activeCount) +} + +// WaitWithTimeout waits for all active requests to complete with a timeout +// Returns true if all requests completed, false if timeout occurred +func (t *RequestTracker) WaitWithTimeout(timeout time.Duration) bool { + t.isShutdown.Store(true) + close(t.shutdownCh) + + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +// WaitWithContext waits for all active requests to complete or context cancellation +// Returns true if all requests completed, false if context was cancelled +func (t *RequestTracker) WaitWithContext(ctx context.Context) bool { + t.isShutdown.Store(true) + close(t.shutdownCh) + + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + + select { + case <-done: + return true + case <-ctx.Done(): + return false + } +} + +// IsShuttingDown returns true if shutdown has been initiated +func (t *RequestTracker) IsShuttingDown() bool { + return t.isShutdown.Load() +} + +// ShutdownCh returns a channel that is closed when shutdown begins +func (t *RequestTracker) ShutdownCh() <-chan struct{} { + return t.shutdownCh +} + +// GracefulShutdown initiates graceful shutdown and waits for requests to complete +// maxWait: maximum time to wait for requests to complete +func (t *RequestTracker) GracefulShutdown(maxWait time.Duration) bool { + // Setup notify channel before marking shutdown + t.notifyMu.Lock() + t.notifyCh = make(chan struct{}, 100) // Buffered to avoid blocking Done() + t.notifyMu.Unlock() + + t.isShutdown.Store(true) + close(t.shutdownCh) + + activeCount := t.ActiveCount() + if activeCount == 0 { + log.Printf("[RequestTracker] No active requests, shutdown immediate") + return true + } + + log.Printf("[RequestTracker] Graceful shutdown initiated, waiting for %d active requests", activeCount) + + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + + deadline := time.After(maxWait) + + for { + select { + case <-done: + log.Printf("[RequestTracker] All requests completed, shutdown clean") + return true + case <-t.notifyCh: + // Request completed notification received, log is printed in Done() + // Check if all done + if t.ActiveCount() == 0 { + <-done // Wait for wg.Wait() to complete + log.Printf("[RequestTracker] All requests completed, shutdown clean") + return true + } + case <-deadline: + remaining := t.ActiveCount() + log.Printf("[RequestTracker] Timeout reached, %d requests still active, forcing shutdown", remaining) + return false + } + } +} diff --git a/internal/core/server.go b/internal/core/server.go index 3c16e08d..5a4eca1d 100644 --- a/internal/core/server.go +++ b/internal/core/server.go @@ -9,6 +9,14 @@ import ( "github.com/awsl-project/maxx/internal/handler" ) +// Graceful shutdown configuration +const ( + // GracefulShutdownTimeout is the maximum time to wait for active requests + GracefulShutdownTimeout = 2 * time.Minute + // HTTPShutdownTimeout is the timeout for HTTP server shutdown after requests complete + HTTPShutdownTimeout = 5 * time.Second +) + // ServerConfig 服务器配置 type ServerConfig struct { Addr string @@ -118,12 +126,33 @@ func (s *ManagedServer) Stop(ctx context.Context) error { log.Printf("[Server] Stopping HTTP server on %s", s.config.Addr) - // 使用较短的超时时间,超时后强制关闭 - shutdownCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + // Step 1: Wait for active proxy requests to complete (graceful shutdown) + if s.config.Components != nil && s.config.Components.RequestTracker != nil { + tracker := s.config.Components.RequestTracker + activeCount := tracker.ActiveCount() + + if activeCount > 0 { + log.Printf("[Server] Waiting for %d active proxy requests to complete...", activeCount) + + completed := tracker.GracefulShutdown(GracefulShutdownTimeout) + if !completed { + log.Printf("[Server] Graceful shutdown timeout, some requests may be interrupted") + } else { + log.Printf("[Server] All proxy requests completed successfully") + } + } else { + // Mark as shutting down to reject new requests + tracker.GracefulShutdown(0) + log.Printf("[Server] No active proxy requests") + } + } + + // Step 2: Shutdown HTTP server (with shorter timeout since requests should be done) + shutdownCtx, cancel := context.WithTimeout(ctx, HTTPShutdownTimeout) defer cancel() if err := s.httpServer.Shutdown(shutdownCtx); err != nil { - log.Printf("[Server] Graceful shutdown failed: %v, forcing close", err) + log.Printf("[Server] HTTP server graceful shutdown failed: %v, forcing close", err) // 强制关闭 if closeErr := s.httpServer.Close(); closeErr != nil { log.Printf("[Server] Force close error: %v", closeErr) diff --git a/internal/core/task.go b/internal/core/task.go index 3993cd31..4bb42a33 100644 --- a/internal/core/task.go +++ b/internal/core/task.go @@ -25,36 +25,18 @@ type BackgroundTaskDeps struct { // StartBackgroundTasks 启动所有后台任务 func StartBackgroundTasks(deps BackgroundTaskDeps) { - // 分钟级聚合任务(每 30 秒)- 实时聚合原始数据到分钟 + // 统计聚合任务(每 30 秒)- 聚合原始数据并自动 rollup 到各粒度 go func() { time.Sleep(5 * time.Second) // 初始延迟 - deps.runMinuteAggregation() - - ticker := time.NewTicker(30 * time.Second) - for range ticker.C { - deps.runMinuteAggregation() - } - }() - - // 小时级 Roll-up(每分钟)- 分钟 → 小时 - go func() { - time.Sleep(10 * time.Second) // 初始延迟 - deps.runHourlyRollup() - - ticker := time.NewTicker(1 * time.Minute) - for range ticker.C { - deps.runHourlyRollup() + for range deps.UsageStats.AggregateAndRollUp() { + // drain the channel to wait for completion } - }() - // 天级 Roll-up(每 5 分钟)- 小时 → 天/周/月 - go func() { - time.Sleep(15 * time.Second) // 初始延迟 - deps.runDailyRollup() - - ticker := time.NewTicker(5 * time.Minute) + ticker := time.NewTicker(30 * time.Second) for range ticker.C { - deps.runDailyRollup() + for range deps.UsageStats.AggregateAndRollUp() { + // drain the channel to wait for completion + } } }() @@ -74,27 +56,7 @@ func StartBackgroundTasks(deps BackgroundTaskDeps) { go deps.runAntigravityQuotaRefresh() } - log.Println("[Task] Background tasks started (minute:30s, hour:1m, day:5m, cleanup:1h)") -} - -// runMinuteAggregation 分钟级聚合:从原始数据聚合到分钟 -func (d *BackgroundTaskDeps) runMinuteAggregation() { - _, _ = d.UsageStats.AggregateMinute() -} - -// runHourlyRollup 小时级 Roll-up:分钟 → 小时 -func (d *BackgroundTaskDeps) runHourlyRollup() { - _, _ = d.UsageStats.RollUp(domain.GranularityMinute, domain.GranularityHour) -} - -// runDailyRollup 天级 Roll-up:小时 → 天/周/月 -func (d *BackgroundTaskDeps) runDailyRollup() { - // 小时 → 天 - _, _ = d.UsageStats.RollUp(domain.GranularityHour, domain.GranularityDay) - // 天 → 周 - _, _ = d.UsageStats.RollUp(domain.GranularityDay, domain.GranularityWeek) - // 天 → 月 - _, _ = d.UsageStats.RollUp(domain.GranularityDay, domain.GranularityMonth) + log.Println("[Task] Background tasks started (aggregation:30s, cleanup:1h)") } // runCleanupTasks 清理任务:清理过期数据 diff --git a/internal/domain/adapter_event.go b/internal/domain/adapter_event.go index c6aadec1..7faeb7d0 100644 --- a/internal/domain/adapter_event.go +++ b/internal/domain/adapter_event.go @@ -12,6 +12,8 @@ const ( EventMetrics // EventResponseModel is sent when response model is extracted EventResponseModel + // EventFirstToken is sent when the first token/chunk is received (for TTFT tracking) + EventFirstToken ) // AdapterMetrics contains token usage metrics (avoids import cycle with usage package) @@ -26,11 +28,12 @@ type AdapterMetrics struct { // AdapterEvent represents an event from adapter to executor type AdapterEvent struct { - Type AdapterEventType - RequestInfo *RequestInfo // for EventRequestInfo - ResponseInfo *ResponseInfo // for EventResponseInfo - Metrics *AdapterMetrics // for EventMetrics - ResponseModel string // for EventResponseModel + Type AdapterEventType + RequestInfo *RequestInfo // for EventRequestInfo + ResponseInfo *ResponseInfo // for EventResponseInfo + Metrics *AdapterMetrics // for EventMetrics + ResponseModel string // for EventResponseModel + FirstTokenTime int64 // for EventFirstToken (Unix milliseconds) } // AdapterEventChan is used by adapters to send events to executor @@ -86,6 +89,17 @@ func (ch AdapterEventChan) SendResponseModel(model string) { } } +// SendFirstToken sends first token event with the time when first token was received +func (ch AdapterEventChan) SendFirstToken(timeMs int64) { + if ch == nil || timeMs == 0 { + return + } + select { + case ch <- &AdapterEvent{Type: EventFirstToken, FirstTokenTime: timeMs}: + default: + } +} + // Close closes the event channel func (ch AdapterEventChan) Close() { if ch != nil { diff --git a/internal/domain/model.go b/internal/domain/model.go index 948e5130..aef1cd52 100644 --- a/internal/domain/model.go +++ b/internal/domain/model.go @@ -201,6 +201,9 @@ type ProxyRequest struct { EndTime time.Time `json:"endTime"` Duration time.Duration `json:"duration"` + // TTFT (Time To First Token) 首字时长,流式接口第一条数据返回的延迟 + TTFT time.Duration `json:"ttft"` + // 是否为 SSE 流式请求 IsStream bool `json:"isStream"` @@ -239,7 +242,7 @@ type ProxyRequest struct { Cache5mWriteCount uint64 `json:"cache5mWriteCount"` Cache1hWriteCount uint64 `json:"cache1hWriteCount"` - // 成本 (微美元,1 USD = 1,000,000) + // 成本 (纳美元,1 USD = 1,000,000,000 nanoUSD) Cost uint64 `json:"cost"` // 使用的 API Token ID,0 表示未使用 Token @@ -256,6 +259,9 @@ type ProxyUpstreamAttempt struct { EndTime time.Time `json:"endTime"` Duration time.Duration `json:"duration"` + // TTFT (Time To First Token) 首字时长,流式接口第一条数据返回的延迟 + TTFT time.Duration `json:"ttft"` + // PENDING, IN_PROGRESS, COMPLETED, FAILED Status string `json:"status"` @@ -295,6 +301,22 @@ type ProxyUpstreamAttempt struct { Cost uint64 `json:"cost"` } +// AttemptCostData contains minimal data needed for cost recalculation +type AttemptCostData struct { + ID uint64 + ProxyRequestID uint64 + ResponseModel string + MappedModel string + RequestModel string + InputTokenCount uint64 + OutputTokenCount uint64 + CacheReadCount uint64 + CacheWriteCount uint64 + Cache5mWriteCount uint64 + Cache1hWriteCount uint64 + Cost uint64 +} + // 重试配置 type RetryConfig struct { ID uint64 `json:"id"` @@ -432,7 +454,7 @@ type ProviderStats struct { TotalCacheRead uint64 `json:"totalCacheRead"` TotalCacheWrite uint64 `json:"totalCacheWrite"` - // 成本 (微美元) + // 成本 (纳美元) TotalCost uint64 `json:"totalCost"` } @@ -443,7 +465,6 @@ const ( GranularityMinute Granularity = "minute" GranularityHour Granularity = "hour" GranularityDay Granularity = "day" - GranularityWeek Granularity = "week" GranularityMonth Granularity = "month" ) @@ -469,6 +490,7 @@ type UsageStats struct { SuccessfulRequests uint64 `json:"successfulRequests"` FailedRequests uint64 `json:"failedRequests"` TotalDurationMs uint64 `json:"totalDurationMs"` // 累计请求耗时(毫秒) + TotalTTFTMs uint64 `json:"totalTtftMs"` // 累计首字时长(毫秒) // Token 统计 InputTokens uint64 `json:"inputTokens"` @@ -476,7 +498,7 @@ type UsageStats struct { CacheRead uint64 `json:"cacheRead"` CacheWrite uint64 `json:"cacheWrite"` - // 成本 (微美元) + // 成本 (纳美元) Cost uint64 `json:"cost"` } @@ -756,3 +778,25 @@ type DashboardData struct { ProviderStats map[uint64]DashboardProviderStats `json:"providerStats"` Timezone string `json:"timezone"` // 配置的时区,如 "Asia/Shanghai" } + +// ===== Progress Reporting ===== + +// Progress represents a progress update for long-running operations +type Progress struct { + Phase string `json:"phase"` // Current phase of the operation + Current int `json:"current"` // Current item being processed + Total int `json:"total"` // Total items to process + Percentage int `json:"percentage"` // 0-100 + Message string `json:"message"` // Human-readable message +} + +// AggregateEvent represents a progress event during stats aggregation +type AggregateEvent struct { + Phase string `json:"phase"` // "aggregate_minute", "rollup_hour", "rollup_day", "rollup_month" + From Granularity `json:"from"` // Source granularity (for rollup) + To Granularity `json:"to"` // Target granularity + StartTime int64 `json:"start_time"` // Start of time range (unix ms) + EndTime int64 `json:"end_time"` // End of time range (unix ms) + Count int `json:"count"` // Number of records created/updated + Error error `json:"-"` // Error if any (not serialized) +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go index eea6af07..cb9d92a6 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -234,6 +234,8 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http // If current attempt is still IN_PROGRESS, mark as cancelled/failed if currentAttempt != nil && currentAttempt.Status == "IN_PROGRESS" { + currentAttempt.EndTime = time.Now() + currentAttempt.Duration = currentAttempt.EndTime.Sub(currentAttempt.StartTime) if ctx.Err() != nil { currentAttempt.Status = "CANCELLED" } else { @@ -316,7 +318,7 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http return ctx.Err() } - // Create attempt record with start time + // Create attempt record with start time and request info attemptStartTime := time.Now() attemptRecord := &domain.ProxyUpstreamAttempt{ ProxyRequestID: proxyReq.ID, @@ -327,6 +329,7 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http StartTime: attemptStartTime, RequestModel: requestModel, MappedModel: mappedModel, + RequestInfo: proxyReq.RequestInfo, // Use original request info initially } if err := e.attemptRepo.Create(attemptRecord); err != nil { log.Printf("[Executor] Failed to create attempt record: %v", err) @@ -404,7 +407,12 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http Cache5mCreationCount: attemptRecord.Cache5mWriteCount, Cache1hCreationCount: attemptRecord.Cache1hWriteCount, } - attemptRecord.Cost = pricing.GlobalCalculator().Calculate(attemptRecord.MappedModel, metrics) + // Use ResponseModel for pricing (actual model from API response), fallback to MappedModel + pricingModel := attemptRecord.ResponseModel + if pricingModel == "" { + pricingModel = attemptRecord.MappedModel + } + attemptRecord.Cost = pricing.GlobalCalculator().Calculate(pricingModel, metrics) } _ = e.attemptRepo.Update(attemptRecord) @@ -443,6 +451,7 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http proxyReq.Cache1hWriteCount = metrics.Cache1hCreationCount } proxyReq.Cost = attemptRecord.Cost + proxyReq.TTFT = attemptRecord.TTFT _ = e.proxyRequestRepo.Update(proxyReq) @@ -476,7 +485,12 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http Cache5mCreationCount: attemptRecord.Cache5mWriteCount, Cache1hCreationCount: attemptRecord.Cache1hWriteCount, } - attemptRecord.Cost = pricing.GlobalCalculator().Calculate(attemptRecord.MappedModel, metrics) + // Use ResponseModel for pricing (actual model from API response), fallback to MappedModel + pricingModel := attemptRecord.ResponseModel + if pricingModel == "" { + pricingModel = attemptRecord.MappedModel + } + attemptRecord.Cost = pricing.GlobalCalculator().Calculate(pricingModel, metrics) } _ = e.attemptRepo.Update(attemptRecord) @@ -508,6 +522,7 @@ func (e *Executor) Execute(ctx context.Context, w http.ResponseWriter, req *http } } proxyReq.Cost = attemptRecord.Cost + proxyReq.TTFT = attemptRecord.TTFT _ = e.proxyRequestRepo.Update(proxyReq) if e.broadcaster != nil { @@ -800,6 +815,11 @@ func (e *Executor) processAdapterEvents(eventChan domain.AdapterEventChan, attem if event.ResponseModel != "" { attempt.ResponseModel = event.ResponseModel } + case domain.EventFirstToken: + if event.FirstTokenTime > 0 { + firstTokenTime := time.UnixMilli(event.FirstTokenTime) + attempt.TTFT = firstTokenTime.Sub(attempt.StartTime) + } } default: // No more events @@ -850,6 +870,13 @@ func (e *Executor) processAdapterEventsRealtime(eventChan domain.AdapterEventCha attempt.ResponseModel = event.ResponseModel needsBroadcast = true } + case domain.EventFirstToken: + if event.FirstTokenTime > 0 { + // Calculate TTFT as duration from start time to first token time + firstTokenTime := time.UnixMilli(event.FirstTokenTime) + attempt.TTFT = firstTokenTime.Sub(attempt.StartTime) + needsBroadcast = true + } } // Broadcast update immediately for real-time visibility diff --git a/internal/handler/admin.go b/internal/handler/admin.go index c2e020f3..35dc321f 100644 --- a/internal/handler/admin.go +++ b/internal/handler/admin.go @@ -9,6 +9,7 @@ import ( "github.com/awsl-project/maxx/internal/cooldown" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/pricing" "github.com/awsl-project/maxx/internal/repository" "github.com/awsl-project/maxx/internal/service" ) @@ -88,6 +89,8 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.handleResponseModels(w, r) case "backup": h.handleBackup(w, r, parts) + case "pricing": + h.handlePricing(w, r) default: writeJSON(w, http.StatusNotFound, map[string]string{"error": "not found"}) } @@ -650,7 +653,7 @@ func (h *AdminHandler) handleRoutingStrategies(w http.ResponseWriter, r *http.Re } // ProxyRequest handlers -// Routes: /admin/requests, /admin/requests/count, /admin/requests/active, /admin/requests/{id}, /admin/requests/{id}/attempts +// Routes: /admin/requests, /admin/requests/count, /admin/requests/active, /admin/requests/{id}, /admin/requests/{id}/attempts, /admin/requests/{id}/recalculate-cost func (h *AdminHandler) handleProxyRequests(w http.ResponseWriter, r *http.Request, id uint64, parts []string) { // Check for count endpoint: /admin/requests/count if len(parts) > 2 && parts[2] == "count" { @@ -670,6 +673,12 @@ func (h *AdminHandler) handleProxyRequests(w http.ResponseWriter, r *http.Reques return } + // Check for sub-resource: /admin/requests/{id}/recalculate-cost + if len(parts) > 3 && parts[3] == "recalculate-cost" && id > 0 { + h.handleRecalculateRequestCost(w, r, id) + return + } + switch r.Method { case http.MethodGet: if id > 0 { @@ -691,7 +700,25 @@ func (h *AdminHandler) handleProxyRequests(w http.ResponseWriter, r *http.Reques if a := r.URL.Query().Get("after"); a != "" { after, _ = strconv.ParseUint(a, 10, 64) } - result, err := h.svc.GetProxyRequestsCursor(limit, before, after) + + // 构建过滤条件 + var filter *repository.ProxyRequestFilter + providerIDStr := r.URL.Query().Get("providerId") + statusStr := r.URL.Query().Get("status") + + if providerIDStr != "" || statusStr != "" { + filter = &repository.ProxyRequestFilter{} + if providerIDStr != "" { + if providerID, err := strconv.ParseUint(providerIDStr, 10, 64); err == nil { + filter.ProviderID = &providerID + } + } + if statusStr != "" { + filter.Status = &statusStr + } + } + + result, err := h.svc.GetProxyRequestsCursor(limit, before, after, filter) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return @@ -710,7 +737,27 @@ func (h *AdminHandler) handleProxyRequestsCount(w http.ResponseWriter, r *http.R return } - count, err := h.svc.GetProxyRequestsCount() + // 解析过滤参数 + var filter *repository.ProxyRequestFilter + providerIDStr := r.URL.Query().Get("providerId") + statusStr := r.URL.Query().Get("status") + + if providerIDStr != "" || statusStr != "" { + filter = &repository.ProxyRequestFilter{} + if providerIDStr != "" { + providerID, err := strconv.ParseUint(providerIDStr, 10, 64) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid providerId"}) + return + } + filter.ProviderID = &providerID + } + if statusStr != "" { + filter.Status = &statusStr + } + } + + count, err := h.svc.GetProxyRequestsCountWithFilter(filter) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return @@ -748,6 +795,21 @@ func (h *AdminHandler) handleProxyUpstreamAttempts(w http.ResponseWriter, r *htt writeJSON(w, http.StatusOK, attempts) } +// handleRecalculateRequestCost handles POST /admin/requests/{id}/recalculate-cost +func (h *AdminHandler) handleRecalculateRequestCost(w http.ResponseWriter, r *http.Request, requestID uint64) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + return + } + + result, err := h.svc.RecalculateRequestCost(requestID) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, result) +} + // Settings handlers func (h *AdminHandler) handleSettings(w http.ResponseWriter, r *http.Request, parts []string) { var key string @@ -1168,6 +1230,11 @@ func (h *AdminHandler) handleUsageStats(w http.ResponseWriter, r *http.Request) h.handleRecalculateUsageStats(w, r) return } + // Check for recalculate-costs endpoint: /admin/usage-stats/recalculate-costs + if strings.HasSuffix(path, "/recalculate-costs") { + h.handleRecalculateCosts(w, r) + return + } if r.Method != http.MethodGet { writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) @@ -1187,8 +1254,6 @@ func (h *AdminHandler) handleUsageStats(w http.ResponseWriter, r *http.Request) filter.Granularity = domain.GranularityHour case "day": filter.Granularity = domain.GranularityDay - case "week": - filter.Granularity = domain.GranularityWeek case "month": filter.Granularity = domain.GranularityMonth default: @@ -1259,6 +1324,22 @@ func (h *AdminHandler) handleRecalculateUsageStats(w http.ResponseWriter, r *htt writeJSON(w, http.StatusOK, map[string]string{"message": "usage stats recalculated successfully"}) } +// handleRecalculateCosts handles POST /admin/usage-stats/recalculate-costs +// Recalculates cost for all attempts using the current price table +func (h *AdminHandler) handleRecalculateCosts(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + return + } + + result, err := h.svc.RecalculateCosts() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + writeJSON(w, http.StatusOK, result) +} + // handleResponseModels handles GET /admin/response-models func (h *AdminHandler) handleResponseModels(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { @@ -1359,6 +1440,18 @@ func (h *AdminHandler) handleBackupImport(w http.ResponseWriter, r *http.Request writeJSON(w, http.StatusOK, result) } +// handlePricing handles GET /admin/pricing +// Returns the default price table for cost calculation display +func (h *AdminHandler) handlePricing(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + return + } + + priceTable := pricing.DefaultPriceTable() + writeJSON(w, http.StatusOK, priceTable) +} + func writeJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) diff --git a/internal/handler/proxy.go b/internal/handler/proxy.go index 53191ab3..9ea62317 100644 --- a/internal/handler/proxy.go +++ b/internal/handler/proxy.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "strconv" + "sync" "github.com/awsl-project/maxx/internal/adapter/client" ctxutil "github.com/awsl-project/maxx/internal/context" @@ -14,12 +15,21 @@ import ( "github.com/awsl-project/maxx/internal/repository/cached" ) +// RequestTracker interface for tracking active requests +type RequestTracker interface { + Add() bool + Done() + IsShuttingDown() bool +} + // ProxyHandler handles AI API proxy requests type ProxyHandler struct { clientAdapter *client.Adapter executor *executor.Executor sessionRepo *cached.SessionRepository tokenAuth *TokenAuthMiddleware + tracker RequestTracker + trackerMu sync.RWMutex } // NewProxyHandler creates a new proxy handler @@ -37,10 +47,32 @@ func NewProxyHandler( } } +// SetRequestTracker sets the request tracker for graceful shutdown +func (h *ProxyHandler) SetRequestTracker(tracker RequestTracker) { + h.trackerMu.Lock() + defer h.trackerMu.Unlock() + h.tracker = tracker +} + // ServeHTTP handles proxy requests func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("[Proxy] Received request: %s %s", r.Method, r.URL.Path) + // Track request for graceful shutdown + h.trackerMu.RLock() + tracker := h.tracker + h.trackerMu.RUnlock() + + if tracker != nil { + if !tracker.Add() { + // Server is shutting down, reject new requests + log.Printf("[Proxy] Rejecting request during shutdown: %s %s", r.Method, r.URL.Path) + writeError(w, http.StatusServiceUnavailable, "server is shutting down") + return + } + defer tracker.Done() + } + if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return diff --git a/internal/pricing/calculator.go b/internal/pricing/calculator.go index 5320dc6a..c9f9e54f 100644 --- a/internal/pricing/calculator.go +++ b/internal/pricing/calculator.go @@ -34,7 +34,7 @@ func NewCalculator(pt *PriceTable) *Calculator { } } -// Calculate 计算成本,返回微美元 (1 USD = 1,000,000 microUSD) +// Calculate 计算成本,返回纳美元 (1 USD = 1,000,000,000 nanoUSD) // model: 模型名称 // metrics: token使用指标 // 如果模型未找到,返回0并记录警告日志 @@ -56,6 +56,7 @@ func (c *Calculator) Calculate(model string, metrics *usage.Metrics) uint64 { } // CalculateWithPricing 使用指定价格计算成本(纯整数运算) +// 返回: 纳美元成本 (nanoUSD) func (c *Calculator) CalculateWithPricing(pricing *ModelPricing, metrics *usage.Metrics) uint64 { if pricing == nil || metrics == nil { return 0 @@ -67,14 +68,14 @@ func (c *Calculator) CalculateWithPricing(pricing *ModelPricing, metrics *usage. if metrics.InputTokens > 0 { if pricing.Has1MContext { inputNum, inputDenom := pricing.GetInputPremiumFraction() - totalCost += CalculateTieredCostMicro( + totalCost += CalculateTieredCost( metrics.InputTokens, pricing.InputPriceMicro, inputNum, inputDenom, pricing.GetContext1MThreshold(), ) } else { - totalCost += CalculateLinearCostMicro(metrics.InputTokens, pricing.InputPriceMicro) + totalCost += CalculateLinearCost(metrics.InputTokens, pricing.InputPriceMicro) } } @@ -82,20 +83,20 @@ func (c *Calculator) CalculateWithPricing(pricing *ModelPricing, metrics *usage. if metrics.OutputTokens > 0 { if pricing.Has1MContext { outputNum, outputDenom := pricing.GetOutputPremiumFraction() - totalCost += CalculateTieredCostMicro( + totalCost += CalculateTieredCost( metrics.OutputTokens, pricing.OutputPriceMicro, outputNum, outputDenom, pricing.GetContext1MThreshold(), ) } else { - totalCost += CalculateLinearCostMicro(metrics.OutputTokens, pricing.OutputPriceMicro) + totalCost += CalculateLinearCost(metrics.OutputTokens, pricing.OutputPriceMicro) } } // 3. 缓存读取成本(使用 input 价格的 10%) if metrics.CacheReadCount > 0 { - totalCost += CalculateLinearCostMicro( + totalCost += CalculateLinearCost( metrics.CacheReadCount, pricing.GetEffectiveCacheReadPriceMicro(), ) @@ -103,7 +104,7 @@ func (c *Calculator) CalculateWithPricing(pricing *ModelPricing, metrics *usage. // 4. 5分钟缓存写入成本(使用 input 价格的 125%) if metrics.Cache5mCreationCount > 0 { - totalCost += CalculateLinearCostMicro( + totalCost += CalculateLinearCost( metrics.Cache5mCreationCount, pricing.GetEffectiveCache5mWritePriceMicro(), ) @@ -111,12 +112,20 @@ func (c *Calculator) CalculateWithPricing(pricing *ModelPricing, metrics *usage. // 5. 1小时缓存写入成本(使用 input 价格的 200%) if metrics.Cache1hCreationCount > 0 { - totalCost += CalculateLinearCostMicro( + totalCost += CalculateLinearCost( metrics.Cache1hCreationCount, pricing.GetEffectiveCache1hWritePriceMicro(), ) } + // 6. Fallback: 如果没有 5m/1h 细分但有总缓存写入数 + if metrics.Cache5mCreationCount == 0 && metrics.Cache1hCreationCount == 0 && metrics.CacheCreationCount > 0 { + totalCost += CalculateLinearCost( + metrics.CacheCreationCount, + pricing.GetEffectiveCache5mWritePriceMicro(), // 使用 5m 价格作为默认 + ) + } + return totalCost } diff --git a/internal/pricing/tiered.go b/internal/pricing/tiered.go index 3f54338e..d68fc505 100644 --- a/internal/pricing/tiered.go +++ b/internal/pricing/tiered.go @@ -1,37 +1,88 @@ package pricing +import "math/big" + // 价格单位常量 const ( - // MicroUSDPerUSD 1美元 = 1,000,000 微美元 + // MicroUSDPerUSD 1美元 = 1,000,000 微美元 (用于价格表存储) MicroUSDPerUSD = 1_000_000 + // NanoUSDPerUSD 1美元 = 1,000,000,000 纳美元 (用于成本存储,提供更高精度) + NanoUSDPerUSD = 1_000_000_000 // TokensPerMillion 百万tokens TokensPerMillion = 1_000_000 + // MicroToNano 微美元转纳美元的倍数 + MicroToNano = 1000 +) + +var ( + bigTokensPerMillion = big.NewInt(TokensPerMillion) + bigMicroToNano = big.NewInt(MicroToNano) ) -// CalculateTieredCostMicro 计算分层定价成本(整数运算) +// CalculateTieredCost 计算分层定价成本(使用 big.Int 防止溢出) // tokens: token数量 // basePriceMicro: 基础价格 (microUSD/M tokens) // premiumNum, premiumDenom: 超阈值倍率(分数表示,如 2.0 = 2/1, 1.5 = 3/2) // threshold: 阈值 token 数 -// 返回: 微美元成本 -func CalculateTieredCostMicro(tokens uint64, basePriceMicro uint64, premiumNum, premiumDenom, threshold uint64) uint64 { +// 返回: 纳美元成本 (nanoUSD) +func CalculateTieredCost(tokens uint64, basePriceMicro uint64, premiumNum, premiumDenom, threshold uint64) uint64 { if tokens <= threshold { - return tokens * basePriceMicro / TokensPerMillion + return calculateLinearCostBig(tokens, basePriceMicro) } - baseCost := threshold * basePriceMicro / TokensPerMillion + + baseCostNano := calculateLinearCostBig(threshold, basePriceMicro) premiumTokens := tokens - threshold - // premiumCost = premiumTokens * basePriceMicro * (premiumNum/premiumDenom) / TokensPerMillion - // 重排以避免溢出: (premiumTokens * basePriceMicro / TokensPerMillion) * premiumNum / premiumDenom - premiumCost := premiumTokens * basePriceMicro / TokensPerMillion * premiumNum / premiumDenom - return baseCost + premiumCost + + // premiumCost = premiumTokens * basePriceMicro * MicroToNano / TokensPerMillion * premiumNum / premiumDenom + t := big.NewInt(0).SetUint64(premiumTokens) + p := big.NewInt(0).SetUint64(basePriceMicro) + num := big.NewInt(0).SetUint64(premiumNum) + denom := big.NewInt(0).SetUint64(premiumDenom) + + // t * p * MicroToNano * num / TokensPerMillion / denom + t.Mul(t, p) + t.Mul(t, bigMicroToNano) + t.Mul(t, num) + t.Div(t, bigTokensPerMillion) + t.Div(t, denom) + + return baseCostNano + t.Uint64() } -// CalculateLinearCostMicro 计算线性定价成本(整数运算) +// CalculateLinearCost 计算线性定价成本(使用 big.Int 防止溢出) // tokens: token数量 // priceMicro: 价格 (microUSD/M tokens) -// 返回: 微美元成本 +// 返回: 纳美元成本 (nanoUSD) +func CalculateLinearCost(tokens, priceMicro uint64) uint64 { + return calculateLinearCostBig(tokens, priceMicro) +} + +// calculateLinearCostBig 使用 big.Int 计算线性成本 +func calculateLinearCostBig(tokens, priceMicro uint64) uint64 { + // cost = tokens * priceMicro * MicroToNano / TokensPerMillion + t := big.NewInt(0).SetUint64(tokens) + p := big.NewInt(0).SetUint64(priceMicro) + + t.Mul(t, p) + t.Mul(t, bigMicroToNano) + t.Div(t, bigTokensPerMillion) + + return t.Uint64() +} + +// Deprecated: 使用 CalculateTieredCost 代替 +func CalculateTieredCostMicro(tokens uint64, basePriceMicro uint64, premiumNum, premiumDenom, threshold uint64) uint64 { + return CalculateTieredCost(tokens, basePriceMicro, premiumNum, premiumDenom, threshold) / MicroToNano +} + +// Deprecated: 使用 CalculateLinearCost 代替 func CalculateLinearCostMicro(tokens, priceMicro uint64) uint64 { - return tokens * priceMicro / TokensPerMillion + return CalculateLinearCost(tokens, priceMicro) / MicroToNano +} + +// NanoToUSD 将纳美元转换为美元(用于显示) +func NanoToUSD(nanoUSD uint64) float64 { + return float64(nanoUSD) / NanoUSDPerUSD } // MicroToUSD 将微美元转换为美元(用于显示) diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 1f4d1317..00b29201 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -59,6 +59,12 @@ type SessionRepository interface { List() ([]*domain.Session, error) } +// ProxyRequestFilter 请求列表过滤条件 +type ProxyRequestFilter struct { + ProviderID *uint64 // Provider ID,nil 表示不过滤 + Status *string // 状态,nil 表示不过滤 +} + type ProxyRequestRepository interface { Create(req *domain.ProxyRequest) error Update(req *domain.ProxyRequest) error @@ -67,25 +73,55 @@ type ProxyRequestRepository interface { // ListCursor 基于游标的分页查询 // before: 获取 id < before 的记录 (向后翻页) // after: 获取 id > after 的记录 (向前翻页/获取新数据) - ListCursor(limit int, before, after uint64) ([]*domain.ProxyRequest, error) + // filter: 可选的过滤条件 + ListCursor(limit int, before, after uint64, filter *ProxyRequestFilter) ([]*domain.ProxyRequest, error) // ListActive 获取所有活跃请求 (PENDING 或 IN_PROGRESS 状态) ListActive() ([]*domain.ProxyRequest, error) Count() (int64, error) + // CountWithFilter 带过滤条件的计数 + CountWithFilter(filter *ProxyRequestFilter) (int64, error) // UpdateProjectIDBySessionID 批量更新指定 sessionID 的所有请求的 projectID UpdateProjectIDBySessionID(sessionID string, projectID uint64) (int64, error) // MarkStaleAsFailed marks all IN_PROGRESS/PENDING requests from other instances as FAILED // Also marks requests that have been IN_PROGRESS for too long (> 30 minutes) as timed out MarkStaleAsFailed(currentInstanceID string) (int64, error) + // FixFailedRequestsWithoutEndTime fixes FAILED requests that have no end_time set + FixFailedRequestsWithoutEndTime() (int64, error) // DeleteOlderThan 删除指定时间之前的请求记录 DeleteOlderThan(before time.Time) (int64, error) // HasRecentRequests 检查指定时间之后是否有请求记录 HasRecentRequests(since time.Time) (bool, error) + // UpdateCost updates only the cost field of a request + UpdateCost(id uint64, cost uint64) error + // AddCost adds a delta to the cost field of a request (can be negative) + AddCost(id uint64, delta int64) error + // BatchUpdateCosts updates costs for multiple requests in a single transaction + BatchUpdateCosts(updates map[uint64]uint64) error + // RecalculateCostsFromAttempts recalculates all request costs by summing their attempt costs + RecalculateCostsFromAttempts() (int64, error) + // RecalculateCostsFromAttemptsWithProgress recalculates all request costs with progress reporting via channel + RecalculateCostsFromAttemptsWithProgress(progress chan<- domain.Progress) (int64, error) } type ProxyUpstreamAttemptRepository interface { Create(attempt *domain.ProxyUpstreamAttempt) error Update(attempt *domain.ProxyUpstreamAttempt) error ListByProxyRequestID(proxyRequestID uint64) ([]*domain.ProxyUpstreamAttempt, error) + // ListAll returns all attempts (for cost recalculation) + ListAll() ([]*domain.ProxyUpstreamAttempt, error) + // CountAll returns total count of attempts + CountAll() (int64, error) + // StreamForCostCalc iterates through all attempts for cost calculation + // Calls the callback with batches of minimal data, returns early if callback returns error + StreamForCostCalc(batchSize int, callback func(batch []*domain.AttemptCostData) error) error + // UpdateCost updates only the cost field of an attempt + UpdateCost(id uint64, cost uint64) error + // BatchUpdateCosts updates costs for multiple attempts in a single transaction + BatchUpdateCosts(updates map[uint64]uint64) error + // MarkStaleAttemptsFailed marks stale attempts as failed with proper end_time and duration + MarkStaleAttemptsFailed() (int64, error) + // FixFailedAttemptsWithoutEndTime fixes FAILED attempts that have no end_time set + FixFailedAttemptsWithoutEndTime() (int64, error) } type SystemSettingRepository interface { @@ -111,10 +147,8 @@ type UsageStatsRepository interface { Upsert(stats *domain.UsageStats) error // BatchUpsert 批量更新或插入统计记录 BatchUpsert(stats []*domain.UsageStats) error - // Query 查询统计数据,支持按粒度、时间范围、路由、Provider、项目过滤 + // Query 查询统计数据(包含当前时间桶的实时数据补全) Query(filter UsageStatsFilter) ([]*domain.UsageStats, error) - // QueryWithRealtime 查询统计数据并合并当前周期的实时数据 - QueryWithRealtime(filter UsageStatsFilter) ([]*domain.UsageStats, error) // QueryDashboardData 查询 Dashboard 所需的所有数据(单次请求,并发执行) QueryDashboardData() (*domain.DashboardData, error) // GetSummary 获取汇总统计数据(总计) @@ -135,12 +169,14 @@ type UsageStatsRepository interface { GetLatestTimeBucket(granularity domain.Granularity) (*time.Time, error) // GetProviderStats 获取 Provider 统计数据 GetProviderStats(clientType string, projectID uint64) (map[uint64]*domain.ProviderStats, error) - // AggregateMinute 从原始数据聚合到分钟级别 - AggregateMinute() (int, error) - // RollUp 从细粒度上卷到粗粒度 - RollUp(from, to domain.Granularity) (int, error) + // AggregateAndRollUp 聚合原始数据到分钟级别,并自动 rollup 到各个粗粒度 + // 返回一个 channel,发送每个阶段的进度事件,channel 会在完成后关闭 + // 调用者可以 range 遍历 channel 获取进度,或直接忽略(异步执行) + AggregateAndRollUp() <-chan domain.AggregateEvent // ClearAndRecalculate 清空统计数据并重新从原始数据计算 ClearAndRecalculate() error + // ClearAndRecalculateWithProgress 清空统计数据并重新计算,通过 channel 报告进度 + ClearAndRecalculateWithProgress(progress chan<- domain.Progress) error } // UsageStatsFilter 统计查询过滤条件 diff --git a/internal/repository/sqlite/migrations.go b/internal/repository/sqlite/migrations.go index 4188a996..e28dfc65 100644 --- a/internal/repository/sqlite/migrations.go +++ b/internal/repository/sqlite/migrations.go @@ -18,7 +18,40 @@ type Migration struct { // 所有迁移按版本号注册 // 注意:GORM AutoMigrate 会自动处理新增列,这里只需要处理特殊情况(重命名、数据迁移等) -var migrations = []Migration{} +var migrations = []Migration{ + { + Version: 1, + Description: "Convert cost from microUSD to nanoUSD (multiply by 1000)", + Up: func(db *gorm.DB) error { + // Convert cost in proxy_requests table + if err := db.Exec("UPDATE proxy_requests SET cost = cost * 1000 WHERE cost > 0").Error; err != nil { + return err + } + // Convert cost in proxy_upstream_attempts table + if err := db.Exec("UPDATE proxy_upstream_attempts SET cost = cost * 1000 WHERE cost > 0").Error; err != nil { + return err + } + // Convert cost in usage_stats table + if err := db.Exec("UPDATE usage_stats SET cost = cost * 1000 WHERE cost > 0").Error; err != nil { + return err + } + return nil + }, + Down: func(db *gorm.DB) error { + // Rollback: divide by 1000 + if err := db.Exec("UPDATE proxy_requests SET cost = cost / 1000").Error; err != nil { + return err + } + if err := db.Exec("UPDATE proxy_upstream_attempts SET cost = cost / 1000").Error; err != nil { + return err + } + if err := db.Exec("UPDATE usage_stats SET cost = cost / 1000").Error; err != nil { + return err + } + return nil + }, + }, +} // RunMigrations 运行所有待执行的迁移 func (d *DB) RunMigrations() error { diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go index 70907745..818f20aa 100644 --- a/internal/repository/sqlite/models.go +++ b/internal/repository/sqlite/models.go @@ -18,7 +18,7 @@ type LongText string // GormDBDataType returns the database-specific data type func (LongText) GormDBDataType(db *gorm.DB, _ *schema.Field) string { - switch db.Dialector.Name() { + switch db.Name() { case "mysql": return "longtext" default: @@ -62,8 +62,8 @@ func (m *BaseModel) BeforeUpdate(tx *gorm.DB) error { // Provider model type Provider struct { SoftDeleteModel - Type string `gorm:"size:64"` - Name string `gorm:"size:255"` + Type string `gorm:"size:64"` + Name string `gorm:"size:255"` Config LongText SupportedClientTypes LongText SupportModels LongText @@ -74,8 +74,8 @@ func (Provider) TableName() string { return "providers" } // Project model type Project struct { SoftDeleteModel - Name string `gorm:"size:255"` - Slug string `gorm:"size:128"` + Name string `gorm:"size:255"` + Slug string `gorm:"size:128"` EnabledCustomRoutes LongText } @@ -95,8 +95,8 @@ func (Session) TableName() string { return "sessions" } // Route model type Route struct { SoftDeleteModel - IsEnabled int `gorm:"default:1"` - IsNative int `gorm:"default:1"` + IsEnabled int `gorm:"default:1"` + IsNative int `gorm:"default:1"` ProjectID uint64 ClientType string `gorm:"size:64"` ProviderID uint64 @@ -123,7 +123,7 @@ func (RetryConfig) TableName() string { return "retry_configs" } type RoutingStrategy struct { SoftDeleteModel ProjectID uint64 - Type string `gorm:"size:64"` + Type string `gorm:"size:64"` Config LongText } @@ -132,9 +132,9 @@ func (RoutingStrategy) TableName() string { return "routing_strategies" } // APIToken model type APIToken struct { SoftDeleteModel - Token string `gorm:"size:255;uniqueIndex"` - TokenPrefix string `gorm:"size:32"` - Name string `gorm:"size:255"` + Token string `gorm:"size:255;uniqueIndex"` + TokenPrefix string `gorm:"size:32"` + Name string `gorm:"size:255"` Description LongText ProjectID uint64 IsEnabled int `gorm:"default:1"` @@ -165,13 +165,13 @@ func (ModelMapping) TableName() string { return "model_mappings" } // AntigravityQuota model type AntigravityQuota struct { SoftDeleteModel - Email string `gorm:"size:255;uniqueIndex"` - SubscriptionTier string `gorm:"size:64;default:'FREE'"` + Email string `gorm:"size:255;uniqueIndex"` + SubscriptionTier string `gorm:"size:64;default:'FREE'"` IsForbidden int Models LongText - Name string `gorm:"size:255"` + Name string `gorm:"size:255"` Picture LongText - GCPProjectID string `gorm:"size:128;column:gcp_project_id"` + GCPProjectID string `gorm:"size:128;column:gcp_project_id"` } func (AntigravityQuota) TableName() string { return "antigravity_quotas" } @@ -181,16 +181,17 @@ func (AntigravityQuota) TableName() string { return "antigravity_quotas" } // ProxyRequest model type ProxyRequest struct { BaseModel - InstanceID string `gorm:"size:64"` - RequestID string `gorm:"size:64"` - SessionID string `gorm:"size:255;index"` - ClientType string `gorm:"size:64"` - RequestModel string `gorm:"size:128"` - ResponseModel string `gorm:"size:128"` + InstanceID string `gorm:"size:64"` + RequestID string `gorm:"size:64"` + SessionID string `gorm:"size:255;index"` + ClientType string `gorm:"size:64"` + RequestModel string `gorm:"size:128"` + ResponseModel string `gorm:"size:128"` StartTime int64 - EndTime int64 + EndTime int64 `gorm:"index"` DurationMs int64 - Status string `gorm:"size:64"` + TTFTMs int64 + Status string `gorm:"size:64"` RequestInfo LongText ResponseInfo LongText Error LongText @@ -216,8 +217,8 @@ func (ProxyRequest) TableName() string { return "proxy_requests" } // ProxyUpstreamAttempt model type ProxyUpstreamAttempt struct { BaseModel - Status string `gorm:"size:64"` - ProxyRequestID uint64 `gorm:"index"` + Status string `gorm:"size:64"` + ProxyRequestID uint64 `gorm:"index"` RequestInfo LongText ResponseInfo LongText RouteID uint64 @@ -233,6 +234,7 @@ type ProxyUpstreamAttempt struct { StartTime int64 EndTime int64 DurationMs int64 + TTFTMs int64 RequestModel string `gorm:"size:128"` MappedModel string `gorm:"size:128"` ResponseModel string `gorm:"size:128"` @@ -242,7 +244,7 @@ func (ProxyUpstreamAttempt) TableName() string { return "proxy_upstream_attempts // SystemSetting model type SystemSetting struct { - Key string `gorm:"column:setting_key;size:255;primaryKey"` + Key string `gorm:"column:setting_key;size:255;primaryKey"` Value LongText CreatedAt int64 UpdatedAt int64 @@ -289,6 +291,7 @@ type UsageStats struct { SuccessfulRequests uint64 FailedRequests uint64 TotalDurationMs uint64 + TotalTTFTMs uint64 InputTokens uint64 OutputTokens uint64 CacheRead uint64 diff --git a/internal/repository/sqlite/proxy_request.go b/internal/repository/sqlite/proxy_request.go index bebcb2a5..aa437dfa 100644 --- a/internal/repository/sqlite/proxy_request.go +++ b/internal/repository/sqlite/proxy_request.go @@ -2,10 +2,13 @@ package sqlite import ( "errors" + "fmt" + "strings" "sync/atomic" "time" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/repository" "gorm.io/gorm" ) @@ -74,11 +77,12 @@ func (r *ProxyRequestRepository) List(limit, offset int) ([]*domain.ProxyRequest // ListCursor 基于游标的分页查询,比 OFFSET 更高效 // before: 获取 id < before 的记录 (向后翻页) // after: 获取 id > after 的记录 (向前翻页/获取新数据) +// filter: 可选的过滤条件 // 注意:列表查询不返回 request_info 和 response_info 大字段 -func (r *ProxyRequestRepository) ListCursor(limit int, before, after uint64) ([]*domain.ProxyRequest, error) { +func (r *ProxyRequestRepository) ListCursor(limit int, before, after uint64, filter *repository.ProxyRequestFilter) ([]*domain.ProxyRequest, error) { // 使用 Select 排除大字段 query := r.db.gorm.Model(&ProxyRequest{}). - Select("id, created_at, updated_at, instance_id, request_id, session_id, client_type, request_model, response_model, start_time, end_time, duration_ms, is_stream, status, status_code, error, proxy_upstream_attempt_count, final_proxy_upstream_attempt_id, route_id, provider_id, project_id, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost, api_token_id") + Select("id, created_at, updated_at, instance_id, request_id, session_id, client_type, request_model, response_model, start_time, end_time, duration_ms, ttft_ms, is_stream, status, status_code, error, proxy_upstream_attempt_count, final_proxy_upstream_attempt_id, route_id, provider_id, project_id, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost, api_token_id") if after > 0 { query = query.Where("id > ?", after) @@ -86,8 +90,20 @@ func (r *ProxyRequestRepository) ListCursor(limit int, before, after uint64) ([] query = query.Where("id < ?", before) } + // 应用过滤条件 + if filter != nil { + if filter.ProviderID != nil { + query = query.Where("provider_id = ?", *filter.ProviderID) + } + if filter.Status != nil { + query = query.Where("status = ?", *filter.Status) + } + } + var models []ProxyRequest - if err := query.Order("id DESC").Limit(limit).Find(&models).Error; err != nil { + // 按结束时间排序:未完成的请求(end_time=0)在最前面,已完成的按 end_time DESC 排序 + // SQLite 不支持 NULLS FIRST,使用 CASE WHEN 实现 + if err := query.Order("CASE WHEN end_time = 0 THEN 0 ELSE 1 END, end_time DESC, id DESC").Limit(limit).Find(&models).Error; err != nil { return nil, err } return r.toDomainList(models), nil @@ -110,13 +126,37 @@ func (r *ProxyRequestRepository) Count() (int64, error) { return atomic.LoadInt64(&r.count), nil } +// CountWithFilter 带过滤条件的计数 +func (r *ProxyRequestRepository) CountWithFilter(filter *repository.ProxyRequestFilter) (int64, error) { + // 如果没有过滤条件,使用缓存的总数 + if filter == nil || (filter.ProviderID == nil && filter.Status == nil) { + return atomic.LoadInt64(&r.count), nil + } + + // 有过滤条件时需要查询数据库 + var count int64 + query := r.db.gorm.Model(&ProxyRequest{}) + if filter.ProviderID != nil { + query = query.Where("provider_id = ?", *filter.ProviderID) + } + if filter.Status != nil { + query = query.Where("status = ?", *filter.Status) + } + if err := query.Count(&count).Error; err != nil { + return 0, err + } + return count, nil +} + // MarkStaleAsFailed marks all IN_PROGRESS/PENDING requests from other instances as FAILED // Also marks requests that have been IN_PROGRESS for too long (> 30 minutes) as timed out +// Sets proper end_time and duration_ms for complete failure handling func (r *ProxyRequestRepository) MarkStaleAsFailed(currentInstanceID string) (int64, error) { timeoutThreshold := time.Now().Add(-30 * time.Minute).UnixMilli() now := time.Now().UnixMilli() // Use raw SQL for complex CASE expression + // Sets end_time = now and calculates duration_ms = now - start_time result := r.db.gorm.Exec(` UPDATE proxy_requests SET status = 'FAILED', @@ -124,13 +164,41 @@ func (r *ProxyRequestRepository) MarkStaleAsFailed(currentInstanceID string) (in WHEN instance_id IS NULL OR instance_id != ? THEN 'Server restarted' ELSE 'Request timed out (stuck in progress)' END, + end_time = ?, + duration_ms = CASE + WHEN start_time > 0 THEN ? - start_time + ELSE 0 + END, updated_at = ? WHERE status IN ('PENDING', 'IN_PROGRESS') AND ( (instance_id IS NULL OR instance_id != ?) OR (start_time < ? AND start_time > 0) )`, - currentInstanceID, now, currentInstanceID, timeoutThreshold, + currentInstanceID, now, now, now, currentInstanceID, timeoutThreshold, + ) + if result.Error != nil { + return 0, result.Error + } + return result.RowsAffected, nil +} + +// FixFailedRequestsWithoutEndTime fixes FAILED requests that have no end_time set +// This handles legacy data where end_time was not properly set +func (r *ProxyRequestRepository) FixFailedRequestsWithoutEndTime() (int64, error) { + now := time.Now().UnixMilli() + + result := r.db.gorm.Exec(` + UPDATE proxy_requests + SET end_time = CASE + WHEN start_time > 0 THEN start_time + ELSE ? + END, + duration_ms = 0, + updated_at = ? + WHERE status = 'FAILED' + AND end_time = 0`, + now, now, ) if result.Error != nil { return 0, result.Error @@ -197,6 +265,153 @@ func (r *ProxyRequestRepository) HasRecentRequests(since time.Time) (bool, error return count > 0, nil } +// UpdateCost updates only the cost field of a request +func (r *ProxyRequestRepository) UpdateCost(id uint64, cost uint64) error { + return r.db.gorm.Model(&ProxyRequest{}).Where("id = ?", id).Update("cost", cost).Error +} + +// AddCost adds a delta to the cost field of a request (can be negative) +func (r *ProxyRequestRepository) AddCost(id uint64, delta int64) error { + return r.db.gorm.Model(&ProxyRequest{}).Where("id = ?", id). + Update("cost", gorm.Expr("cost + ?", delta)).Error +} + +// BatchUpdateCosts updates costs for multiple requests in a single transaction +func (r *ProxyRequestRepository) BatchUpdateCosts(updates map[uint64]uint64) error { + if len(updates) == 0 { + return nil + } + + return r.db.gorm.Transaction(func(tx *gorm.DB) error { + // Use CASE WHEN for batch update + const batchSize = 500 + ids := make([]uint64, 0, len(updates)) + for id := range updates { + ids = append(ids, id) + } + + for i := 0; i < len(ids); i += batchSize { + end := i + batchSize + if end > len(ids) { + end = len(ids) + } + batchIDs := ids[i:end] + + // Build CASE WHEN statement + var cases strings.Builder + cases.WriteString("CASE id ") + args := make([]interface{}, 0, len(batchIDs)*3+1) + + // First: CASE WHEN pairs (id, cost) + for _, id := range batchIDs { + cases.WriteString("WHEN ? THEN ? ") + args = append(args, id, updates[id]) + } + cases.WriteString("END") + + // Second: timestamp for updated_at + args = append(args, time.Now().UnixMilli()) + + // Third: WHERE IN ids + for _, id := range batchIDs { + args = append(args, id) + } + + sql := fmt.Sprintf("UPDATE proxy_requests SET cost = %s, updated_at = ? WHERE id IN (?%s)", + cases.String(), strings.Repeat(",?", len(batchIDs)-1)) + + if err := tx.Exec(sql, args...).Error; err != nil { + return err + } + } + return nil + }) +} + +// RecalculateCostsFromAttempts recalculates all request costs by summing their attempt costs +func (r *ProxyRequestRepository) RecalculateCostsFromAttempts() (int64, error) { + return r.RecalculateCostsFromAttemptsWithProgress(nil) +} + +// RecalculateCostsFromAttemptsWithProgress recalculates all request costs with progress reporting via channel +func (r *ProxyRequestRepository) RecalculateCostsFromAttemptsWithProgress(progress chan<- domain.Progress) (int64, error) { + sendProgress := func(current, total int, message string) { + if progress == nil { + return + } + percentage := 0 + if total > 0 { + percentage = current * 100 / total + } + progress <- domain.Progress{ + Phase: "updating_requests", + Current: current, + Total: total, + Percentage: percentage, + Message: message, + } + } + + // 1. 获取所有 request IDs + var requestIDs []uint64 + err := r.db.gorm.Model(&ProxyRequest{}).Pluck("id", &requestIDs).Error + if err != nil { + return 0, err + } + + total := len(requestIDs) + if total == 0 { + return 0, nil + } + + // 报告初始进度 + sendProgress(0, total, fmt.Sprintf("Updating %d requests...", total)) + + // 2. 分批处理 + const batchSize = 100 + now := time.Now().UnixMilli() + var totalUpdated int64 + + for i := 0; i < total; i += batchSize { + end := i + batchSize + if end > total { + end = total + } + batchIDs := requestIDs[i:end] + + // 使用子查询批量更新 + placeholders := make([]string, len(batchIDs)) + args := make([]interface{}, 0, len(batchIDs)+1) + args = append(args, now) + for j, id := range batchIDs { + placeholders[j] = "?" + args = append(args, id) + } + + sql := fmt.Sprintf(` + UPDATE proxy_requests + SET cost = ( + SELECT COALESCE(SUM(cost), 0) + FROM proxy_upstream_attempts + WHERE proxy_request_id = proxy_requests.id + ), + updated_at = ? + WHERE id IN (%s) + `, strings.Join(placeholders, ",")) + + result := r.db.gorm.Exec(sql, args...) + if result.Error != nil { + return totalUpdated, result.Error + } + totalUpdated += result.RowsAffected + + // 报告进度 + sendProgress(end, total, fmt.Sprintf("Updating requests: %d/%d", end, total)) + } + + return totalUpdated, nil +} + func (r *ProxyRequestRepository) toModel(p *domain.ProxyRequest) *ProxyRequest { return &ProxyRequest{ BaseModel: BaseModel{ @@ -213,6 +428,7 @@ func (r *ProxyRequestRepository) toModel(p *domain.ProxyRequest) *ProxyRequest { StartTime: toTimestamp(p.StartTime), EndTime: toTimestamp(p.EndTime), DurationMs: p.Duration.Milliseconds(), + TTFTMs: p.TTFT.Milliseconds(), IsStream: boolToInt(p.IsStream), Status: p.Status, StatusCode: p.StatusCode, @@ -249,6 +465,7 @@ func (r *ProxyRequestRepository) toDomain(m *ProxyRequest) *domain.ProxyRequest StartTime: fromTimestamp(m.StartTime), EndTime: fromTimestamp(m.EndTime), Duration: time.Duration(m.DurationMs) * time.Millisecond, + TTFT: time.Duration(m.TTFTMs) * time.Millisecond, IsStream: m.IsStream == 1, Status: m.Status, StatusCode: m.StatusCode, diff --git a/internal/repository/sqlite/proxy_upstream_attempt.go b/internal/repository/sqlite/proxy_upstream_attempt.go index 3fdff810..899709fa 100644 --- a/internal/repository/sqlite/proxy_upstream_attempt.go +++ b/internal/repository/sqlite/proxy_upstream_attempt.go @@ -1,9 +1,12 @@ package sqlite import ( + "fmt" + "strings" "time" "github.com/awsl-project/maxx/internal/domain" + "gorm.io/gorm" ) type ProxyUpstreamAttemptRepository struct { @@ -41,6 +44,198 @@ func (r *ProxyUpstreamAttemptRepository) ListByProxyRequestID(proxyRequestID uin return r.toDomainList(models), nil } +func (r *ProxyUpstreamAttemptRepository) ListAll() ([]*domain.ProxyUpstreamAttempt, error) { + var models []ProxyUpstreamAttempt + if err := r.db.gorm.Order("id").Find(&models).Error; err != nil { + return nil, err + } + return r.toDomainList(models), nil +} + +func (r *ProxyUpstreamAttemptRepository) CountAll() (int64, error) { + var count int64 + if err := r.db.gorm.Model(&ProxyUpstreamAttempt{}).Count(&count).Error; err != nil { + return 0, err + } + return count, nil +} + +// StreamForCostCalc iterates through all attempts in batches for cost calculation +// Only fetches fields needed for cost calculation, avoiding expensive JSON parsing +func (r *ProxyUpstreamAttemptRepository) StreamForCostCalc(batchSize int, callback func(batch []*domain.AttemptCostData) error) error { + var lastID uint64 = 0 + + for { + var results []struct { + ID uint64 `gorm:"column:id"` + ProxyRequestID uint64 `gorm:"column:proxy_request_id"` + ResponseModel string `gorm:"column:response_model"` + MappedModel string `gorm:"column:mapped_model"` + RequestModel string `gorm:"column:request_model"` + InputTokenCount uint64 `gorm:"column:input_token_count"` + OutputTokenCount uint64 `gorm:"column:output_token_count"` + CacheReadCount uint64 `gorm:"column:cache_read_count"` + CacheWriteCount uint64 `gorm:"column:cache_write_count"` + Cache5mWriteCount uint64 `gorm:"column:cache_5m_write_count"` + Cache1hWriteCount uint64 `gorm:"column:cache_1h_write_count"` + Cost uint64 `gorm:"column:cost"` + } + + err := r.db.gorm.Table("proxy_upstream_attempts"). + Select("id, proxy_request_id, response_model, mapped_model, request_model, input_token_count, output_token_count, cache_read_count, cache_write_count, cache_5m_write_count, cache_1h_write_count, cost"). + Where("id > ?", lastID). + Order("id"). + Limit(batchSize). + Find(&results).Error + + if err != nil { + return err + } + + if len(results) == 0 { + break + } + + // Convert to domain type + batch := make([]*domain.AttemptCostData, len(results)) + for i, r := range results { + batch[i] = &domain.AttemptCostData{ + ID: r.ID, + ProxyRequestID: r.ProxyRequestID, + ResponseModel: r.ResponseModel, + MappedModel: r.MappedModel, + RequestModel: r.RequestModel, + InputTokenCount: r.InputTokenCount, + OutputTokenCount: r.OutputTokenCount, + CacheReadCount: r.CacheReadCount, + CacheWriteCount: r.CacheWriteCount, + Cache5mWriteCount: r.Cache5mWriteCount, + Cache1hWriteCount: r.Cache1hWriteCount, + Cost: r.Cost, + } + } + + if err := callback(batch); err != nil { + return err + } + + lastID = results[len(results)-1].ID + + if len(results) < batchSize { + break + } + } + + return nil +} + +func (r *ProxyUpstreamAttemptRepository) UpdateCost(id uint64, cost uint64) error { + return r.db.gorm.Model(&ProxyUpstreamAttempt{}).Where("id = ?", id).Update("cost", cost).Error +} + +// MarkStaleAttemptsFailed marks all IN_PROGRESS/PENDING attempts belonging to stale requests as FAILED +// This should be called after MarkStaleAsFailed on proxy_requests to clean up orphaned attempts +// Sets proper end_time and duration_ms for complete failure handling +func (r *ProxyUpstreamAttemptRepository) MarkStaleAttemptsFailed() (int64, error) { + now := time.Now().UnixMilli() + + // Update attempts that belong to FAILED requests but are still in progress + result := r.db.gorm.Exec(` + UPDATE proxy_upstream_attempts + SET status = 'FAILED', + end_time = ?, + duration_ms = CASE + WHEN start_time > 0 THEN ? - start_time + ELSE 0 + END, + updated_at = ? + WHERE status IN ('PENDING', 'IN_PROGRESS') + AND proxy_request_id IN ( + SELECT id FROM proxy_requests WHERE status = 'FAILED' + )`, + now, now, now, + ) + if result.Error != nil { + return 0, result.Error + } + return result.RowsAffected, nil +} + +// FixFailedAttemptsWithoutEndTime fixes FAILED attempts that have no end_time set +// This handles legacy data where end_time was not properly set +func (r *ProxyUpstreamAttemptRepository) FixFailedAttemptsWithoutEndTime() (int64, error) { + now := time.Now().UnixMilli() + + result := r.db.gorm.Exec(` + UPDATE proxy_upstream_attempts + SET end_time = CASE + WHEN start_time > 0 THEN start_time + ELSE ? + END, + duration_ms = 0, + updated_at = ? + WHERE status = 'FAILED' + AND end_time = 0`, + now, now, + ) + if result.Error != nil { + return 0, result.Error + } + return result.RowsAffected, nil +} + +// BatchUpdateCosts updates costs for multiple attempts in a single transaction +func (r *ProxyUpstreamAttemptRepository) BatchUpdateCosts(updates map[uint64]uint64) error { + if len(updates) == 0 { + return nil + } + + return r.db.gorm.Transaction(func(tx *gorm.DB) error { + // Use CASE WHEN for batch update + const batchSize = 500 + ids := make([]uint64, 0, len(updates)) + for id := range updates { + ids = append(ids, id) + } + + for i := 0; i < len(ids); i += batchSize { + end := i + batchSize + if end > len(ids) { + end = len(ids) + } + batchIDs := ids[i:end] + + // Build CASE WHEN statement + var cases strings.Builder + cases.WriteString("CASE id ") + args := make([]interface{}, 0, len(batchIDs)*3+1) + + // First: CASE WHEN pairs (id, cost) + for _, id := range batchIDs { + cases.WriteString("WHEN ? THEN ? ") + args = append(args, id, updates[id]) + } + cases.WriteString("END") + + // Second: timestamp for updated_at + args = append(args, time.Now().UnixMilli()) + + // Third: WHERE IN ids + for _, id := range batchIDs { + args = append(args, id) + } + + sql := fmt.Sprintf("UPDATE proxy_upstream_attempts SET cost = %s, updated_at = ? WHERE id IN (?%s)", + cases.String(), strings.Repeat(",?", len(batchIDs)-1)) + + if err := tx.Exec(sql, args...).Error; err != nil { + return err + } + } + return nil + }) +} + func (r *ProxyUpstreamAttemptRepository) toModel(a *domain.ProxyUpstreamAttempt) *ProxyUpstreamAttempt { return &ProxyUpstreamAttempt{ BaseModel: BaseModel{ @@ -51,6 +246,7 @@ func (r *ProxyUpstreamAttemptRepository) toModel(a *domain.ProxyUpstreamAttempt) StartTime: toTimestamp(a.StartTime), EndTime: toTimestamp(a.EndTime), DurationMs: a.Duration.Milliseconds(), + TTFTMs: a.TTFT.Milliseconds(), Status: a.Status, ProxyRequestID: a.ProxyRequestID, IsStream: boolToInt(a.IsStream), @@ -79,6 +275,7 @@ func (r *ProxyUpstreamAttemptRepository) toDomain(m *ProxyUpstreamAttempt) *doma StartTime: fromTimestamp(m.StartTime), EndTime: fromTimestamp(m.EndTime), Duration: time.Duration(m.DurationMs) * time.Millisecond, + TTFT: time.Duration(m.TTFTMs) * time.Millisecond, Status: m.Status, ProxyRequestID: m.ProxyRequestID, IsStream: m.IsStream == 1, diff --git a/internal/repository/sqlite/usage_stats.go b/internal/repository/sqlite/usage_stats.go index c61e1818..f609446c 100644 --- a/internal/repository/sqlite/usage_stats.go +++ b/internal/repository/sqlite/usage_stats.go @@ -9,6 +9,7 @@ import ( "github.com/awsl-project/maxx/internal/domain" "github.com/awsl-project/maxx/internal/repository" + "github.com/awsl-project/maxx/internal/stats" "golang.org/x/sync/errgroup" "gorm.io/gorm/clause" ) @@ -40,54 +41,6 @@ func (r *UsageStatsRepository) getConfiguredTimezone() *time.Location { return loc } -// TruncateToGranularity 将时间截断到指定粒度的时间桶(使用 UTC) -func TruncateToGranularity(t time.Time, g domain.Granularity) time.Time { - t = t.UTC() - switch g { - case domain.GranularityMinute: - return t.Truncate(time.Minute) - case domain.GranularityHour: - return t.Truncate(time.Hour) - case domain.GranularityDay: - return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) - case domain.GranularityWeek: - // 截断到周一 - weekday := int(t.Weekday()) - if weekday == 0 { - weekday = 7 - } - return time.Date(t.Year(), t.Month(), t.Day()-(weekday-1), 0, 0, 0, 0, time.UTC) - case domain.GranularityMonth: - return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC) - default: - return t.Truncate(time.Hour) - } -} - -// TruncateToGranularityInTimezone 将时间截断到指定粒度的时间桶(使用指定时区) -func TruncateToGranularityInTimezone(t time.Time, g domain.Granularity, loc *time.Location) time.Time { - t = t.In(loc) - switch g { - case domain.GranularityMinute: - return t.Truncate(time.Minute) - case domain.GranularityHour: - return t.Truncate(time.Hour) - case domain.GranularityDay: - return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc) - case domain.GranularityWeek: - // 截断到周一 - weekday := int(t.Weekday()) - if weekday == 0 { - weekday = 7 - } - return time.Date(t.Year(), t.Month(), t.Day()-(weekday-1), 0, 0, 0, 0, loc) - case domain.GranularityMonth: - return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc) - default: - return t.Truncate(time.Hour) - } -} - // Upsert 更新或插入统计记录 func (r *UsageStatsRepository) Upsert(stats *domain.UsageStats) error { now := time.Now() @@ -110,6 +63,7 @@ func (r *UsageStatsRepository) Upsert(stats *domain.UsageStats) error { "successful_requests": stats.SuccessfulRequests, "failed_requests": stats.FailedRequests, "total_duration_ms": stats.TotalDurationMs, + "total_ttft_ms": stats.TotalTTFTMs, "input_tokens": stats.InputTokens, "output_tokens": stats.OutputTokens, "cache_read": stats.CacheRead, @@ -131,8 +85,8 @@ func (r *UsageStatsRepository) BatchUpsert(stats []*domain.UsageStats) error { return nil } -// Query 查询统计数据 -func (r *UsageStatsRepository) Query(filter repository.UsageStatsFilter) ([]*domain.UsageStats, error) { +// queryHistorical 查询预聚合的历史统计数据(内部方法) +func (r *UsageStatsRepository) queryHistorical(filter repository.UsageStatsFilter) ([]*domain.UsageStats, error) { var conditions []string var args []interface{} @@ -183,49 +137,48 @@ func (r *UsageStatsRepository) Query(filter repository.UsageStatsFilter) ([]*dom return r.toDomainList(models), nil } -// QueryWithRealtime 查询统计数据并补全当前时间桶的数据 +// Query 查询统计数据并补全当前时间桶的数据 // 策略(分层查询,每层用最粗粒度的预聚合数据): // - 历史时间桶:使用目标粒度的预聚合数据 -// - 当前时间桶:week → day → hour → minute → 最近 2 分钟实时 +// - 当前时间桶:day → hour → minute → 最近 2 分钟实时 // // 示例(查询 month 粒度,当前是 1月17日 10:30): -// - 1月1日-1月5日(第1周): usage_stats (granularity='week') -// - 1月6日-1月12日(第2周): usage_stats (granularity='week') -// - 1月13日-1月16日: usage_stats (granularity='day') +// - 1月1日-1月16日: usage_stats (granularity='day') // - 1月17日 00:00-09:00: usage_stats (granularity='hour') // - 1月17日 10:00-10:28: usage_stats (granularity='minute') // - 1月17日 10:29-10:30: proxy_upstream_attempts (实时) -func (r *UsageStatsRepository) QueryWithRealtime(filter repository.UsageStatsFilter) ([]*domain.UsageStats, error) { - now := time.Now().UTC() - currentBucket := TruncateToGranularity(now, filter.Granularity) - currentWeek := TruncateToGranularity(now, domain.GranularityWeek) - currentDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) +func (r *UsageStatsRepository) Query(filter repository.UsageStatsFilter) ([]*domain.UsageStats, error) { + loc := r.getConfiguredTimezone() + now := time.Now().In(loc) + currentBucket := stats.TruncateToGranularity(now, filter.Granularity, loc) + currentMonth := stats.TruncateToGranularity(now, domain.GranularityMonth, loc) + currentDay := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc) currentHour := now.Truncate(time.Hour) currentMinute := now.Truncate(time.Minute) twoMinutesAgo := currentMinute.Add(-time.Minute) - // 判断是否需要补全当前时间桶 - needCurrentBucket := filter.EndTime == nil || !filter.EndTime.Before(currentBucket) + // 判断是否需要补全实时数据(仅当查询范围包含最近 2 分钟内的数据) + // 如果 EndTime 在 2 分钟之前,说明是纯历史查询,预聚合数据已完整覆盖 + needRealtimeData := filter.EndTime == nil || !filter.EndTime.Before(twoMinutesAgo) // 1. 查询历史数据(使用目标粒度的预聚合数据) - // 如果需要补全当前时间桶,则排除当前时间桶(避免查出会被替换的数据) + // 如果需要补全实时数据,则排除当前时间桶(避免查出会被替换的数据) historyFilter := filter - if needCurrentBucket { + if needRealtimeData { endTime := currentBucket.Add(-time.Millisecond) // 排除当前时间桶 historyFilter.EndTime = &endTime } - results, err := r.Query(historyFilter) + results, err := r.queryHistorical(historyFilter) if err != nil { return nil, err } - if !needCurrentBucket { + if !needRealtimeData { return results, nil } // 2. 对于当前时间桶,并发分层查询(每层用最粗粒度的预聚合数据): - // - 已完成的周: usage_stats (granularity='week') [仅 month 粒度] - // - 已完成的天: usage_stats (granularity='day') [week/month 粒度] + // - 已完成的天: usage_stats (granularity='day') [month 粒度] // - 已完成的小时: usage_stats (granularity='hour') // - 已完成的分钟: usage_stats (granularity='minute') // - 最近 2 分钟: proxy_upstream_attempts (实时) @@ -236,24 +189,10 @@ func (r *UsageStatsRepository) QueryWithRealtime(filter repository.UsageStatsFil g errgroup.Group ) - // 2a. 查询当前时间桶内已完成的周数据 (仅 month 粒度需要) - if filter.Granularity == domain.GranularityMonth && currentWeek.After(currentBucket) { - g.Go(func() error { - weekStats, err := r.queryStatsInRange(domain.GranularityWeek, currentBucket, currentWeek, filter) - if err != nil { - return err - } - mu.Lock() - allStats = append(allStats, weekStats...) - mu.Unlock() - return nil - }) - } - - // 2b. 查询当前周(或当前时间桶)内已完成的天数据 (week/month 粒度需要) - if filter.Granularity == domain.GranularityWeek || filter.Granularity == domain.GranularityMonth { - dayStart := currentWeek - if currentBucket.After(currentWeek) { + // 2a. 查询当前月(或当前时间桶)内已完成的天数据 (month 粒度需要) + if filter.Granularity == domain.GranularityMonth { + dayStart := currentMonth + if currentBucket.After(currentMonth) { dayStart = currentBucket } if currentDay.After(dayStart) { @@ -270,7 +209,7 @@ func (r *UsageStatsRepository) QueryWithRealtime(filter repository.UsageStatsFil } } - // 2c. 查询今天(或当前时间桶)内已完成的小时数据 + // 2b. 查询今天(或当前时间桶)内已完成的小时数据 hourStart := currentDay if currentBucket.After(currentDay) { hourStart = currentBucket @@ -288,7 +227,7 @@ func (r *UsageStatsRepository) QueryWithRealtime(filter repository.UsageStatsFil }) } - // 2d. 查询当前小时内已完成的分钟数据(不包括最近 2 分钟) + // 2c. 查询当前小时内已完成的分钟数据(不包括最近 2 分钟) minuteStart := currentHour if currentBucket.After(currentHour) { minuteStart = currentBucket @@ -306,7 +245,7 @@ func (r *UsageStatsRepository) QueryWithRealtime(filter repository.UsageStatsFil }) } - // 2e. 查询最近 2 分钟的实时数据 + // 2d. 查询最近 2 分钟的实时数据 g.Go(func() error { realtimeStats, err := r.queryRecentMinutesStats(twoMinutesAgo, filter) if err != nil { @@ -323,11 +262,16 @@ func (r *UsageStatsRepository) QueryWithRealtime(filter repository.UsageStatsFil return nil, err } - // 3. 将所有数据聚合为当前时间桶 - currentBucketStats := r.aggregateToTargetBucket(allStats, currentBucket, filter.Granularity) - - // 4. 将当前时间桶数据合并到结果中(替换预聚合数据) - results = r.mergeCurrentBucketStats(results, currentBucketStats, currentBucket, filter.Granularity) + // 3. 对于分钟粒度,直接将实时数据合并(保留各分钟的独立数据) + // 对于其他粒度,将所有数据聚合为当前时间桶 + if filter.Granularity == domain.GranularityMinute { + // 分钟粒度:直接合并实时分钟数据,每个分钟保持独立 + results = r.mergeRealtimeMinuteStats(results, allStats, currentBucket) + } else { + // 其他粒度:聚合到当前时间桶 + currentBucketStats := r.aggregateToTargetBucket(allStats, currentBucket, filter.Granularity) + results = r.mergeCurrentBucketStats(results, currentBucketStats, currentBucket, filter.Granularity) + } return results, nil } @@ -404,6 +348,7 @@ func (r *UsageStatsRepository) aggregateToTargetBucket( existing.SuccessfulRequests += s.SuccessfulRequests existing.FailedRequests += s.FailedRequests existing.TotalDurationMs += s.TotalDurationMs + existing.TotalTTFTMs += s.TotalTTFTMs existing.InputTokens += s.InputTokens existing.OutputTokens += s.OutputTokens existing.CacheRead += s.CacheRead @@ -423,6 +368,7 @@ func (r *UsageStatsRepository) aggregateToTargetBucket( SuccessfulRequests: s.SuccessfulRequests, FailedRequests: s.FailedRequests, TotalDurationMs: s.TotalDurationMs, + TotalTTFTMs: s.TotalTTFTMs, InputTokens: s.InputTokens, OutputTokens: s.OutputTokens, CacheRead: s.CacheRead, @@ -449,7 +395,7 @@ func (r *UsageStatsRepository) mergeCurrentBucketStats( // 移除结果中已有的当前时间桶数据(预聚合的可能不完整) filtered := make([]*domain.UsageStats, 0, len(results)) for _, s := range results { - if !(s.TimeBucket.Equal(targetBucket) && s.Granularity == granularity) { + if !s.TimeBucket.Equal(targetBucket) || s.Granularity != granularity { filtered = append(filtered, s) } } @@ -458,8 +404,49 @@ func (r *UsageStatsRepository) mergeCurrentBucketStats( return append(currentBucketStats, filtered...) } +// mergeRealtimeMinuteStats 合并实时分钟数据到结果中(分钟粒度专用) +// 保留各分钟的独立数据,替换预聚合中对应分钟桶的数据 +func (r *UsageStatsRepository) mergeRealtimeMinuteStats( + results []*domain.UsageStats, + realtimeStats []*domain.UsageStats, + currentBucket time.Time, +) []*domain.UsageStats { + if len(realtimeStats) == 0 { + return results + } + + // 收集实时数据中的所有分钟桶时间 + realtimeBuckets := make(map[int64]bool) + for _, s := range realtimeStats { + realtimeBuckets[s.TimeBucket.UnixMilli()] = true + } + + // 从历史结果中移除这些分钟桶的数据(将被实时数据替换) + filtered := make([]*domain.UsageStats, 0, len(results)) + for _, s := range results { + if s.Granularity != domain.GranularityMinute || !realtimeBuckets[s.TimeBucket.UnixMilli()] { + filtered = append(filtered, s) + } + } + + // 合并实时数据和历史数据,按时间倒序排列 + merged := append(realtimeStats, filtered...) + + // 按 TimeBucket 倒序排列 + for i := 0; i < len(merged)-1; i++ { + for j := i + 1; j < len(merged); j++ { + if merged[j].TimeBucket.After(merged[i].TimeBucket) { + merged[i], merged[j] = merged[j], merged[i] + } + } + } + + return merged +} + // queryRecentMinutesStats 查询最近 2 分钟的实时统计数据 // 只查询已完成的请求,使用 end_time 作为时间条件 +// 返回按分钟桶分组的数据,每个分钟桶的数据独立返回 func (r *UsageStatsRepository) queryRecentMinutesStats(startMinute time.Time, filter repository.UsageStatsFilter) ([]*domain.UsageStats, error) { var conditions []string var args []interface{} @@ -494,50 +481,75 @@ func (r *UsageStatsRepository) queryRecentMinutesStats(startMinute time.Time, fi args = append(args, *filter.Model) } + // 查询原始数据,在 Go 中聚合(避免 SQLite 类型问题,性能更好) query := ` SELECT + a.end_time, COALESCE(r.route_id, 0), COALESCE(a.provider_id, 0), COALESCE(r.project_id, 0), COALESCE(r.api_token_id, 0), COALESCE(r.client_type, ''), COALESCE(a.response_model, ''), - COUNT(*), - SUM(CASE WHEN a.status = 'COMPLETED' THEN 1 ELSE 0 END), - SUM(CASE WHEN a.status IN ('FAILED', 'CANCELLED') THEN 1 ELSE 0 END), - COALESCE(SUM(a.duration_ms), 0), - COALESCE(SUM(a.input_token_count), 0), - COALESCE(SUM(a.output_token_count), 0), - COALESCE(SUM(a.cache_read_count), 0), - COALESCE(SUM(a.cache_write_count), 0), - COALESCE(SUM(a.cost), 0) + a.status, + COALESCE(a.duration_ms, 0), + COALESCE(a.ttft_ms, 0), + COALESCE(a.input_token_count, 0), + COALESCE(a.output_token_count, 0), + COALESCE(a.cache_read_count, 0), + COALESCE(a.cache_write_count, 0), + COALESCE(a.cost, 0) FROM proxy_upstream_attempts a LEFT JOIN proxy_requests r ON a.proxy_request_id = r.id - WHERE ` + strings.Join(conditions, " AND ") + ` - GROUP BY r.route_id, a.provider_id, r.project_id, r.api_token_id, r.client_type, a.response_model - ` + WHERE ` + strings.Join(conditions, " AND ") rows, err := r.db.gorm.Raw(query, args...).Rows() if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() - var results []*domain.UsageStats + // 收集所有记录,使用 stats.AggregateAttempts 聚合 + var records []stats.AttemptRecord for rows.Next() { - s := &domain.UsageStats{ - TimeBucket: startMinute, // 会在合并时被替换为目标时间桶 - Granularity: domain.GranularityMinute, - } + var endTime int64 + var routeID, providerID, projectID, apiTokenID uint64 + var clientType, model, status string + var durationMs, ttftMs, inputTokens, outputTokens, cacheRead, cacheWrite, cost uint64 + err := rows.Scan( - &s.RouteID, &s.ProviderID, &s.ProjectID, &s.APITokenID, &s.ClientType, - &s.Model, - &s.TotalRequests, &s.SuccessfulRequests, &s.FailedRequests, &s.TotalDurationMs, - &s.InputTokens, &s.OutputTokens, &s.CacheRead, &s.CacheWrite, &s.Cost, + &endTime, &routeID, &providerID, &projectID, &apiTokenID, &clientType, + &model, &status, &durationMs, &ttftMs, + &inputTokens, &outputTokens, &cacheRead, &cacheWrite, &cost, ) if err != nil { - return nil, err + continue } - results = append(results, s) + + records = append(records, stats.AttemptRecord{ + EndTime: fromTimestamp(endTime), + RouteID: routeID, + ProviderID: providerID, + ProjectID: projectID, + APITokenID: apiTokenID, + ClientType: clientType, + Model: model, + IsSuccessful: status == "COMPLETED", + IsFailed: status == "FAILED" || status == "CANCELLED", + DurationMs: durationMs, + TTFTMs: ttftMs, + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheRead: cacheRead, + CacheWrite: cacheWrite, + Cost: cost, + }) } - return results, rows.Err() + + if err := rows.Err(); err != nil { + return nil, err + } + + // 使用配置的时区进行分钟聚合 + loc := r.getConfiguredTimezone() + return stats.AggregateAttempts(records, loc), nil } // GetSummary 获取汇总统计数据(总计) @@ -690,7 +702,7 @@ func (r *UsageStatsRepository) getSummaryByDimension(filter repository.UsageStat if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() results := make(map[uint64]*domain.UsageStatsSummary) for rows.Next() { @@ -774,7 +786,7 @@ func (r *UsageStatsRepository) GetSummaryByClientType(filter repository.UsageSta if err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() results := make(map[string]*domain.UsageStatsSummary) for rows.Next() { @@ -822,22 +834,29 @@ func (r *UsageStatsRepository) GetLatestTimeBucket(granularity domain.Granularit } // GetProviderStats 获取 Provider 统计数据 +// 使用分层查询策略:历史月数据 + 当前月的分层实时数据 func (r *UsageStatsRepository) GetProviderStats(clientType string, projectID uint64) (map[uint64]*domain.ProviderStats, error) { - stats := make(map[uint64]*domain.ProviderStats) + result := make(map[uint64]*domain.ProviderStats) - conditions := []string{"provider_id > 0"} - var args []any + // 获取配置的时区 + loc := r.getConfiguredTimezone() + now := time.Now().In(loc) + currentMonth := stats.TruncateToGranularity(now, domain.GranularityMonth, loc) + + // 1. 查询历史月数据(当前月之前) + historyConditions := []string{"provider_id > 0", "granularity = ?", "time_bucket < ?"} + historyArgs := []any{domain.GranularityMonth, toTimestamp(currentMonth)} if clientType != "" { - conditions = append(conditions, "client_type = ?") - args = append(args, clientType) + historyConditions = append(historyConditions, "client_type = ?") + historyArgs = append(historyArgs, clientType) } if projectID > 0 { - conditions = append(conditions, "project_id = ?") - args = append(args, projectID) + historyConditions = append(historyConditions, "project_id = ?") + historyArgs = append(historyArgs, projectID) } - query := ` + historyQuery := ` SELECT provider_id, COALESCE(SUM(total_requests), 0), @@ -849,15 +868,14 @@ func (r *UsageStatsRepository) GetProviderStats(clientType string, projectID uin COALESCE(SUM(cache_write), 0), COALESCE(SUM(cost), 0) FROM usage_stats - WHERE ` + strings.Join(conditions, " AND ") + ` + WHERE ` + strings.Join(historyConditions, " AND ") + ` GROUP BY provider_id ` - rows, err := r.db.gorm.Raw(query, args...).Rows() + rows, err := r.db.gorm.Raw(historyQuery, historyArgs...).Rows() if err != nil { return nil, err } - defer rows.Close() for rows.Next() { var s domain.ProviderStats @@ -873,27 +891,112 @@ func (r *UsageStatsRepository) GetProviderStats(clientType string, projectID uin &s.TotalCost, ) if err != nil { + _ = rows.Close() return nil, err } + result[s.ProviderID] = &s + } + _ = rows.Close() + + // 2. 查询当前月的 day 粒度数据(使用 day 作为当前月的最粗粒度) + currentMonthConditions := []string{"provider_id > 0", "granularity = ?", "time_bucket >= ?"} + currentMonthArgs := []any{domain.GranularityDay, toTimestamp(currentMonth)} + + if clientType != "" { + currentMonthConditions = append(currentMonthConditions, "client_type = ?") + currentMonthArgs = append(currentMonthArgs, clientType) + } + if projectID > 0 { + currentMonthConditions = append(currentMonthConditions, "project_id = ?") + currentMonthArgs = append(currentMonthArgs, projectID) + } + + currentMonthQuery := ` + SELECT + provider_id, + COALESCE(SUM(total_requests), 0), + COALESCE(SUM(successful_requests), 0), + COALESCE(SUM(failed_requests), 0), + COALESCE(SUM(input_tokens), 0), + COALESCE(SUM(output_tokens), 0), + COALESCE(SUM(cache_read), 0), + COALESCE(SUM(cache_write), 0), + COALESCE(SUM(cost), 0) + FROM usage_stats + WHERE ` + strings.Join(currentMonthConditions, " AND ") + ` + GROUP BY provider_id + ` + + rows, err = r.db.gorm.Raw(currentMonthQuery, currentMonthArgs...).Rows() + if err != nil { + return nil, err + } + + for rows.Next() { + var providerID uint64 + var totalRequests, successfulRequests, failedRequests uint64 + var inputTokens, outputTokens, cacheRead, cacheWrite, cost uint64 + err := rows.Scan( + &providerID, + &totalRequests, + &successfulRequests, + &failedRequests, + &inputTokens, + &outputTokens, + &cacheRead, + &cacheWrite, + &cost, + ) + if err != nil { + _ = rows.Close() + return nil, err + } + + // 累加到已有数据 + if existing, ok := result[providerID]; ok { + existing.TotalRequests += totalRequests + existing.SuccessfulRequests += successfulRequests + existing.FailedRequests += failedRequests + existing.TotalInputTokens += inputTokens + existing.TotalOutputTokens += outputTokens + existing.TotalCacheRead += cacheRead + existing.TotalCacheWrite += cacheWrite + existing.TotalCost += cost + } else { + result[providerID] = &domain.ProviderStats{ + ProviderID: providerID, + TotalRequests: totalRequests, + SuccessfulRequests: successfulRequests, + FailedRequests: failedRequests, + TotalInputTokens: inputTokens, + TotalOutputTokens: outputTokens, + TotalCacheRead: cacheRead, + TotalCacheWrite: cacheWrite, + TotalCost: cost, + } + } + } + _ = rows.Close() + + // 计算成功率 + for _, s := range result { if s.TotalRequests > 0 { s.SuccessRate = float64(s.SuccessfulRequests) / float64(s.TotalRequests) * 100 } - stats[s.ProviderID] = &s } - return stats, rows.Err() + return result, nil } -// AggregateMinute 从原始数据聚合到分钟级别 -// 只聚合已完成的请求(COMPLETED/FAILED/CANCELLED),使用 end_time 作为时间桶 -func (r *UsageStatsRepository) AggregateMinute() (int, error) { +// aggregateMinute 从原始数据聚合到分钟级别(内部方法) +// 返回:聚合数量、开始时间、结束时间、错误 +func (r *UsageStatsRepository) aggregateMinute() (count int, startTime, endTime time.Time, err error) { now := time.Now().UTC() - currentMinute := now.Truncate(time.Minute) + endTime = now.Truncate(time.Minute) // 获取最新的聚合分钟 - latestMinute, err := r.GetLatestTimeBucket(domain.GranularityMinute) - var startTime time.Time - if err != nil || latestMinute == nil { + latestMinute, e := r.GetLatestTimeBucket(domain.GranularityMinute) + if e != nil || latestMinute == nil { // 如果没有历史数据,从 2 小时前开始 startTime = now.Add(-2 * time.Hour).Truncate(time.Minute) } else { @@ -909,9 +1012,9 @@ func (r *UsageStatsRepository) AggregateMinute() (int, error) { COALESCE(r.route_id, 0), COALESCE(a.provider_id, 0), COALESCE(r.project_id, 0), COALESCE(r.api_token_id, 0), COALESCE(r.client_type, ''), COALESCE(a.response_model, ''), - CASE WHEN a.status = 'COMPLETED' THEN 1 ELSE 0 END, - CASE WHEN a.status IN ('FAILED', 'CANCELLED') THEN 1 ELSE 0 END, + a.status, COALESCE(a.duration_ms, 0), + COALESCE(a.ttft_ms, 0), COALESCE(a.input_token_count, 0), COALESCE(a.output_token_count, 0), COALESCE(a.cache_read_count, 0), @@ -923,36 +1026,25 @@ func (r *UsageStatsRepository) AggregateMinute() (int, error) { AND a.status IN ('COMPLETED', 'FAILED', 'CANCELLED') ` - rows, err := r.db.gorm.Raw(query, toTimestamp(startTime), toTimestamp(currentMinute)).Rows() + rows, err := r.db.gorm.Raw(query, toTimestamp(startTime), toTimestamp(endTime)).Rows() if err != nil { - return 0, err + return 0, startTime, endTime, err } - defer rows.Close() + defer func() { _ = rows.Close() }() - // 使用 map 聚合数据 - type aggKey struct { - minuteBucket int64 - routeID uint64 - providerID uint64 - projectID uint64 - apiTokenID uint64 - clientType string - model string - } - statsMap := make(map[aggKey]*domain.UsageStats) + // 收集所有记录,使用 stats.AggregateAttempts 聚合 + var records []stats.AttemptRecord responseModels := make(map[string]bool) for rows.Next() { var endTime int64 var routeID, providerID, projectID, apiTokenID uint64 - var clientType, model string - var successful, failed int - var durationMs, inputTokens, outputTokens, cacheRead, cacheWrite, cost uint64 + var clientType, model, status string + var durationMs, ttftMs, inputTokens, outputTokens, cacheRead, cacheWrite, cost uint64 err := rows.Scan( &endTime, &routeID, &providerID, &projectID, &apiTokenID, &clientType, - &model, - &successful, &failed, &durationMs, + &model, &status, &durationMs, &ttftMs, &inputTokens, &outputTokens, &cacheRead, &cacheWrite, &cost, ) if err != nil { @@ -964,50 +1056,24 @@ func (r *UsageStatsRepository) AggregateMinute() (int, error) { responseModels[model] = true } - // 截断到分钟(使用 end_time) - minuteBucket := fromTimestamp(endTime).Truncate(time.Minute).UnixMilli() - - key := aggKey{ - minuteBucket: minuteBucket, - routeID: routeID, - providerID: providerID, - projectID: projectID, - apiTokenID: apiTokenID, - clientType: clientType, - model: model, - } - - if s, ok := statsMap[key]; ok { - s.TotalRequests++ - s.SuccessfulRequests += uint64(successful) - s.FailedRequests += uint64(failed) - s.TotalDurationMs += durationMs - s.InputTokens += inputTokens - s.OutputTokens += outputTokens - s.CacheRead += cacheRead - s.CacheWrite += cacheWrite - s.Cost += cost - } else { - statsMap[key] = &domain.UsageStats{ - Granularity: domain.GranularityMinute, - TimeBucket: time.UnixMilli(minuteBucket), - RouteID: routeID, - ProviderID: providerID, - ProjectID: projectID, - APITokenID: apiTokenID, - ClientType: clientType, - Model: model, - TotalRequests: 1, - SuccessfulRequests: uint64(successful), - FailedRequests: uint64(failed), - TotalDurationMs: durationMs, - InputTokens: inputTokens, - OutputTokens: outputTokens, - CacheRead: cacheRead, - CacheWrite: cacheWrite, - Cost: cost, - } - } + records = append(records, stats.AttemptRecord{ + EndTime: fromTimestamp(endTime), + RouteID: routeID, + ProviderID: providerID, + ProjectID: projectID, + APITokenID: apiTokenID, + ClientType: clientType, + Model: model, + IsSuccessful: status == "COMPLETED", + IsFailed: status == "FAILED" || status == "CANCELLED", + DurationMs: durationMs, + TTFTMs: ttftMs, + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheRead: cacheRead, + CacheWrite: cacheWrite, + Cost: cost, + }) } // 记录 response models 到独立表 @@ -1020,40 +1086,92 @@ func (r *UsageStatsRepository) AggregateMinute() (int, error) { _ = responseModelRepo.BatchUpsert(models) } - if len(statsMap) == 0 { - return 0, nil + if len(records) == 0 { + return 0, startTime, endTime, nil } - statsList := make([]*domain.UsageStats, 0, len(statsMap)) - for _, s := range statsMap { - statsList = append(statsList, s) + // 使用配置的时区进行分钟聚合 + loc := r.getConfiguredTimezone() + statsList := stats.AggregateAttempts(records, loc) + + if len(statsList) == 0 { + return 0, startTime, endTime, nil } - return len(statsList), r.BatchUpsert(statsList) + err = r.BatchUpsert(statsList) + return len(statsList), startTime, endTime, err +} + +// AggregateAndRollUp 聚合原始数据到分钟级别,并自动 rollup 到各个粗粒度 +// 返回一个 channel,发送每个阶段的进度事件,channel 会在完成后关闭 +// 调用者可以 range 遍历 channel 获取进度,或直接忽略(异步执行) +func (r *UsageStatsRepository) AggregateAndRollUp() <-chan domain.AggregateEvent { + ch := make(chan domain.AggregateEvent, 5) // buffered to avoid blocking + + go func() { + defer close(ch) + + // 1. 聚合原始数据到分钟级别 + count, startTime, endTime, err := r.aggregateMinute() + ch <- domain.AggregateEvent{ + Phase: "aggregate_minute", + To: domain.GranularityMinute, + StartTime: startTime.UnixMilli(), + EndTime: endTime.UnixMilli(), + Count: count, + Error: err, + } + if err != nil { + return + } + + // 2. 自动 rollup 到各个粒度 + rollups := []struct { + from domain.Granularity + to domain.Granularity + phase string + }{ + {domain.GranularityMinute, domain.GranularityHour, "rollup_hour"}, + {domain.GranularityHour, domain.GranularityDay, "rollup_day"}, + {domain.GranularityDay, domain.GranularityMonth, "rollup_month"}, + } + + for _, ru := range rollups { + count, startTime, endTime, err := r.rollUp(ru.from, ru.to) + ch <- domain.AggregateEvent{ + Phase: ru.phase, + From: ru.from, + To: ru.to, + StartTime: startTime.UnixMilli(), + EndTime: endTime.UnixMilli(), + Count: count, + Error: err, + } + if err != nil { + return + } + } + }() + + return ch } -// RollUp 从细粒度上卷到粗粒度 -// 对于 day/week/month 粒度,使用配置的时区来划分边界 -func (r *UsageStatsRepository) RollUp(from, to domain.Granularity) (int, error) { +// rollUp 从细粒度上卷到粗粒度(内部方法) +// 返回:聚合数量、开始时间、结束时间、错误 +func (r *UsageStatsRepository) rollUp(from, to domain.Granularity) (count int, startTime, endTime time.Time, err error) { now := time.Now().UTC() - // 对于 day 及以上粒度,使用配置的时区 - var loc *time.Location - if to == domain.GranularityDay || to == domain.GranularityWeek || to == domain.GranularityMonth { + // 对于 day 及以上粒度,使用配置的时区,否则使用 UTC + loc := time.UTC + if to == domain.GranularityDay || to == domain.GranularityMonth { loc = r.getConfiguredTimezone() } // 计算当前时间桶 - var currentBucket time.Time - if loc != nil { - currentBucket = TruncateToGranularityInTimezone(now, to, loc) - } else { - currentBucket = TruncateToGranularity(now, to) - } + endTime = stats.TruncateToGranularity(now, to, loc) // 获取目标粒度的最新时间桶 latestBucket, _ := r.GetLatestTimeBucket(to) - var startTime time.Time if latestBucket == nil { // 如果没有历史数据,根据源粒度的保留时间决定 switch from { @@ -1072,108 +1190,47 @@ func (r *UsageStatsRepository) RollUp(from, to domain.Granularity) (int, error) // 查询源粒度数据 var models []UsageStats - err := r.db.gorm.Where("granularity = ? AND time_bucket >= ? AND time_bucket < ?", - from, toTimestamp(startTime), toTimestamp(currentBucket)). + err = r.db.gorm.Where("granularity = ? AND time_bucket >= ? AND time_bucket < ?", + from, toTimestamp(startTime), toTimestamp(endTime)). Find(&models).Error if err != nil { - return 0, err + return 0, startTime, endTime, err } - // 使用 map 聚合数据 - type rollupKey struct { - targetBucket int64 - routeID uint64 - providerID uint64 - projectID uint64 - apiTokenID uint64 - clientType string - model string - } - statsMap := make(map[rollupKey]*domain.UsageStats) - - for _, m := range models { - // 截断到目标粒度(使用配置的时区) - t := fromTimestamp(m.TimeBucket) - var targetBucket int64 - if loc != nil { - targetBucket = TruncateToGranularityInTimezone(t, to, loc).UnixMilli() - } else { - targetBucket = TruncateToGranularity(t, to).UnixMilli() - } - - key := rollupKey{ - targetBucket: targetBucket, - routeID: m.RouteID, - providerID: m.ProviderID, - projectID: m.ProjectID, - apiTokenID: m.APITokenID, - clientType: m.ClientType, - model: m.Model, - } - - if s, ok := statsMap[key]; ok { - s.TotalRequests += m.TotalRequests - s.SuccessfulRequests += m.SuccessfulRequests - s.FailedRequests += m.FailedRequests - s.TotalDurationMs += m.TotalDurationMs - s.InputTokens += m.InputTokens - s.OutputTokens += m.OutputTokens - s.CacheRead += m.CacheRead - s.CacheWrite += m.CacheWrite - s.Cost += m.Cost - } else { - statsMap[key] = &domain.UsageStats{ - Granularity: to, - TimeBucket: time.UnixMilli(targetBucket), - RouteID: m.RouteID, - ProviderID: m.ProviderID, - ProjectID: m.ProjectID, - APITokenID: m.APITokenID, - ClientType: m.ClientType, - Model: m.Model, - TotalRequests: m.TotalRequests, - SuccessfulRequests: m.SuccessfulRequests, - FailedRequests: m.FailedRequests, - TotalDurationMs: m.TotalDurationMs, - InputTokens: m.InputTokens, - OutputTokens: m.OutputTokens, - CacheRead: m.CacheRead, - CacheWrite: m.CacheWrite, - Cost: m.Cost, - } - } + if len(models) == 0 { + return 0, startTime, endTime, nil } - if len(statsMap) == 0 { - return 0, nil - } + // 转换为 domain 对象并使用 stats.RollUp 聚合 + domainStats := r.toDomainList(models) + rolledUp := stats.RollUp(domainStats, to, loc) - statsList := make([]*domain.UsageStats, 0, len(statsMap)) - for _, s := range statsMap { - statsList = append(statsList, s) + if len(rolledUp) == 0 { + return 0, startTime, endTime, nil } - return len(statsList), r.BatchUpsert(statsList) + err = r.BatchUpsert(rolledUp) + return len(rolledUp), startTime, endTime, err } // RollUpAll 从细粒度上卷到粗粒度(处理所有历史数据,用于重新计算) -// 对于 day/week/month 粒度,使用配置的时区来划分边界 +// 对于 day/month 粒度,使用配置的时区来划分边界 func (r *UsageStatsRepository) RollUpAll(from, to domain.Granularity) (int, error) { + return r.RollUpAllWithProgress(from, to, nil) +} + +// RollUpAllWithProgress 从细粒度上卷到粗粒度,带进度报告 +func (r *UsageStatsRepository) RollUpAllWithProgress(from, to domain.Granularity, progressFn func(current, total int)) (int, error) { now := time.Now().UTC() - // 对于 day 及以上粒度,使用配置的时区 - var loc *time.Location - if to == domain.GranularityDay || to == domain.GranularityWeek || to == domain.GranularityMonth { + // 对于 day 及以上粒度,使用配置的时区,否则使用 UTC + loc := time.UTC + if to == domain.GranularityDay || to == domain.GranularityMonth { loc = r.getConfiguredTimezone() } // 计算当前时间桶 - var currentBucket time.Time - if loc != nil { - currentBucket = TruncateToGranularityInTimezone(now, to, loc) - } else { - currentBucket = TruncateToGranularity(now, to) - } + currentBucket := stats.TruncateToGranularity(now, to, loc) // 查询所有源粒度数据 var models []UsageStats @@ -1183,120 +1240,121 @@ func (r *UsageStatsRepository) RollUpAll(from, to domain.Granularity) (int, erro return 0, err } - // 使用 map 聚合数据 - type rollupKey struct { - targetBucket int64 - routeID uint64 - providerID uint64 - projectID uint64 - apiTokenID uint64 - clientType string - model string - } - statsMap := make(map[rollupKey]*domain.UsageStats) - - for _, m := range models { - // 截断到目标粒度(使用配置的时区) - t := fromTimestamp(m.TimeBucket) - var targetBucket int64 - if loc != nil { - targetBucket = TruncateToGranularityInTimezone(t, to, loc).UnixMilli() - } else { - targetBucket = TruncateToGranularity(t, to).UnixMilli() - } - - key := rollupKey{ - targetBucket: targetBucket, - routeID: m.RouteID, - providerID: m.ProviderID, - projectID: m.ProjectID, - apiTokenID: m.APITokenID, - clientType: m.ClientType, - model: m.Model, - } + total := len(models) + if total == 0 { + return 0, nil + } - if s, ok := statsMap[key]; ok { - s.TotalRequests += m.TotalRequests - s.SuccessfulRequests += m.SuccessfulRequests - s.FailedRequests += m.FailedRequests - s.TotalDurationMs += m.TotalDurationMs - s.InputTokens += m.InputTokens - s.OutputTokens += m.OutputTokens - s.CacheRead += m.CacheRead - s.CacheWrite += m.CacheWrite - s.Cost += m.Cost - } else { - statsMap[key] = &domain.UsageStats{ - Granularity: to, - TimeBucket: time.UnixMilli(targetBucket), - RouteID: m.RouteID, - ProviderID: m.ProviderID, - ProjectID: m.ProjectID, - APITokenID: m.APITokenID, - ClientType: m.ClientType, - Model: m.Model, - TotalRequests: m.TotalRequests, - SuccessfulRequests: m.SuccessfulRequests, - FailedRequests: m.FailedRequests, - TotalDurationMs: m.TotalDurationMs, - InputTokens: m.InputTokens, - OutputTokens: m.OutputTokens, - CacheRead: m.CacheRead, - CacheWrite: m.CacheWrite, - Cost: m.Cost, - } - } + // 报告初始进度 + if progressFn != nil { + progressFn(0, total) } - if len(statsMap) == 0 { - return 0, nil + // 转换为 domain 对象并使用 stats.RollUp 聚合 + domainStats := r.toDomainList(models) + rolledUp := stats.RollUp(domainStats, to, loc) + + // 报告最终进度 + if progressFn != nil { + progressFn(total, total) } - statsList := make([]*domain.UsageStats, 0, len(statsMap)) - for _, s := range statsMap { - statsList = append(statsList, s) + if len(rolledUp) == 0 { + return 0, nil } - return len(statsList), r.BatchUpsert(statsList) + return len(rolledUp), r.BatchUpsert(rolledUp) } // ClearAndRecalculate 清空统计数据并重新从原始数据计算 func (r *UsageStatsRepository) ClearAndRecalculate() error { + return r.ClearAndRecalculateWithProgress(nil) +} + +// ClearAndRecalculateWithProgress 清空统计数据并重新计算,通过 channel 报告进度 +func (r *UsageStatsRepository) ClearAndRecalculateWithProgress(progress chan<- domain.Progress) error { + sendProgress := func(phase string, current, total int, message string) { + if progress == nil { + return + } + percentage := 0 + if total > 0 { + percentage = current * 100 / total + } + progress <- domain.Progress{ + Phase: phase, + Current: current, + Total: total, + Percentage: percentage, + Message: message, + } + } + // 1. 清空所有统计数据 + sendProgress("clearing", 0, 100, "Clearing existing stats...") if err := r.db.gorm.Exec(`DELETE FROM usage_stats`).Error; err != nil { return fmt.Errorf("failed to clear usage_stats: %w", err) } - // 2. 重新聚合分钟级数据(从所有历史数据) - _, err := r.aggregateAllMinutes() + // 2. 重新聚合分钟级数据(从所有历史数据)- 带进度 + _, err := r.aggregateAllMinutesWithProgress(func(current, total int) { + sendProgress("aggregating", current, total, fmt.Sprintf("Aggregating attempts: %d/%d", current, total)) + }) if err != nil { return fmt.Errorf("failed to aggregate minutes: %w", err) } - // 3. Roll-up 到各个粒度(使用完整时间范围) - _, _ = r.RollUpAll(domain.GranularityMinute, domain.GranularityHour) - _, _ = r.RollUpAll(domain.GranularityHour, domain.GranularityDay) - _, _ = r.RollUpAll(domain.GranularityDay, domain.GranularityWeek) - _, _ = r.RollUpAll(domain.GranularityDay, domain.GranularityMonth) + // 3. Roll-up 到各个粒度(使用完整时间范围)- 带进度 + _, _ = r.RollUpAllWithProgress(domain.GranularityMinute, domain.GranularityHour, func(current, total int) { + sendProgress("rollup", current, total, fmt.Sprintf("Rolling up to hourly: %d/%d", current, total)) + }) + + _, _ = r.RollUpAllWithProgress(domain.GranularityHour, domain.GranularityDay, func(current, total int) { + sendProgress("rollup", current, total, fmt.Sprintf("Rolling up to daily: %d/%d", current, total)) + }) + + _, _ = r.RollUpAllWithProgress(domain.GranularityDay, domain.GranularityMonth, func(current, total int) { + sendProgress("rollup", current, total, fmt.Sprintf("Rolling up to monthly: %d/%d", current, total)) + }) + sendProgress("completed", 100, 100, "Stats recalculation completed") return nil } -// aggregateAllMinutes 从所有历史数据聚合分钟级统计 -// 只聚合已完成的请求,使用 end_time 作为时间桶 -func (r *UsageStatsRepository) aggregateAllMinutes() (int, error) { +// aggregateAllMinutesWithProgress 从所有历史数据聚合分钟级统计,带进度回调 +// progressFn 会在每处理一定数量的记录后调用,参数为 (current, total) +func (r *UsageStatsRepository) aggregateAllMinutesWithProgress(progressFn func(current, total int)) (int, error) { now := time.Now().UTC() currentMinute := now.Truncate(time.Minute) + // 1. 首先获取总数以便报告进度 + var totalCount int64 + countQuery := `SELECT COUNT(*) FROM proxy_upstream_attempts WHERE end_time < ? AND status IN ('COMPLETED', 'FAILED', 'CANCELLED')` + if err := r.db.gorm.Raw(countQuery, toTimestamp(currentMinute)).Scan(&totalCount).Error; err != nil { + return 0, err + } + + if totalCount == 0 { + if progressFn != nil { + progressFn(0, 0) + } + return 0, nil + } + + // 报告初始进度 + if progressFn != nil { + progressFn(0, int(totalCount)) + } + query := ` SELECT a.end_time, COALESCE(r.route_id, 0), COALESCE(a.provider_id, 0), COALESCE(r.project_id, 0), COALESCE(r.api_token_id, 0), COALESCE(r.client_type, ''), COALESCE(a.response_model, ''), - CASE WHEN a.status = 'COMPLETED' THEN 1 ELSE 0 END, - CASE WHEN a.status IN ('FAILED', 'CANCELLED') THEN 1 ELSE 0 END, + a.status, COALESCE(a.duration_ms, 0), + COALESCE(a.ttft_ms, 0), COALESCE(a.input_token_count, 0), COALESCE(a.output_token_count, 0), COALESCE(a.cache_read_count, 0), @@ -1311,32 +1369,25 @@ func (r *UsageStatsRepository) aggregateAllMinutes() (int, error) { if err != nil { return 0, err } - defer rows.Close() + defer func() { _ = rows.Close() }() - // 使用 map 聚合数据 - type aggKey struct { - minuteBucket int64 - routeID uint64 - providerID uint64 - projectID uint64 - apiTokenID uint64 - clientType string - model string - } - statsMap := make(map[aggKey]*domain.UsageStats) + // 收集所有记录,使用 stats.AggregateAttempts 聚合 + var records []stats.AttemptRecord responseModels := make(map[string]bool) + // 进度跟踪 + processedCount := 0 + const progressInterval = 100 // 每处理100条报告一次进度 + for rows.Next() { var endTime int64 var routeID, providerID, projectID, apiTokenID uint64 - var clientType, model string - var successful, failed int - var durationMs, inputTokens, outputTokens, cacheRead, cacheWrite, cost uint64 + var clientType, model, status string + var durationMs, ttftMs, inputTokens, outputTokens, cacheRead, cacheWrite, cost uint64 err := rows.Scan( &endTime, &routeID, &providerID, &projectID, &apiTokenID, &clientType, - &model, - &successful, &failed, &durationMs, + &model, &status, &durationMs, &ttftMs, &inputTokens, &outputTokens, &cacheRead, &cacheWrite, &cost, ) if err != nil { @@ -1344,55 +1395,40 @@ func (r *UsageStatsRepository) aggregateAllMinutes() (int, error) { continue } + processedCount++ + // 定期报告进度 + if progressFn != nil && processedCount%progressInterval == 0 { + progressFn(processedCount, int(totalCount)) + } + // 记录 response model if model != "" { responseModels[model] = true } - // 截断到分钟(使用 end_time) - minuteBucket := fromTimestamp(endTime).Truncate(time.Minute).UnixMilli() - - key := aggKey{ - minuteBucket: minuteBucket, - routeID: routeID, - providerID: providerID, - projectID: projectID, - apiTokenID: apiTokenID, - clientType: clientType, - model: model, - } + records = append(records, stats.AttemptRecord{ + EndTime: fromTimestamp(endTime), + RouteID: routeID, + ProviderID: providerID, + ProjectID: projectID, + APITokenID: apiTokenID, + ClientType: clientType, + Model: model, + IsSuccessful: status == "COMPLETED", + IsFailed: status == "FAILED" || status == "CANCELLED", + DurationMs: durationMs, + TTFTMs: ttftMs, + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheRead: cacheRead, + CacheWrite: cacheWrite, + Cost: cost, + }) + } - if s, ok := statsMap[key]; ok { - s.TotalRequests++ - s.SuccessfulRequests += uint64(successful) - s.FailedRequests += uint64(failed) - s.TotalDurationMs += durationMs - s.InputTokens += inputTokens - s.OutputTokens += outputTokens - s.CacheRead += cacheRead - s.CacheWrite += cacheWrite - s.Cost += cost - } else { - statsMap[key] = &domain.UsageStats{ - Granularity: domain.GranularityMinute, - TimeBucket: time.UnixMilli(minuteBucket), - RouteID: routeID, - ProviderID: providerID, - ProjectID: projectID, - APITokenID: apiTokenID, - ClientType: clientType, - Model: model, - TotalRequests: 1, - SuccessfulRequests: uint64(successful), - FailedRequests: uint64(failed), - TotalDurationMs: durationMs, - InputTokens: inputTokens, - OutputTokens: outputTokens, - CacheRead: cacheRead, - CacheWrite: cacheWrite, - Cost: cost, - } - } + // 报告最终进度 + if progressFn != nil { + progressFn(processedCount, int(totalCount)) } // 记录 response models 到独立表 @@ -1407,13 +1443,16 @@ func (r *UsageStatsRepository) aggregateAllMinutes() (int, error) { } } - if len(statsMap) == 0 { + if len(records) == 0 { return 0, nil } - statsList := make([]*domain.UsageStats, 0, len(statsMap)) - for _, s := range statsMap { - statsList = append(statsList, s) + // 使用配置的时区进行分钟聚合 + loc := r.getConfiguredTimezone() + statsList := stats.AggregateAttempts(records, loc) + + if len(statsList) == 0 { + return 0, nil } return len(statsList), r.BatchUpsert(statsList) @@ -1435,6 +1474,7 @@ func (r *UsageStatsRepository) toModel(s *domain.UsageStats) *UsageStats { SuccessfulRequests: s.SuccessfulRequests, FailedRequests: s.FailedRequests, TotalDurationMs: s.TotalDurationMs, + TotalTTFTMs: s.TotalTTFTMs, InputTokens: s.InputTokens, OutputTokens: s.OutputTokens, CacheRead: s.CacheRead, @@ -1459,6 +1499,7 @@ func (r *UsageStatsRepository) toDomain(m *UsageStats) *domain.UsageStats { SuccessfulRequests: m.SuccessfulRequests, FailedRequests: m.FailedRequests, TotalDurationMs: m.TotalDurationMs, + TotalTTFTMs: m.TotalTTFTMs, InputTokens: m.InputTokens, OutputTokens: m.OutputTokens, CacheRead: m.CacheRead, @@ -1477,9 +1518,9 @@ func (r *UsageStatsRepository) toDomainList(models []UsageStats) []*domain.Usage // QueryDashboardData 查询 Dashboard 所需的所有数据(单次请求) // 优化:只执行 3 次主查询 -// 1. 历史 day 粒度数据 (371天) → 热力图、昨日、Provider统计(30天) -// 2. 今日实时 hour 粒度 (QueryWithRealtime) → 今日统计、24h趋势、今日热力图 -// 3. 全量 month 粒度 (QueryWithRealtime) → 全量统计、Top模型(全量) +// 1. 历史 day 粒度数据 (371天) → 热力图、昨日、Provider统计(30天) +// 2. 今日实时 hour 粒度 (Query) → 今日统计、24h趋势、今日热力图 +// 3. 全量 month 粒度 (Query) → 全量统计、Top模型(全量) func (r *UsageStatsRepository) QueryDashboardData() (*domain.DashboardData, error) { // 获取配置的时区 loc := r.getConfiguredTimezone() @@ -1518,7 +1559,7 @@ func (r *UsageStatsRepository) QueryDashboardData() (*domain.DashboardData, erro if err != nil { return err } - defer rows.Close() + defer func() { _ = rows.Close() }() // 初始化热力图(使用配置的时区格式化日期) days := int(now.Sub(days371Ago).Hours()/24) + 1 @@ -1602,14 +1643,14 @@ func (r *UsageStatsRepository) QueryDashboardData() (*domain.DashboardData, erro return nil }) - // 查询2: 今日实时 hour 粒度 (QueryWithRealtime) + // 查询2: 今日实时 hour 粒度 (Query) // 用于:今日统计、24h趋势、今日热力图、Provider今日RPM/TPM g.Go(func() error { filter := repository.UsageStatsFilter{ Granularity: domain.GranularityHour, StartTime: &hours24Ago, } - stats, err := r.QueryWithRealtime(filter) + stats, err := r.Query(filter) if err != nil { return err } @@ -1733,13 +1774,13 @@ func (r *UsageStatsRepository) QueryDashboardData() (*domain.DashboardData, erro return nil }) - // 查询3: 全量 month 粒度 (QueryWithRealtime) + // 查询3: 全量 month 粒度 (Query) // 用于:全量统计、Top模型(全量) g.Go(func() error { filter := repository.UsageStatsFilter{ Granularity: domain.GranularityMonth, } - stats, err := r.QueryWithRealtime(filter) + stats, err := r.Query(filter) if err != nil { return err } @@ -1828,259 +1869,3 @@ func (r *UsageStatsRepository) getTopModels(modelData map[string]*struct { } return result } - -// aggregateToSummary 将 UsageStats 列表聚合为 DashboardDaySummary -func (r *UsageStatsRepository) aggregateToSummary(stats []*domain.UsageStats) domain.DashboardDaySummary { - var result domain.DashboardDaySummary - var successfulRequests uint64 - - for _, s := range stats { - result.Requests += s.TotalRequests - successfulRequests += s.SuccessfulRequests - result.Tokens += s.InputTokens + s.OutputTokens + s.CacheRead + s.CacheWrite - result.Cost += s.Cost - } - - if result.Requests > 0 { - result.SuccessRate = float64(successfulRequests) / float64(result.Requests) * 100 - } - - return result -} - -// statsToHeatmap 将 UsageStats 列表转换为热力图数据 -func (r *UsageStatsRepository) statsToHeatmap(stats []*domain.UsageStats, start, end time.Time) []domain.DashboardHeatmapPoint { - // 初始化所有日期 - days := int(end.Sub(start).Hours() / 24) - dateMap := make(map[string]uint64, days) - for i := 0; i < days; i++ { - date := start.Add(time.Duration(i) * 24 * time.Hour) - dateStr := date.Format("2006-01-02") - dateMap[dateStr] = 0 - } - - // 按天聚合 - for _, s := range stats { - dateStr := s.TimeBucket.Format("2006-01-02") - dateMap[dateStr] += s.TotalRequests - } - - // 转换为有序数组 - result := make([]domain.DashboardHeatmapPoint, 0, days) - for i := 0; i < days; i++ { - date := start.Add(time.Duration(i) * 24 * time.Hour) - dateStr := date.Format("2006-01-02") - result = append(result, domain.DashboardHeatmapPoint{ - Date: dateStr, - Count: dateMap[dateStr], - }) - } - - return result -} - -// statsToTrend24h 将 UsageStats 列表转换为 24 小时趋势数据 -func (r *UsageStatsRepository) statsToTrend24h(stats []*domain.UsageStats, start, end time.Time) []domain.DashboardTrendPoint { - // 初始化 24 小时 - hourMap := make(map[string]uint64, 24) - for i := 0; i < 24; i++ { - hour := start.Add(time.Duration(i) * time.Hour).Truncate(time.Hour) - hourStr := hour.Format("15:04") - hourMap[hourStr] = 0 - } - - // 按小时聚合 - for _, s := range stats { - hourStr := s.TimeBucket.Format("15:04") - hourMap[hourStr] += s.TotalRequests - } - - // 转换为有序数组 - result := make([]domain.DashboardTrendPoint, 0, 24) - for i := 0; i < 24; i++ { - hour := start.Add(time.Duration(i) * time.Hour).Truncate(time.Hour) - hourStr := hour.Format("15:04") - result = append(result, domain.DashboardTrendPoint{ - Hour: hourStr, - Requests: hourMap[hourStr], - }) - } - - return result -} - -// statsToProviderStats 将 UsageStats 列表转换为 Provider 统计 -func (r *UsageStatsRepository) statsToProviderStats(stats []*domain.UsageStats) map[uint64]domain.DashboardProviderStats { - // 按 Provider 聚合 - providerMap := make(map[uint64]*struct { - requests uint64 - successful uint64 - }) - - for _, s := range stats { - if s.ProviderID == 0 { - continue - } - if _, ok := providerMap[s.ProviderID]; !ok { - providerMap[s.ProviderID] = &struct { - requests uint64 - successful uint64 - }{} - } - providerMap[s.ProviderID].requests += s.TotalRequests - providerMap[s.ProviderID].successful += s.SuccessfulRequests - } - - // 转换为结果 - result := make(map[uint64]domain.DashboardProviderStats) - for providerID, data := range providerMap { - var successRate float64 - if data.requests > 0 { - successRate = float64(data.successful) / float64(data.requests) * 100 - } - result[providerID] = domain.DashboardProviderStats{ - Requests: data.requests, - SuccessRate: successRate, - } - } - - return result -} - -// queryDashboardAllTimeStats 查询全量统计和首次使用日期 -func (r *UsageStatsRepository) queryDashboardAllTimeStats() (domain.DashboardDaySummary, *time.Time, error) { - var result domain.DashboardDaySummary - - // 查询全量统计(使用 month 粒度) - query := ` - SELECT - COALESCE(SUM(total_requests), 0), - COALESCE(SUM(input_tokens + output_tokens + cache_read + cache_write), 0), - COALESCE(SUM(cost), 0), - MIN(time_bucket) - FROM usage_stats - WHERE granularity = 'month' - ` - - var totalRequests, tokens, cost uint64 - var minBucket *int64 - err := r.db.gorm.Raw(query).Row().Scan(&totalRequests, &tokens, &cost, &minBucket) - if err != nil { - return result, nil, err - } - - result.Requests = totalRequests - result.Tokens = tokens - result.Cost = cost - - var firstUse *time.Time - if minBucket != nil && *minBucket > 0 { - t := fromTimestamp(*minBucket) - firstUse = &t - } - - return result, firstUse, nil -} - -// queryDashboardHeatmap 查询热力图数据 -// 历史数据用 day 粒度预聚合,今天用 QueryWithRealtime 获取实时数据 -func (r *UsageStatsRepository) queryDashboardHeatmap(start, todayStart, end time.Time) ([]domain.DashboardHeatmapPoint, error) { - // 初始化所有日期 - days := int(end.Sub(start).Hours()/24) + 1 - dateMap := make(map[string]uint64, days) - for i := 0; i < days; i++ { - date := start.Add(time.Duration(i) * 24 * time.Hour) - dateStr := date.Format("2006-01-02") - dateMap[dateStr] = 0 - } - - // 1. 查询历史天数据(今天之前,使用 day 粒度预聚合) - if todayStart.After(start) { - query := ` - SELECT time_bucket, SUM(total_requests) as count - FROM usage_stats - WHERE granularity = 'day' - AND time_bucket >= ? AND time_bucket < ? - GROUP BY time_bucket - ` - rows, err := r.db.gorm.Raw(query, toTimestamp(start), toTimestamp(todayStart)).Rows() - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var bucket int64 - var count uint64 - if err := rows.Scan(&bucket, &count); err != nil { - continue - } - dateStr := fromTimestamp(bucket).Format("2006-01-02") - dateMap[dateStr] += count - } - } - - // 2. 查询今天的实时数据(使用 QueryWithRealtime) - todayFilter := repository.UsageStatsFilter{ - Granularity: domain.GranularityDay, - StartTime: &todayStart, - } - todayStats, err := r.QueryWithRealtime(todayFilter) - if err != nil { - return nil, err - } - - // 聚合今天的数据 - todayDateStr := todayStart.Format("2006-01-02") - for _, s := range todayStats { - dateMap[todayDateStr] += s.TotalRequests - } - - // 转换为有序数组 - result := make([]domain.DashboardHeatmapPoint, 0, days) - for i := 0; i < days; i++ { - date := start.Add(time.Duration(i) * 24 * time.Hour) - dateStr := date.Format("2006-01-02") - result = append(result, domain.DashboardHeatmapPoint{ - Date: dateStr, - Count: dateMap[dateStr], - }) - } - - return result, nil -} - -// queryDashboardTopModels 查询 Top N 模型 -func (r *UsageStatsRepository) queryDashboardTopModels(limit int) ([]domain.DashboardModelStats, error) { - query := ` - SELECT - model, - SUM(total_requests) as requests - FROM usage_stats - WHERE granularity = 'month' AND model != '' - GROUP BY model - ORDER BY requests DESC - LIMIT ? - ` - - rows, err := r.db.gorm.Raw(query, limit).Rows() - if err != nil { - return nil, err - } - defer rows.Close() - - var result []domain.DashboardModelStats - for rows.Next() { - var model string - var requests uint64 - if err := rows.Scan(&model, &requests); err != nil { - continue - } - result = append(result, domain.DashboardModelStats{ - Model: model, - Requests: requests, - }) - } - - return result, nil -} diff --git a/internal/service/admin.go b/internal/service/admin.go index a6f08c00..30198f6f 100644 --- a/internal/service/admin.go +++ b/internal/service/admin.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "log" "net" "net/http" "strconv" @@ -11,7 +12,10 @@ import ( "time" "github.com/awsl-project/maxx/internal/domain" + "github.com/awsl-project/maxx/internal/event" + "github.com/awsl-project/maxx/internal/pricing" "github.com/awsl-project/maxx/internal/repository" + "github.com/awsl-project/maxx/internal/usage" "github.com/awsl-project/maxx/internal/version" ) @@ -40,6 +44,7 @@ type AdminService struct { responseModelRepo repository.ResponseModelRepository serverAddr string adapterRefresher ProviderAdapterRefresher + broadcaster event.Broadcaster } // NewAdminService creates a new admin service @@ -59,6 +64,7 @@ func NewAdminService( responseModelRepo repository.ResponseModelRepository, serverAddr string, adapterRefresher ProviderAdapterRefresher, + broadcaster event.Broadcaster, ) *AdminService { return &AdminService{ providerRepo: providerRepo, @@ -76,6 +82,7 @@ func NewAdminService( responseModelRepo: responseModelRepo, serverAddr: serverAddr, adapterRefresher: adapterRefresher, + broadcaster: broadcaster, } } @@ -360,8 +367,8 @@ type CursorPaginationResult struct { LastID uint64 `json:"lastId,omitempty"` } -func (s *AdminService) GetProxyRequestsCursor(limit int, before, after uint64) (*CursorPaginationResult, error) { - items, err := s.proxyRequestRepo.ListCursor(limit+1, before, after) +func (s *AdminService) GetProxyRequestsCursor(limit int, before, after uint64, filter *repository.ProxyRequestFilter) (*CursorPaginationResult, error) { + items, err := s.proxyRequestRepo.ListCursor(limit+1, before, after, filter) if err != nil { return nil, err } @@ -388,6 +395,10 @@ func (s *AdminService) GetProxyRequestsCount() (int64, error) { return s.proxyRequestRepo.Count() } +func (s *AdminService) GetProxyRequestsCountWithFilter(filter *repository.ProxyRequestFilter) (int64, error) { + return s.proxyRequestRepo.CountWithFilter(filter) +} + func (s *AdminService) GetProxyRequest(id uint64) (*domain.ProxyRequest, error) { return s.proxyRequestRepo.GetByID(id) } @@ -471,10 +482,8 @@ func (s *AdminService) GetProxyStatus(r *http.Request) *ProxyStatus { port = p } // displayAddr 保持 host:port 格式不变 - } else { - // 地址不包含端口,说明是标准端口 80 - // displayAddr 保持原样(不带端口) } + // else: 地址不包含端口,说明是标准端口 80,displayAddr 保持原样 return &ProxyStatus{ Running: true, @@ -652,9 +661,8 @@ func (s *AdminService) GetAvailableClientTypes() []domain.ClientType { // ===== Usage Stats API ===== // GetUsageStats queries usage statistics with optional filters -// Uses QueryWithRealtime to include current period's real-time data func (s *AdminService) GetUsageStats(filter repository.UsageStatsFilter) ([]*domain.UsageStats, error) { - return s.usageStatsRepo.QueryWithRealtime(filter) + return s.usageStatsRepo.Query(filter) } // GetDashboardData returns all dashboard data in a single query @@ -662,7 +670,269 @@ func (s *AdminService) GetDashboardData() (*domain.DashboardData, error) { return s.usageStatsRepo.QueryDashboardData() } +// RecalculateUsageStatsProgress represents progress update for usage stats recalculation +type RecalculateUsageStatsProgress struct { + Phase string `json:"phase"` // "clearing", "aggregating", "rollup", "completed" + Current int `json:"current"` // Current step being processed + Total int `json:"total"` // Total steps to process + Percentage int `json:"percentage"` // 0-100 + Message string `json:"message"` // Human-readable message +} + // RecalculateUsageStats clears all usage stats and recalculates from raw data +// This only re-aggregates usage stats, it does NOT recalculate costs func (s *AdminService) RecalculateUsageStats() error { - return s.usageStatsRepo.ClearAndRecalculate() + // Create progress channel + progressChan := make(chan domain.Progress, 10) + + // Start goroutine to listen to progress and broadcast via WebSocket + go func() { + for progress := range progressChan { + if s.broadcaster != nil { + s.broadcaster.BroadcastMessage("recalculate_stats_progress", RecalculateUsageStatsProgress{ + Phase: progress.Phase, + Current: progress.Current, + Total: progress.Total, + Percentage: progress.Percentage, + Message: progress.Message, + }) + } + } + }() + + // Call repository method with progress channel + err := s.usageStatsRepo.ClearAndRecalculateWithProgress(progressChan) + + // Close channel when done + close(progressChan) + + return err +} + +// RecalculateCostsResult holds the result of cost recalculation +type RecalculateCostsResult struct { + TotalAttempts int `json:"totalAttempts"` + UpdatedAttempts int `json:"updatedAttempts"` + UpdatedRequests int `json:"updatedRequests"` + Message string `json:"message"` +} + +// RecalculateCostsProgress represents progress update for cost recalculation +type RecalculateCostsProgress struct { + Phase string `json:"phase"` // "calculating", "updating_attempts", "updating_requests", "aggregating_stats", "completed" + Current int `json:"current"` // Current item being processed + Total int `json:"total"` // Total items to process + Percentage int `json:"percentage"` // 0-100 + Message string `json:"message"` // Human-readable message +} + +// RecalculateCosts recalculates cost for all attempts using the current price table +// and updates the parent requests' cost accordingly (with streaming batch processing) +func (s *AdminService) RecalculateCosts() (*RecalculateCostsResult, error) { + result := &RecalculateCostsResult{} + + // Helper to broadcast progress + broadcastProgress := func(phase string, current, total int, message string) { + if s.broadcaster == nil { + return + } + percentage := 0 + if total > 0 { + percentage = current * 100 / total + } + s.broadcaster.BroadcastMessage("recalculate_costs_progress", RecalculateCostsProgress{ + Phase: phase, + Current: current, + Total: total, + Percentage: percentage, + Message: message, + }) + } + + // 1. Get total count first + broadcastProgress("calculating", 0, 0, "Counting attempts...") + totalCount, err := s.attemptRepo.CountAll() + if err != nil { + return nil, fmt.Errorf("failed to count attempts: %w", err) + } + result.TotalAttempts = int(totalCount) + + if totalCount == 0 { + result.Message = "No attempts to recalculate" + broadcastProgress("completed", 0, 0, result.Message) + return result, nil + } + + broadcastProgress("calculating", 0, int(totalCount), fmt.Sprintf("Processing %d attempts...", totalCount)) + + calculator := pricing.GlobalCalculator() + processedCount := 0 + const batchSize = 100 + affectedRequestIDs := make(map[uint64]struct{}) + + // 2. Stream through attempts, process and update each batch immediately + err = s.attemptRepo.StreamForCostCalc(batchSize, func(batch []*domain.AttemptCostData) error { + attemptUpdates := make(map[uint64]uint64, len(batch)) + + for _, attempt := range batch { + // Use responseModel if available, otherwise use mappedModel or requestModel + model := attempt.ResponseModel + if model == "" { + model = attempt.MappedModel + } + if model == "" { + model = attempt.RequestModel + } + + // Build metrics from attempt data + metrics := &usage.Metrics{ + InputTokens: attempt.InputTokenCount, + OutputTokens: attempt.OutputTokenCount, + CacheReadCount: attempt.CacheReadCount, + CacheCreationCount: attempt.CacheWriteCount, + Cache5mCreationCount: attempt.Cache5mWriteCount, + Cache1hCreationCount: attempt.Cache1hWriteCount, + } + + // Calculate new cost + newCost := calculator.Calculate(model, metrics) + + // Track affected request IDs + affectedRequestIDs[attempt.ProxyRequestID] = struct{}{} + + // Track if attempt needs update + if newCost != attempt.Cost { + attemptUpdates[attempt.ID] = newCost + } + + processedCount++ + } + + // Batch update attempt costs immediately + if len(attemptUpdates) > 0 { + if err := s.attemptRepo.BatchUpdateCosts(attemptUpdates); err != nil { + log.Printf("[RecalculateCosts] Failed to batch update attempts: %v", err) + } else { + result.UpdatedAttempts += len(attemptUpdates) + } + } + + // Broadcast progress + broadcastProgress("calculating", processedCount, int(totalCount), + fmt.Sprintf("Processed %d/%d attempts", processedCount, totalCount)) + + // Small delay to allow UI to update (WebSocket messages need time to be processed) + time.Sleep(50 * time.Millisecond) + + return nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to stream attempts: %w", err) + } + + // 3. Recalculate request costs from attempts (with progress via channel) + progressChan := make(chan domain.Progress, 10) + go func() { + for progress := range progressChan { + broadcastProgress(progress.Phase, progress.Current, progress.Total, progress.Message) + } + }() + + updatedRequests, err := s.proxyRequestRepo.RecalculateCostsFromAttemptsWithProgress(progressChan) + close(progressChan) + + if err != nil { + log.Printf("[RecalculateCosts] Failed to recalculate request costs: %v", err) + } else { + result.UpdatedRequests = int(updatedRequests) + } + + broadcastProgress("updating_requests", result.UpdatedRequests, result.UpdatedRequests, + fmt.Sprintf("Updated %d requests", result.UpdatedRequests)) + + result.Message = fmt.Sprintf("Recalculated %d attempts, updated %d attempts and %d requests", + result.TotalAttempts, result.UpdatedAttempts, result.UpdatedRequests) + + broadcastProgress("completed", 100, 100, result.Message) + + log.Printf("[RecalculateCosts] %s", result.Message) + return result, nil +} + +// RecalculateRequestCostResult holds the result of single request cost recalculation +type RecalculateRequestCostResult struct { + RequestID uint64 `json:"requestId"` + OldCost uint64 `json:"oldCost"` + NewCost uint64 `json:"newCost"` + UpdatedAttempts int `json:"updatedAttempts"` + Message string `json:"message"` +} + +// RecalculateRequestCost recalculates cost for a single request and its attempts +func (s *AdminService) RecalculateRequestCost(requestID uint64) (*RecalculateRequestCostResult, error) { + result := &RecalculateRequestCostResult{RequestID: requestID} + + // 1. Get the request + request, err := s.proxyRequestRepo.GetByID(requestID) + if err != nil { + return nil, fmt.Errorf("failed to get request: %w", err) + } + result.OldCost = request.Cost + + // 2. Get all attempts for this request + attempts, err := s.attemptRepo.ListByProxyRequestID(requestID) + if err != nil { + return nil, fmt.Errorf("failed to list attempts: %w", err) + } + + calculator := pricing.GlobalCalculator() + var totalCost uint64 + + // 3. Recalculate cost for each attempt + for _, attempt := range attempts { + // Use responseModel if available, otherwise use mappedModel or requestModel + model := attempt.ResponseModel + if model == "" { + model = attempt.MappedModel + } + if model == "" { + model = attempt.RequestModel + } + + // Build metrics from attempt data + metrics := &usage.Metrics{ + InputTokens: attempt.InputTokenCount, + OutputTokens: attempt.OutputTokenCount, + CacheReadCount: attempt.CacheReadCount, + CacheCreationCount: attempt.CacheWriteCount, + Cache5mCreationCount: attempt.Cache5mWriteCount, + Cache1hCreationCount: attempt.Cache1hWriteCount, + } + + // Calculate new cost + newCost := calculator.Calculate(model, metrics) + totalCost += newCost + + // Update attempt cost if changed + if newCost != attempt.Cost { + if err := s.attemptRepo.UpdateCost(attempt.ID, newCost); err != nil { + log.Printf("[RecalculateRequestCost] Failed to update attempt %d cost: %v", attempt.ID, err) + continue + } + result.UpdatedAttempts++ + } + } + + // 4. Update request cost + result.NewCost = totalCost + if err := s.proxyRequestRepo.UpdateCost(requestID, totalCost); err != nil { + return nil, fmt.Errorf("failed to update request cost: %w", err) + } + + result.Message = fmt.Sprintf("Recalculated request %d: %d -> %d (updated %d attempts)", + requestID, result.OldCost, result.NewCost, result.UpdatedAttempts) + + log.Printf("[RecalculateRequestCost] %s", result.Message) + return result, nil } diff --git a/internal/service/backup.go b/internal/service/backup.go index 4aaee1a4..8da6ed9a 100644 --- a/internal/service/backup.go +++ b/internal/service/backup.go @@ -589,7 +589,7 @@ func (s *BackupService) importRoutes(routes []domain.BackupRoute, opts domain.Im continue case "overwrite": summary.Skipped++ - result.Warnings = append(result.Warnings, fmt.Sprintf("Route overwrite not supported, skipped")) + result.Warnings = append(result.Warnings, "Route overwrite not supported, skipped") continue case "error": result.Success = false diff --git a/internal/stats/aggregator.go b/internal/stats/aggregator.go index 8aead1d1..e7561c43 100644 --- a/internal/stats/aggregator.go +++ b/internal/stats/aggregator.go @@ -5,7 +5,7 @@ import ( ) // StatsAggregator 统计数据聚合器 -// 仅支持定时同步模式,实时数据由 QueryWithRealtime 直接查询 +// 仅支持定时同步模式,实时数据由 Query 方法直接查询 type StatsAggregator struct { usageStatsRepo repository.UsageStatsRepository } @@ -17,7 +17,10 @@ func NewStatsAggregator(usageStatsRepo repository.UsageStatsRepository) *StatsAg } } -// RunPeriodicSync 定期同步分钟级数据 +// RunPeriodicSync 定期同步统计数据(聚合 + rollup) +// 通过 range channel 等待所有阶段完成 func (sa *StatsAggregator) RunPeriodicSync() { - _, _ = sa.usageStatsRepo.AggregateMinute() + for range sa.usageStatsRepo.AggregateAndRollUp() { + // drain the channel to wait for completion + } } diff --git a/internal/stats/pure.go b/internal/stats/pure.go new file mode 100644 index 00000000..de86d0fa --- /dev/null +++ b/internal/stats/pure.go @@ -0,0 +1,344 @@ +// Package stats provides pure functions for usage statistics aggregation and rollup. +// These functions are separated from the repository layer to enable easier testing +// and to ensure the aggregation logic is correct and predictable. +package stats + +import ( + "time" + + "github.com/awsl-project/maxx/internal/domain" +) + +// AttemptRecord represents a single upstream attempt record for aggregation. +// This is a simplified representation of the data needed for minute-level aggregation. +type AttemptRecord struct { + EndTime time.Time + RouteID uint64 + ProviderID uint64 + ProjectID uint64 + APITokenID uint64 + ClientType string + Model string // response_model + IsSuccessful bool + IsFailed bool + DurationMs uint64 + TTFTMs uint64 // Time To First Token (milliseconds) + InputTokens uint64 + OutputTokens uint64 + CacheRead uint64 + CacheWrite uint64 + Cost uint64 +} + +// TruncateToGranularity truncates a time to the start of its time bucket +// based on granularity using the specified timezone. +// The loc parameter is required and must not be nil. +func TruncateToGranularity(t time.Time, g domain.Granularity, loc *time.Location) time.Time { + t = t.In(loc) + switch g { + case domain.GranularityMinute: + return t.Truncate(time.Minute) + case domain.GranularityHour: + return t.Truncate(time.Hour) + case domain.GranularityDay: + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc) + case domain.GranularityMonth: + return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, loc) + default: + return t.Truncate(time.Hour) + } +} + +// AggregateAttempts aggregates a list of attempt records into UsageStats by minute. +// This is a pure function that takes raw attempt data and returns aggregated stats. +// The loc parameter specifies the timezone for time bucket calculation. +func AggregateAttempts(records []AttemptRecord, loc *time.Location) []*domain.UsageStats { + if len(records) == 0 { + return nil + } + + type aggKey struct { + minuteBucket int64 + routeID uint64 + providerID uint64 + projectID uint64 + apiTokenID uint64 + clientType string + model string + } + statsMap := make(map[aggKey]*domain.UsageStats) + + for _, r := range records { + minuteBucket := TruncateToGranularity(r.EndTime, domain.GranularityMinute, loc).UnixMilli() + + key := aggKey{ + minuteBucket: minuteBucket, + routeID: r.RouteID, + providerID: r.ProviderID, + projectID: r.ProjectID, + apiTokenID: r.APITokenID, + clientType: r.ClientType, + model: r.Model, + } + + var successful, failed uint64 + if r.IsSuccessful { + successful = 1 + } + if r.IsFailed { + failed = 1 + } + + if s, ok := statsMap[key]; ok { + s.TotalRequests++ + s.SuccessfulRequests += successful + s.FailedRequests += failed + s.TotalDurationMs += r.DurationMs + s.TotalTTFTMs += r.TTFTMs + s.InputTokens += r.InputTokens + s.OutputTokens += r.OutputTokens + s.CacheRead += r.CacheRead + s.CacheWrite += r.CacheWrite + s.Cost += r.Cost + } else { + statsMap[key] = &domain.UsageStats{ + Granularity: domain.GranularityMinute, + TimeBucket: time.UnixMilli(minuteBucket), + RouteID: r.RouteID, + ProviderID: r.ProviderID, + ProjectID: r.ProjectID, + APITokenID: r.APITokenID, + ClientType: r.ClientType, + Model: r.Model, + TotalRequests: 1, + SuccessfulRequests: successful, + FailedRequests: failed, + TotalDurationMs: r.DurationMs, + TotalTTFTMs: r.TTFTMs, + InputTokens: r.InputTokens, + OutputTokens: r.OutputTokens, + CacheRead: r.CacheRead, + CacheWrite: r.CacheWrite, + Cost: r.Cost, + } + } + } + + result := make([]*domain.UsageStats, 0, len(statsMap)) + for _, s := range statsMap { + result = append(result, s) + } + return result +} + +// RollUp aggregates stats from a finer granularity to a coarser granularity. +// It takes a list of source stats and returns aggregated stats at the target granularity. +// The loc parameter specifies the timezone for time bucket calculation. +func RollUp(stats []*domain.UsageStats, to domain.Granularity, loc *time.Location) []*domain.UsageStats { + if len(stats) == 0 { + return nil + } + + type rollupKey struct { + targetBucket int64 + routeID uint64 + providerID uint64 + projectID uint64 + apiTokenID uint64 + clientType string + model string + } + statsMap := make(map[rollupKey]*domain.UsageStats) + + for _, s := range stats { + targetBucket := TruncateToGranularity(s.TimeBucket, to, loc) + + key := rollupKey{ + targetBucket: targetBucket.UnixMilli(), + routeID: s.RouteID, + providerID: s.ProviderID, + projectID: s.ProjectID, + apiTokenID: s.APITokenID, + clientType: s.ClientType, + model: s.Model, + } + + if existing, ok := statsMap[key]; ok { + existing.TotalRequests += s.TotalRequests + existing.SuccessfulRequests += s.SuccessfulRequests + existing.FailedRequests += s.FailedRequests + existing.TotalDurationMs += s.TotalDurationMs + existing.TotalTTFTMs += s.TotalTTFTMs + existing.InputTokens += s.InputTokens + existing.OutputTokens += s.OutputTokens + existing.CacheRead += s.CacheRead + existing.CacheWrite += s.CacheWrite + existing.Cost += s.Cost + } else { + statsMap[key] = &domain.UsageStats{ + Granularity: to, + TimeBucket: targetBucket, + RouteID: s.RouteID, + ProviderID: s.ProviderID, + ProjectID: s.ProjectID, + APITokenID: s.APITokenID, + ClientType: s.ClientType, + Model: s.Model, + TotalRequests: s.TotalRequests, + SuccessfulRequests: s.SuccessfulRequests, + FailedRequests: s.FailedRequests, + TotalDurationMs: s.TotalDurationMs, + TotalTTFTMs: s.TotalTTFTMs, + InputTokens: s.InputTokens, + OutputTokens: s.OutputTokens, + CacheRead: s.CacheRead, + CacheWrite: s.CacheWrite, + Cost: s.Cost, + } + } + } + + result := make([]*domain.UsageStats, 0, len(statsMap)) + for _, s := range statsMap { + result = append(result, s) + } + return result +} + +// MergeStats merges multiple UsageStats slices into one, combining stats with matching keys. +// This is useful for combining pre-aggregated data with real-time data. +func MergeStats(statsList ...[]*domain.UsageStats) []*domain.UsageStats { + type mergeKey struct { + granularity domain.Granularity + timeBucket int64 + routeID uint64 + providerID uint64 + projectID uint64 + apiTokenID uint64 + clientType string + model string + } + merged := make(map[mergeKey]*domain.UsageStats) + + for _, stats := range statsList { + for _, s := range stats { + key := mergeKey{ + granularity: s.Granularity, + timeBucket: s.TimeBucket.UnixMilli(), + routeID: s.RouteID, + providerID: s.ProviderID, + projectID: s.ProjectID, + apiTokenID: s.APITokenID, + clientType: s.ClientType, + model: s.Model, + } + + if existing, ok := merged[key]; ok { + existing.TotalRequests += s.TotalRequests + existing.SuccessfulRequests += s.SuccessfulRequests + existing.FailedRequests += s.FailedRequests + existing.TotalDurationMs += s.TotalDurationMs + existing.TotalTTFTMs += s.TotalTTFTMs + existing.InputTokens += s.InputTokens + existing.OutputTokens += s.OutputTokens + existing.CacheRead += s.CacheRead + existing.CacheWrite += s.CacheWrite + existing.Cost += s.Cost + } else { + // Make a copy to avoid modifying the original + copied := *s + merged[key] = &copied + } + } + } + + result := make([]*domain.UsageStats, 0, len(merged)) + for _, s := range merged { + result = append(result, s) + } + return result +} + +// SumStats calculates the summary of a list of UsageStats. +// Returns total requests, successful requests, failed requests, input tokens, output tokens, +// cache read, cache write, and cost. +func SumStats(stats []*domain.UsageStats) (totalReq, successReq, failedReq, inputTokens, outputTokens, cacheRead, cacheWrite, cost uint64) { + for _, s := range stats { + totalReq += s.TotalRequests + successReq += s.SuccessfulRequests + failedReq += s.FailedRequests + inputTokens += s.InputTokens + outputTokens += s.OutputTokens + cacheRead += s.CacheRead + cacheWrite += s.CacheWrite + cost += s.Cost + } + return +} + +// GroupByProvider groups stats by provider ID and sums them. +// Returns a map of provider ID to aggregated totals. +func GroupByProvider(stats []*domain.UsageStats) map[uint64]*domain.ProviderStats { + result := make(map[uint64]*domain.ProviderStats) + + for _, s := range stats { + if s.ProviderID == 0 { + continue + } + + if existing, ok := result[s.ProviderID]; ok { + existing.TotalRequests += s.TotalRequests + existing.SuccessfulRequests += s.SuccessfulRequests + existing.FailedRequests += s.FailedRequests + existing.TotalInputTokens += s.InputTokens + existing.TotalOutputTokens += s.OutputTokens + existing.TotalCacheRead += s.CacheRead + existing.TotalCacheWrite += s.CacheWrite + existing.TotalCost += s.Cost + } else { + result[s.ProviderID] = &domain.ProviderStats{ + ProviderID: s.ProviderID, + TotalRequests: s.TotalRequests, + SuccessfulRequests: s.SuccessfulRequests, + FailedRequests: s.FailedRequests, + TotalInputTokens: s.InputTokens, + TotalOutputTokens: s.OutputTokens, + TotalCacheRead: s.CacheRead, + TotalCacheWrite: s.CacheWrite, + TotalCost: s.Cost, + } + } + } + + // Calculate success rate + for _, ps := range result { + if ps.TotalRequests > 0 { + ps.SuccessRate = float64(ps.SuccessfulRequests) / float64(ps.TotalRequests) * 100 + } + } + + return result +} + +// FilterByGranularity filters stats to only include the specified granularity. +func FilterByGranularity(stats []*domain.UsageStats, g domain.Granularity) []*domain.UsageStats { + result := make([]*domain.UsageStats, 0) + for _, s := range stats { + if s.Granularity == g { + result = append(result, s) + } + } + return result +} + +// FilterByTimeRange filters stats to only include those within the specified time range. +// start is inclusive, end is exclusive. +func FilterByTimeRange(stats []*domain.UsageStats, start, end time.Time) []*domain.UsageStats { + result := make([]*domain.UsageStats, 0) + for _, s := range stats { + if !s.TimeBucket.Before(start) && s.TimeBucket.Before(end) { + result = append(result, s) + } + } + return result +} diff --git a/internal/stats/pure_test.go b/internal/stats/pure_test.go new file mode 100644 index 00000000..6898ede7 --- /dev/null +++ b/internal/stats/pure_test.go @@ -0,0 +1,1492 @@ +package stats + +import ( + "testing" + "time" + + "github.com/awsl-project/maxx/internal/domain" +) + +func TestTruncateToGranularity(t *testing.T) { + // 2024-01-17 14:35:42 UTC (Wednesday) + testTime := time.Date(2024, 1, 17, 14, 35, 42, 123456789, time.UTC) + + tests := []struct { + name string + granularity domain.Granularity + expected time.Time + }{ + { + name: "minute", + granularity: domain.GranularityMinute, + expected: time.Date(2024, 1, 17, 14, 35, 0, 0, time.UTC), + }, + { + name: "hour", + granularity: domain.GranularityHour, + expected: time.Date(2024, 1, 17, 14, 0, 0, 0, time.UTC), + }, + { + name: "day", + granularity: domain.GranularityDay, + expected: time.Date(2024, 1, 17, 0, 0, 0, 0, time.UTC), + }, + { + name: "month", + granularity: domain.GranularityMonth, + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "unknown granularity defaults to hour", + granularity: domain.Granularity("unknown"), + expected: time.Date(2024, 1, 17, 14, 0, 0, 0, time.UTC), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := TruncateToGranularity(testTime, tt.granularity, time.UTC) + if !result.Equal(tt.expected) { + t.Errorf("TruncateToGranularity(%v, %v, UTC) = %v, want %v", + testTime, tt.granularity, result, tt.expected) + } + }) + } +} + +func TestTruncateToGranularity_Timezone(t *testing.T) { + shanghai, _ := time.LoadLocation("Asia/Shanghai") + tokyo, _ := time.LoadLocation("Asia/Tokyo") + + // 2024-01-17 02:30:00 UTC = 2024-01-17 10:30:00 Shanghai = 2024-01-17 11:30:00 Tokyo + testTimeUTC := time.Date(2024, 1, 17, 2, 30, 0, 0, time.UTC) + + tests := []struct { + name string + loc *time.Location + granularity domain.Granularity + expected time.Time + }{ + { + name: "day in Shanghai timezone", + loc: shanghai, + granularity: domain.GranularityDay, + expected: time.Date(2024, 1, 17, 0, 0, 0, 0, shanghai), + }, + { + name: "hour in Shanghai timezone", + loc: shanghai, + granularity: domain.GranularityHour, + expected: time.Date(2024, 1, 17, 10, 0, 0, 0, shanghai), + }, + { + name: "minute in Shanghai timezone", + loc: shanghai, + granularity: domain.GranularityMinute, + expected: time.Date(2024, 1, 17, 10, 30, 0, 0, shanghai), + }, + { + name: "day in Tokyo timezone", + loc: tokyo, + granularity: domain.GranularityDay, + expected: time.Date(2024, 1, 17, 0, 0, 0, 0, tokyo), + }, + { + name: "month in Shanghai timezone", + loc: shanghai, + granularity: domain.GranularityMonth, + expected: time.Date(2024, 1, 1, 0, 0, 0, 0, shanghai), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := TruncateToGranularity(testTimeUTC, tt.granularity, tt.loc) + if !result.Equal(tt.expected) { + t.Errorf("TruncateToGranularity(%v, %v, %v) = %v, want %v", + testTimeUTC, tt.granularity, tt.loc, result, tt.expected) + } + }) + } +} + +func TestTruncateToGranularity_DayBoundary(t *testing.T) { + shanghai, _ := time.LoadLocation("Asia/Shanghai") + + // 2024-01-17 23:30:00 UTC = 2024-01-18 07:30:00 Shanghai + // This is a different day in Shanghai than in UTC + testTimeUTC := time.Date(2024, 1, 17, 23, 30, 0, 0, time.UTC) + + utcDay := TruncateToGranularity(testTimeUTC, domain.GranularityDay, time.UTC) + shanghaiDay := TruncateToGranularity(testTimeUTC, domain.GranularityDay, shanghai) + + expectedUTCDay := time.Date(2024, 1, 17, 0, 0, 0, 0, time.UTC) + expectedShanghaiDay := time.Date(2024, 1, 18, 0, 0, 0, 0, shanghai) + + if !utcDay.Equal(expectedUTCDay) { + t.Errorf("UTC day = %v, want %v", utcDay, expectedUTCDay) + } + if !shanghaiDay.Equal(expectedShanghaiDay) { + t.Errorf("Shanghai day = %v, want %v", shanghaiDay, expectedShanghaiDay) + } +} + +func TestAggregateAttempts_Empty(t *testing.T) { + result := AggregateAttempts(nil, time.UTC) + if result != nil { + t.Errorf("expected nil for empty records, got %v", result) + } + + result = AggregateAttempts([]AttemptRecord{}, time.UTC) + if result != nil { + t.Errorf("expected nil for empty slice, got %v", result) + } +} + +func TestAggregateAttempts_Single(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 15, 0, time.UTC) + + records := []AttemptRecord{ + { + EndTime: baseTime, + ProviderID: 1, + ProjectID: 2, + RouteID: 3, + APITokenID: 4, + ClientType: "claude", + Model: "claude-3", + IsSuccessful: true, + InputTokens: 100, + OutputTokens: 50, + DurationMs: 1000, + CacheRead: 10, + CacheWrite: 5, + Cost: 1000, + }, + } + + result := AggregateAttempts(records, time.UTC) + + if len(result) != 1 { + t.Fatalf("expected 1 result, got %d", len(result)) + } + + s := result[0] + if s.TotalRequests != 1 { + t.Errorf("TotalRequests = %d, want 1", s.TotalRequests) + } + if s.SuccessfulRequests != 1 { + t.Errorf("SuccessfulRequests = %d, want 1", s.SuccessfulRequests) + } + if s.FailedRequests != 0 { + t.Errorf("FailedRequests = %d, want 0", s.FailedRequests) + } + if s.InputTokens != 100 { + t.Errorf("InputTokens = %d, want 100", s.InputTokens) + } + if s.OutputTokens != 50 { + t.Errorf("OutputTokens = %d, want 50", s.OutputTokens) + } + if s.TotalDurationMs != 1000 { + t.Errorf("TotalDurationMs = %d, want 1000", s.TotalDurationMs) + } + if s.CacheRead != 10 { + t.Errorf("CacheRead = %d, want 10", s.CacheRead) + } + if s.CacheWrite != 5 { + t.Errorf("CacheWrite = %d, want 5", s.CacheWrite) + } + if s.Cost != 1000 { + t.Errorf("Cost = %d, want 1000", s.Cost) + } + if s.ProviderID != 1 { + t.Errorf("ProviderID = %d, want 1", s.ProviderID) + } + if s.ProjectID != 2 { + t.Errorf("ProjectID = %d, want 2", s.ProjectID) + } + if s.RouteID != 3 { + t.Errorf("RouteID = %d, want 3", s.RouteID) + } + if s.APITokenID != 4 { + t.Errorf("APITokenID = %d, want 4", s.APITokenID) + } + if s.ClientType != "claude" { + t.Errorf("ClientType = %s, want claude", s.ClientType) + } + if s.Model != "claude-3" { + t.Errorf("Model = %s, want claude-3", s.Model) + } + if s.Granularity != domain.GranularityMinute { + t.Errorf("Granularity = %v, want minute", s.Granularity) + } +} + +func TestAggregateAttempts_SameMinute(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + records := []AttemptRecord{ + { + EndTime: baseTime.Add(10 * time.Second), + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 100, + OutputTokens: 50, + Cost: 1000, + }, + { + EndTime: baseTime.Add(20 * time.Second), + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 200, + OutputTokens: 100, + Cost: 2000, + }, + { + EndTime: baseTime.Add(30 * time.Second), + ProviderID: 1, + Model: "claude-3", + IsFailed: true, + Cost: 0, + }, + } + + result := AggregateAttempts(records, time.UTC) + + if len(result) != 1 { + t.Fatalf("expected 1 aggregated result, got %d", len(result)) + } + + s := result[0] + if s.TotalRequests != 3 { + t.Errorf("TotalRequests = %d, want 3", s.TotalRequests) + } + if s.SuccessfulRequests != 2 { + t.Errorf("SuccessfulRequests = %d, want 2", s.SuccessfulRequests) + } + if s.FailedRequests != 1 { + t.Errorf("FailedRequests = %d, want 1", s.FailedRequests) + } + if s.InputTokens != 300 { + t.Errorf("InputTokens = %d, want 300", s.InputTokens) + } + if s.OutputTokens != 150 { + t.Errorf("OutputTokens = %d, want 150", s.OutputTokens) + } + if s.Cost != 3000 { + t.Errorf("Cost = %d, want 3000", s.Cost) + } +} + +func TestAggregateAttempts_DifferentMinutes(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + records := []AttemptRecord{ + { + EndTime: baseTime, + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 100, + }, + { + EndTime: baseTime.Add(1 * time.Minute), + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 200, + }, + } + + result := AggregateAttempts(records, time.UTC) + + if len(result) != 2 { + t.Fatalf("expected 2 results for different minutes, got %d", len(result)) + } +} + +func TestAggregateAttempts_DifferentProviders(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + records := []AttemptRecord{ + { + EndTime: baseTime, + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 100, + }, + { + EndTime: baseTime, + ProviderID: 2, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 200, + }, + } + + result := AggregateAttempts(records, time.UTC) + + if len(result) != 2 { + t.Fatalf("expected 2 results for different providers, got %d", len(result)) + } +} + +func TestAggregateAttempts_DifferentModels(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + records := []AttemptRecord{ + { + EndTime: baseTime, + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 100, + }, + { + EndTime: baseTime, + ProviderID: 1, + Model: "gpt-4", + IsSuccessful: true, + InputTokens: 200, + }, + } + + result := AggregateAttempts(records, time.UTC) + + if len(result) != 2 { + t.Fatalf("expected 2 results for different models, got %d", len(result)) + } +} + +func TestAggregateAttempts_DifferentDimensions(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + // Test all dimension variations + records := []AttemptRecord{ + {EndTime: baseTime, ProviderID: 1, ProjectID: 1, RouteID: 1, APITokenID: 1, ClientType: "a", Model: "m", InputTokens: 1}, + {EndTime: baseTime, ProviderID: 1, ProjectID: 2, RouteID: 1, APITokenID: 1, ClientType: "a", Model: "m", InputTokens: 2}, // diff project + {EndTime: baseTime, ProviderID: 1, ProjectID: 1, RouteID: 2, APITokenID: 1, ClientType: "a", Model: "m", InputTokens: 3}, // diff route + {EndTime: baseTime, ProviderID: 1, ProjectID: 1, RouteID: 1, APITokenID: 2, ClientType: "a", Model: "m", InputTokens: 4}, // diff token + {EndTime: baseTime, ProviderID: 1, ProjectID: 1, RouteID: 1, APITokenID: 1, ClientType: "b", Model: "m", InputTokens: 5}, // diff client + } + + result := AggregateAttempts(records, time.UTC) + + if len(result) != 5 { + t.Fatalf("expected 5 results for different dimensions, got %d", len(result)) + } + + var total uint64 + for _, s := range result { + total += s.InputTokens + } + if total != 15 { + t.Errorf("total input tokens = %d, want 15", total) + } +} + +func TestAggregateAttempts_WithTimezone(t *testing.T) { + shanghai, _ := time.LoadLocation("Asia/Shanghai") + + // 2024-01-17 23:30:00 UTC = 2024-01-18 07:30:00 Shanghai + // These should be in different minute buckets when using Shanghai timezone + utcTime := time.Date(2024, 1, 17, 23, 30, 30, 0, time.UTC) + + records := []AttemptRecord{ + { + EndTime: utcTime, + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 100, + }, + } + + result := AggregateAttempts(records, shanghai) + + if len(result) != 1 { + t.Fatalf("expected 1 result, got %d", len(result)) + } + + // The time bucket should be 2024-01-18 07:30:00 Shanghai + expected := time.Date(2024, 1, 18, 7, 30, 0, 0, shanghai) + if !result[0].TimeBucket.Equal(expected) { + t.Errorf("TimeBucket = %v, want %v", result[0].TimeBucket, expected) + } +} + +func TestRollUp_Empty(t *testing.T) { + result := RollUp(nil, domain.GranularityHour, time.UTC) + if result != nil { + t.Errorf("expected nil for empty stats, got %v", result) + } + + result = RollUp([]*domain.UsageStats{}, domain.GranularityHour, time.UTC) + if result != nil { + t.Errorf("expected nil for empty slice, got %v", result) + } +} + +func TestRollUp_MinuteToHour(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + minuteStats := []*domain.UsageStats{ + { + Granularity: domain.GranularityMinute, + TimeBucket: baseTime, + ProviderID: 1, + Model: "claude-3", + TotalRequests: 10, + SuccessfulRequests: 8, + FailedRequests: 2, + TotalDurationMs: 10000, + InputTokens: 1000, + OutputTokens: 500, + CacheRead: 100, + CacheWrite: 50, + Cost: 10000, + }, + { + Granularity: domain.GranularityMinute, + TimeBucket: baseTime.Add(15 * time.Minute), + ProviderID: 1, + Model: "claude-3", + TotalRequests: 5, + InputTokens: 500, + OutputTokens: 250, + Cost: 5000, + }, + { + Granularity: domain.GranularityMinute, + TimeBucket: baseTime.Add(30 * time.Minute), + ProviderID: 1, + Model: "claude-3", + TotalRequests: 8, + InputTokens: 800, + OutputTokens: 400, + Cost: 8000, + }, + } + + result := RollUp(minuteStats, domain.GranularityHour, time.UTC) + + if len(result) != 1 { + t.Fatalf("expected 1 hour bucket, got %d", len(result)) + } + + h := result[0] + if h.TotalRequests != 23 { + t.Errorf("TotalRequests = %d, want 23", h.TotalRequests) + } + if h.InputTokens != 2300 { + t.Errorf("InputTokens = %d, want 2300", h.InputTokens) + } + if h.OutputTokens != 1150 { + t.Errorf("OutputTokens = %d, want 1150", h.OutputTokens) + } + if h.Cost != 23000 { + t.Errorf("Cost = %d, want 23000", h.Cost) + } + if h.Granularity != domain.GranularityHour { + t.Errorf("Granularity = %v, want hour", h.Granularity) + } +} + +func TestRollUp_MinuteToDay(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + minuteStats := []*domain.UsageStats{ + {Granularity: domain.GranularityMinute, TimeBucket: baseTime, ProviderID: 1, TotalRequests: 10, InputTokens: 1000}, + {Granularity: domain.GranularityMinute, TimeBucket: baseTime.Add(60 * time.Minute), ProviderID: 1, TotalRequests: 5, InputTokens: 500}, + {Granularity: domain.GranularityMinute, TimeBucket: baseTime.Add(120 * time.Minute), ProviderID: 1, TotalRequests: 8, InputTokens: 800}, + } + + result := RollUp(minuteStats, domain.GranularityDay, time.UTC) + + if len(result) != 1 { + t.Fatalf("expected 1 day bucket, got %d", len(result)) + } + + if result[0].TotalRequests != 23 { + t.Errorf("TotalRequests = %d, want 23", result[0].TotalRequests) + } + if result[0].InputTokens != 2300 { + t.Errorf("InputTokens = %d, want 2300", result[0].InputTokens) + } +} + +func TestRollUp_DayToMonth(t *testing.T) { + day1 := time.Date(2024, 1, 5, 0, 0, 0, 0, time.UTC) + day15 := time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC) + day25 := time.Date(2024, 1, 25, 0, 0, 0, 0, time.UTC) + + dayStats := []*domain.UsageStats{ + {Granularity: domain.GranularityDay, TimeBucket: day1, ProviderID: 1, TotalRequests: 100, InputTokens: 10000}, + {Granularity: domain.GranularityDay, TimeBucket: day15, ProviderID: 1, TotalRequests: 200, InputTokens: 20000}, + {Granularity: domain.GranularityDay, TimeBucket: day25, ProviderID: 1, TotalRequests: 300, InputTokens: 30000}, + } + + result := RollUp(dayStats, domain.GranularityMonth, time.UTC) + + if len(result) != 1 { + t.Fatalf("expected 1 month bucket, got %d", len(result)) + } + + if result[0].TotalRequests != 600 { + t.Errorf("TotalRequests = %d, want 600", result[0].TotalRequests) + } + if result[0].InputTokens != 60000 { + t.Errorf("InputTokens = %d, want 60000", result[0].InputTokens) + } +} + +func TestRollUp_PreservesAggregationKey(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + stats := []*domain.UsageStats{ + { + Granularity: domain.GranularityMinute, + TimeBucket: baseTime, + ProviderID: 1, + ProjectID: 1, + RouteID: 1, + APITokenID: 1, + ClientType: "claude", + Model: "claude-3", + InputTokens: 100, + }, + { + Granularity: domain.GranularityMinute, + TimeBucket: baseTime.Add(5 * time.Minute), + ProviderID: 1, + ProjectID: 1, + RouteID: 1, + APITokenID: 1, + ClientType: "claude", + Model: "claude-3", + InputTokens: 100, + }, + { + Granularity: domain.GranularityMinute, + TimeBucket: baseTime, + ProviderID: 2, // Different provider + ProjectID: 1, + RouteID: 1, + APITokenID: 1, + ClientType: "openai", + Model: "gpt-4", + InputTokens: 200, + }, + } + + result := RollUp(stats, domain.GranularityHour, time.UTC) + + if len(result) != 2 { + t.Fatalf("expected 2 results, got %d", len(result)) + } + + var p1, p2 *domain.UsageStats + for _, s := range result { + switch s.ProviderID { + case 1: + p1 = s + case 2: + p2 = s + } + } + + if p1 == nil || p2 == nil { + t.Fatal("missing expected provider stats") + } + + if p1.InputTokens != 200 { + t.Errorf("provider 1 input tokens = %d, want 200", p1.InputTokens) + } + if p2.InputTokens != 200 { + t.Errorf("provider 2 input tokens = %d, want 200", p2.InputTokens) + } +} + +func TestRollUp_WithTimezone(t *testing.T) { + shanghai, _ := time.LoadLocation("Asia/Shanghai") + + // 2024-01-17 23:00:00 UTC = 2024-01-18 07:00:00 Shanghai + // 2024-01-18 01:00:00 UTC = 2024-01-18 09:00:00 Shanghai + // Both should be in the same day in Shanghai, but different days in UTC + time1 := time.Date(2024, 1, 17, 23, 0, 0, 0, time.UTC) + time2 := time.Date(2024, 1, 18, 1, 0, 0, 0, time.UTC) + + hourStats := []*domain.UsageStats{ + {Granularity: domain.GranularityHour, TimeBucket: time1, ProviderID: 1, TotalRequests: 100, InputTokens: 10000}, + {Granularity: domain.GranularityHour, TimeBucket: time2, ProviderID: 1, TotalRequests: 50, InputTokens: 5000}, + } + + // With UTC - should be 2 different days + resultUTC := RollUp(hourStats, domain.GranularityDay, time.UTC) + if len(resultUTC) != 2 { + t.Errorf("expected 2 day buckets in UTC, got %d", len(resultUTC)) + } + + // With Shanghai - should be 1 day + resultShanghai := RollUp(hourStats, domain.GranularityDay, shanghai) + if len(resultShanghai) != 1 { + t.Errorf("expected 1 day bucket in Shanghai, got %d", len(resultShanghai)) + } + if resultShanghai[0].TotalRequests != 150 { + t.Errorf("Shanghai total requests = %d, want 150", resultShanghai[0].TotalRequests) + } +} + +func TestMergeStats_Empty(t *testing.T) { + result := MergeStats() + if len(result) != 0 { + t.Errorf("expected empty result, got %d", len(result)) + } + + result = MergeStats(nil, nil) + if len(result) != 0 { + t.Errorf("expected empty result for nil slices, got %d", len(result)) + } +} + +func TestMergeStats_SingleList(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + list := []*domain.UsageStats{ + {Granularity: domain.GranularityHour, TimeBucket: baseTime, ProviderID: 1, InputTokens: 100}, + {Granularity: domain.GranularityHour, TimeBucket: baseTime, ProviderID: 2, InputTokens: 200}, + } + + result := MergeStats(list) + + if len(result) != 2 { + t.Fatalf("expected 2 results, got %d", len(result)) + } +} + +func TestMergeStats_MergeMatchingKeys(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + list1 := []*domain.UsageStats{ + { + Granularity: domain.GranularityHour, + TimeBucket: baseTime, + ProviderID: 1, + Model: "claude-3", + TotalRequests: 10, + SuccessfulRequests: 8, + FailedRequests: 2, + TotalDurationMs: 10000, + InputTokens: 100, + OutputTokens: 50, + CacheRead: 10, + CacheWrite: 5, + Cost: 1000, + }, + } + + list2 := []*domain.UsageStats{ + { + Granularity: domain.GranularityHour, + TimeBucket: baseTime, + ProviderID: 1, + Model: "claude-3", + TotalRequests: 5, + SuccessfulRequests: 5, + FailedRequests: 0, + TotalDurationMs: 5000, + InputTokens: 200, + OutputTokens: 100, + CacheRead: 20, + CacheWrite: 10, + Cost: 2000, + }, + } + + result := MergeStats(list1, list2) + + if len(result) != 1 { + t.Fatalf("expected 1 merged result, got %d", len(result)) + } + + s := result[0] + if s.TotalRequests != 15 { + t.Errorf("TotalRequests = %d, want 15", s.TotalRequests) + } + if s.SuccessfulRequests != 13 { + t.Errorf("SuccessfulRequests = %d, want 13", s.SuccessfulRequests) + } + if s.FailedRequests != 2 { + t.Errorf("FailedRequests = %d, want 2", s.FailedRequests) + } + if s.TotalDurationMs != 15000 { + t.Errorf("TotalDurationMs = %d, want 15000", s.TotalDurationMs) + } + if s.InputTokens != 300 { + t.Errorf("InputTokens = %d, want 300", s.InputTokens) + } + if s.OutputTokens != 150 { + t.Errorf("OutputTokens = %d, want 150", s.OutputTokens) + } + if s.CacheRead != 30 { + t.Errorf("CacheRead = %d, want 30", s.CacheRead) + } + if s.CacheWrite != 15 { + t.Errorf("CacheWrite = %d, want 15", s.CacheWrite) + } + if s.Cost != 3000 { + t.Errorf("Cost = %d, want 3000", s.Cost) + } +} + +func TestMergeStats_DifferentKeys(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + list1 := []*domain.UsageStats{ + {Granularity: domain.GranularityHour, TimeBucket: baseTime, ProviderID: 1, InputTokens: 100}, + } + + list2 := []*domain.UsageStats{ + {Granularity: domain.GranularityHour, TimeBucket: baseTime, ProviderID: 2, InputTokens: 200}, + } + + list3 := []*domain.UsageStats{ + {Granularity: domain.GranularityDay, TimeBucket: baseTime, ProviderID: 1, InputTokens: 300}, // Different granularity + } + + result := MergeStats(list1, list2, list3) + + if len(result) != 3 { + t.Fatalf("expected 3 results, got %d", len(result)) + } + + var total uint64 + for _, s := range result { + total += s.InputTokens + } + if total != 600 { + t.Errorf("total input tokens = %d, want 600", total) + } +} + +func TestMergeStats_DoesNotModifyOriginal(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + original := &domain.UsageStats{ + Granularity: domain.GranularityHour, + TimeBucket: baseTime, + ProviderID: 1, + InputTokens: 100, + } + + list1 := []*domain.UsageStats{original} + list2 := []*domain.UsageStats{ + {Granularity: domain.GranularityHour, TimeBucket: baseTime, ProviderID: 1, InputTokens: 200}, + } + + _ = MergeStats(list1, list2) + + // Original should not be modified + if original.InputTokens != 100 { + t.Errorf("original was modified: InputTokens = %d, want 100", original.InputTokens) + } +} + +func TestSumStats_Empty(t *testing.T) { + totalReq, successReq, failedReq, inputTokens, outputTokens, cacheRead, cacheWrite, cost := SumStats(nil) + + if totalReq != 0 || successReq != 0 || failedReq != 0 || inputTokens != 0 || + outputTokens != 0 || cacheRead != 0 || cacheWrite != 0 || cost != 0 { + t.Errorf("expected all zeros for empty stats") + } +} + +func TestSumStats(t *testing.T) { + stats := []*domain.UsageStats{ + { + TotalRequests: 10, + SuccessfulRequests: 8, + FailedRequests: 2, + InputTokens: 1000, + OutputTokens: 500, + CacheRead: 100, + CacheWrite: 50, + Cost: 10000, + }, + { + TotalRequests: 5, + SuccessfulRequests: 5, + FailedRequests: 0, + InputTokens: 500, + OutputTokens: 250, + CacheRead: 50, + CacheWrite: 25, + Cost: 5000, + }, + } + + totalReq, successReq, failedReq, inputTokens, outputTokens, cacheRead, cacheWrite, cost := SumStats(stats) + + if totalReq != 15 { + t.Errorf("totalReq = %d, want 15", totalReq) + } + if successReq != 13 { + t.Errorf("successReq = %d, want 13", successReq) + } + if failedReq != 2 { + t.Errorf("failedReq = %d, want 2", failedReq) + } + if inputTokens != 1500 { + t.Errorf("inputTokens = %d, want 1500", inputTokens) + } + if outputTokens != 750 { + t.Errorf("outputTokens = %d, want 750", outputTokens) + } + if cacheRead != 150 { + t.Errorf("cacheRead = %d, want 150", cacheRead) + } + if cacheWrite != 75 { + t.Errorf("cacheWrite = %d, want 75", cacheWrite) + } + if cost != 15000 { + t.Errorf("cost = %d, want 15000", cost) + } +} + +func TestGroupByProvider_Empty(t *testing.T) { + result := GroupByProvider(nil) + if len(result) != 0 { + t.Errorf("expected empty result, got %d", len(result)) + } +} + +func TestGroupByProvider_SkipsZeroProvider(t *testing.T) { + stats := []*domain.UsageStats{ + {ProviderID: 0, TotalRequests: 100, InputTokens: 10000}, + {ProviderID: 1, TotalRequests: 50, InputTokens: 5000}, + } + + result := GroupByProvider(stats) + + if len(result) != 1 { + t.Fatalf("expected 1 provider (skipping 0), got %d", len(result)) + } + if result[0] != nil { + t.Error("provider 0 should not be in result") + } + if result[1] == nil { + t.Fatal("provider 1 should be in result") + } + if result[1].TotalRequests != 50 { + t.Errorf("provider 1 TotalRequests = %d, want 50", result[1].TotalRequests) + } +} + +func TestGroupByProvider(t *testing.T) { + stats := []*domain.UsageStats{ + { + ProviderID: 1, + TotalRequests: 10, + SuccessfulRequests: 8, + FailedRequests: 2, + InputTokens: 1000, + OutputTokens: 500, + CacheRead: 100, + CacheWrite: 50, + Cost: 10000, + }, + { + ProviderID: 1, + TotalRequests: 5, + SuccessfulRequests: 5, + InputTokens: 500, + OutputTokens: 250, + CacheRead: 50, + CacheWrite: 25, + Cost: 5000, + }, + { + ProviderID: 2, + TotalRequests: 3, + SuccessfulRequests: 3, + InputTokens: 300, + OutputTokens: 150, + CacheRead: 30, + CacheWrite: 15, + Cost: 3000, + }, + } + + result := GroupByProvider(stats) + + if len(result) != 2 { + t.Fatalf("expected 2 providers, got %d", len(result)) + } + + p1 := result[1] + if p1 == nil { + t.Fatal("provider 1 not found") + } + if p1.ProviderID != 1 { + t.Errorf("ProviderID = %d, want 1", p1.ProviderID) + } + if p1.TotalRequests != 15 { + t.Errorf("provider 1 TotalRequests = %d, want 15", p1.TotalRequests) + } + if p1.SuccessfulRequests != 13 { + t.Errorf("provider 1 SuccessfulRequests = %d, want 13", p1.SuccessfulRequests) + } + if p1.FailedRequests != 2 { + t.Errorf("provider 1 FailedRequests = %d, want 2", p1.FailedRequests) + } + if p1.TotalInputTokens != 1500 { + t.Errorf("provider 1 TotalInputTokens = %d, want 1500", p1.TotalInputTokens) + } + if p1.TotalOutputTokens != 750 { + t.Errorf("provider 1 TotalOutputTokens = %d, want 750", p1.TotalOutputTokens) + } + if p1.TotalCacheRead != 150 { + t.Errorf("provider 1 TotalCacheRead = %d, want 150", p1.TotalCacheRead) + } + if p1.TotalCacheWrite != 75 { + t.Errorf("provider 1 TotalCacheWrite = %d, want 75", p1.TotalCacheWrite) + } + if p1.TotalCost != 15000 { + t.Errorf("provider 1 TotalCost = %d, want 15000", p1.TotalCost) + } + + // Success rate: 13/15 * 100 = 86.67% + expectedRate := float64(13) / float64(15) * 100 + if p1.SuccessRate != expectedRate { + t.Errorf("provider 1 SuccessRate = %f, want %f", p1.SuccessRate, expectedRate) + } + + p2 := result[2] + if p2 == nil { + t.Fatal("provider 2 not found") + } + if p2.TotalRequests != 3 { + t.Errorf("provider 2 TotalRequests = %d, want 3", p2.TotalRequests) + } + if p2.SuccessRate != 100 { + t.Errorf("provider 2 SuccessRate = %f, want 100", p2.SuccessRate) + } +} + +func TestGroupByProvider_ZeroRequests(t *testing.T) { + stats := []*domain.UsageStats{ + {ProviderID: 1, TotalRequests: 0, SuccessfulRequests: 0}, + } + + result := GroupByProvider(stats) + + if result[1].SuccessRate != 0 { + t.Errorf("SuccessRate = %f, want 0 for zero requests", result[1].SuccessRate) + } +} + +func TestFilterByGranularity_Empty(t *testing.T) { + result := FilterByGranularity(nil, domain.GranularityHour) + if len(result) != 0 { + t.Errorf("expected empty result, got %d", len(result)) + } +} + +func TestFilterByGranularity(t *testing.T) { + stats := []*domain.UsageStats{ + {Granularity: domain.GranularityMinute, InputTokens: 100}, + {Granularity: domain.GranularityHour, InputTokens: 200}, + {Granularity: domain.GranularityMinute, InputTokens: 300}, + {Granularity: domain.GranularityDay, InputTokens: 400}, + } + + result := FilterByGranularity(stats, domain.GranularityMinute) + + if len(result) != 2 { + t.Fatalf("expected 2 minute stats, got %d", len(result)) + } + + var total uint64 + for _, s := range result { + if s.Granularity != domain.GranularityMinute { + t.Errorf("unexpected granularity: %v", s.Granularity) + } + total += s.InputTokens + } + if total != 400 { + t.Errorf("total input = %d, want 400", total) + } +} + +func TestFilterByGranularity_NoMatch(t *testing.T) { + stats := []*domain.UsageStats{ + {Granularity: domain.GranularityMinute, InputTokens: 100}, + {Granularity: domain.GranularityHour, InputTokens: 200}, + } + + result := FilterByGranularity(stats, domain.GranularityMonth) + + if len(result) != 0 { + t.Errorf("expected empty result, got %d", len(result)) + } +} + +func TestFilterByTimeRange_Empty(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + result := FilterByTimeRange(nil, baseTime, baseTime.Add(time.Hour)) + if len(result) != 0 { + t.Errorf("expected empty result, got %d", len(result)) + } +} + +func TestFilterByTimeRange(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + stats := []*domain.UsageStats{ + {TimeBucket: baseTime, InputTokens: 100}, + {TimeBucket: baseTime.Add(1 * time.Hour), InputTokens: 200}, + {TimeBucket: baseTime.Add(2 * time.Hour), InputTokens: 300}, + {TimeBucket: baseTime.Add(3 * time.Hour), InputTokens: 400}, + } + + // Filter [10:00, 12:00) - should include 10:00 and 11:00 + result := FilterByTimeRange(stats, baseTime, baseTime.Add(2*time.Hour)) + + if len(result) != 2 { + t.Fatalf("expected 2 stats, got %d", len(result)) + } + + var total uint64 + for _, s := range result { + total += s.InputTokens + } + if total != 300 { + t.Errorf("total input = %d, want 300", total) + } +} + +func TestFilterByTimeRange_InclusiveStart(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + stats := []*domain.UsageStats{ + {TimeBucket: baseTime, InputTokens: 100}, + } + + result := FilterByTimeRange(stats, baseTime, baseTime.Add(time.Hour)) + + if len(result) != 1 { + t.Errorf("expected 1 stat (start is inclusive), got %d", len(result)) + } +} + +func TestFilterByTimeRange_ExclusiveEnd(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + stats := []*domain.UsageStats{ + {TimeBucket: baseTime.Add(time.Hour), InputTokens: 100}, + } + + result := FilterByTimeRange(stats, baseTime, baseTime.Add(time.Hour)) + + if len(result) != 0 { + t.Errorf("expected 0 stats (end is exclusive), got %d", len(result)) + } +} + +func TestFilterByTimeRange_NoMatch(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + stats := []*domain.UsageStats{ + {TimeBucket: baseTime.Add(-1 * time.Hour), InputTokens: 100}, + {TimeBucket: baseTime.Add(3 * time.Hour), InputTokens: 200}, + } + + result := FilterByTimeRange(stats, baseTime, baseTime.Add(2*time.Hour)) + + if len(result) != 0 { + t.Errorf("expected 0 stats, got %d", len(result)) + } +} + +// Integration test: verify full aggregation pipeline +func TestAggregationPipeline_TokensCorrectlyAggregated(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + + // Simulate 100 requests, each with 100 input tokens and 50 output tokens + // spread across 10 minutes in the same hour + var records []AttemptRecord + for i := 0; i < 10; i++ { + for j := 0; j < 10; j++ { + records = append(records, AttemptRecord{ + EndTime: baseTime.Add(time.Duration(i)*time.Minute + time.Duration(j)*time.Second), + ProviderID: 1, + Model: "claude-3", + IsSuccessful: true, + InputTokens: 100, + OutputTokens: 50, + Cost: 1000, + }) + } + } + + // Aggregate to minute + minuteStats := AggregateAttempts(records, time.UTC) + + // Verify minute aggregation + var totalMinuteTokens uint64 + for _, s := range minuteStats { + totalMinuteTokens += s.InputTokens + } + expectedTokens := uint64(100 * 100) // 100 requests * 100 tokens + if totalMinuteTokens != expectedTokens { + t.Errorf("minute input tokens = %d, want %d", totalMinuteTokens, expectedTokens) + } + + // Roll up to hour + hourStats := RollUp(minuteStats, domain.GranularityHour, time.UTC) + + if len(hourStats) != 1 { + t.Fatalf("expected 1 hour bucket, got %d", len(hourStats)) + } + + h := hourStats[0] + if h.InputTokens != expectedTokens { + t.Errorf("hour input tokens = %d, want %d", h.InputTokens, expectedTokens) + } + if h.TotalRequests != 100 { + t.Errorf("hour total requests = %d, want 100", h.TotalRequests) + } + + // Roll up to day + dayStats := RollUp(hourStats, domain.GranularityDay, time.UTC) + + if len(dayStats) != 1 { + t.Fatalf("expected 1 day bucket, got %d", len(dayStats)) + } + + d := dayStats[0] + if d.InputTokens != expectedTokens { + t.Errorf("day input tokens = %d, want %d (no data loss)", d.InputTokens, expectedTokens) + } + + // Roll up to month + monthStats := RollUp(dayStats, domain.GranularityMonth, time.UTC) + + if len(monthStats) != 1 { + t.Fatalf("expected 1 month bucket, got %d", len(monthStats)) + } + + m := monthStats[0] + if m.InputTokens != expectedTokens { + t.Errorf("month input tokens = %d, want %d (no data loss)", m.InputTokens, expectedTokens) + } +} + +// TestFullAggregationPipeline tests the complete aggregation pipeline +// that AggregateAndRollUp performs: minute → hour → day → month +func TestFullAggregationPipeline(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + // Create test records spanning multiple minutes + records := []AttemptRecord{ + {EndTime: baseTime, ProviderID: 1, Model: "claude-3", IsSuccessful: true, InputTokens: 100, OutputTokens: 50, Cost: 1000, DurationMs: 500}, + {EndTime: baseTime.Add(30 * time.Second), ProviderID: 1, Model: "claude-3", IsSuccessful: true, InputTokens: 200, OutputTokens: 100, Cost: 2000, DurationMs: 600}, + {EndTime: baseTime.Add(1 * time.Minute), ProviderID: 1, Model: "claude-3", IsFailed: true, InputTokens: 50, OutputTokens: 0, Cost: 0, DurationMs: 100}, + {EndTime: baseTime.Add(2 * time.Minute), ProviderID: 2, Model: "gpt-4", IsSuccessful: true, InputTokens: 300, OutputTokens: 150, Cost: 5000, DurationMs: 800}, + } + + // Step 1: Aggregate to minute + minuteStats := AggregateAttempts(records, time.UTC) + + // Verify: should have 3 minute buckets (10:30, 10:31, 10:32) + // But provider/model combinations mean more entries + if len(minuteStats) < 3 { + t.Errorf("expected at least 3 minute stats, got %d", len(minuteStats)) + } + + // Verify totals + totalReq, successReq, failedReq, inputTokens, outputTokens, _, _, cost := SumStats(minuteStats) + if totalReq != 4 { + t.Errorf("total requests = %d, want 4", totalReq) + } + if successReq != 3 { + t.Errorf("successful requests = %d, want 3", successReq) + } + if failedReq != 1 { + t.Errorf("failed requests = %d, want 1", failedReq) + } + if inputTokens != 650 { + t.Errorf("input tokens = %d, want 650", inputTokens) + } + if outputTokens != 300 { + t.Errorf("output tokens = %d, want 300", outputTokens) + } + if cost != 8000 { + t.Errorf("cost = %d, want 8000", cost) + } + + // Step 2: Roll up to hour + hourStats := RollUp(minuteStats, domain.GranularityHour, time.UTC) + + // Verify totals preserved + totalReq2, _, _, inputTokens2, _, _, _, cost2 := SumStats(hourStats) + if totalReq2 != totalReq { + t.Errorf("hour total requests = %d, want %d (data loss)", totalReq2, totalReq) + } + if inputTokens2 != inputTokens { + t.Errorf("hour input tokens = %d, want %d (data loss)", inputTokens2, inputTokens) + } + if cost2 != cost { + t.Errorf("hour cost = %d, want %d (data loss)", cost2, cost) + } + + // Step 3: Roll up to day + dayStats := RollUp(hourStats, domain.GranularityDay, time.UTC) + + totalReq3, _, _, inputTokens3, _, _, _, cost3 := SumStats(dayStats) + if totalReq3 != totalReq { + t.Errorf("day total requests = %d, want %d (data loss)", totalReq3, totalReq) + } + if inputTokens3 != inputTokens { + t.Errorf("day input tokens = %d, want %d (data loss)", inputTokens3, inputTokens) + } + if cost3 != cost { + t.Errorf("day cost = %d, want %d (data loss)", cost3, cost) + } + + // Step 4: Roll up to month + monthStats := RollUp(dayStats, domain.GranularityMonth, time.UTC) + + totalReq4, _, _, inputTokens4, _, _, _, cost4 := SumStats(monthStats) + if totalReq4 != totalReq { + t.Errorf("month total requests = %d, want %d (data loss)", totalReq4, totalReq) + } + if inputTokens4 != inputTokens { + t.Errorf("month input tokens = %d, want %d (data loss)", inputTokens4, inputTokens) + } + if cost4 != cost { + t.Errorf("month cost = %d, want %d (data loss)", cost4, cost) + } +} + +// TestFullAggregationPipeline_PreservesProviderDimension tests that +// provider dimension is preserved through the entire aggregation pipeline +func TestFullAggregationPipeline_PreservesProviderDimension(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + // Create records for 2 different providers + records := []AttemptRecord{ + {EndTime: baseTime, ProviderID: 1, Model: "claude-3", IsSuccessful: true, InputTokens: 100, Cost: 1000}, + {EndTime: baseTime, ProviderID: 1, Model: "claude-3", IsSuccessful: true, InputTokens: 100, Cost: 1000}, + {EndTime: baseTime, ProviderID: 2, Model: "gpt-4", IsSuccessful: true, InputTokens: 200, Cost: 3000}, + } + + // Aggregate through the entire pipeline + minuteStats := AggregateAttempts(records, time.UTC) + hourStats := RollUp(minuteStats, domain.GranularityHour, time.UTC) + dayStats := RollUp(hourStats, domain.GranularityDay, time.UTC) + monthStats := RollUp(dayStats, domain.GranularityMonth, time.UTC) + + // Group by provider and verify + providerStats := GroupByProvider(monthStats) + + if len(providerStats) != 2 { + t.Fatalf("expected 2 providers, got %d", len(providerStats)) + } + + p1 := providerStats[1] + if p1 == nil { + t.Fatal("provider 1 not found") + } + if p1.TotalRequests != 2 { + t.Errorf("provider 1 requests = %d, want 2", p1.TotalRequests) + } + if p1.TotalInputTokens != 200 { + t.Errorf("provider 1 input tokens = %d, want 200", p1.TotalInputTokens) + } + if p1.TotalCost != 2000 { + t.Errorf("provider 1 cost = %d, want 2000", p1.TotalCost) + } + + p2 := providerStats[2] + if p2 == nil { + t.Fatal("provider 2 not found") + } + if p2.TotalRequests != 1 { + t.Errorf("provider 2 requests = %d, want 1", p2.TotalRequests) + } + if p2.TotalInputTokens != 200 { + t.Errorf("provider 2 input tokens = %d, want 200", p2.TotalInputTokens) + } + if p2.TotalCost != 3000 { + t.Errorf("provider 2 cost = %d, want 3000", p2.TotalCost) + } +} + +// TestFullAggregationPipeline_WithTimezone tests aggregation with timezone +func TestFullAggregationPipeline_WithTimezone(t *testing.T) { + shanghai, _ := time.LoadLocation("Asia/Shanghai") + + // 2024-01-17 23:30 UTC = 2024-01-18 07:30 Shanghai + // 2024-01-18 00:30 UTC = 2024-01-18 08:30 Shanghai + // In UTC these are different days, in Shanghai they're the same day + records := []AttemptRecord{ + {EndTime: time.Date(2024, 1, 17, 23, 30, 0, 0, time.UTC), ProviderID: 1, IsSuccessful: true, InputTokens: 100}, + {EndTime: time.Date(2024, 1, 18, 0, 30, 0, 0, time.UTC), ProviderID: 1, IsSuccessful: true, InputTokens: 200}, + } + + // Aggregate with Shanghai timezone + minuteStats := AggregateAttempts(records, shanghai) + hourStats := RollUp(minuteStats, domain.GranularityHour, shanghai) + dayStats := RollUp(hourStats, domain.GranularityDay, shanghai) + + // In Shanghai timezone, both records should be on 2024-01-18 + if len(dayStats) != 1 { + t.Errorf("expected 1 day bucket in Shanghai timezone, got %d", len(dayStats)) + } + + totalReq, _, _, inputTokens, _, _, _, _ := SumStats(dayStats) + if totalReq != 2 { + t.Errorf("total requests = %d, want 2", totalReq) + } + if inputTokens != 300 { + t.Errorf("input tokens = %d, want 300", inputTokens) + } + + // Now aggregate with UTC - should be 2 different days + minuteStatsUTC := AggregateAttempts(records, time.UTC) + hourStatsUTC := RollUp(minuteStatsUTC, domain.GranularityHour, time.UTC) + dayStatsUTC := RollUp(hourStatsUTC, domain.GranularityDay, time.UTC) + + if len(dayStatsUTC) != 2 { + t.Errorf("expected 2 day buckets in UTC, got %d", len(dayStatsUTC)) + } +} + +// TestFullAggregationPipeline_AllFieldsPreserved tests that all numeric fields +// are correctly summed through the pipeline +func TestFullAggregationPipeline_AllFieldsPreserved(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + records := []AttemptRecord{ + { + EndTime: baseTime, + ProviderID: 1, + IsSuccessful: true, + DurationMs: 1000, + InputTokens: 100, + OutputTokens: 50, + CacheRead: 10, + CacheWrite: 5, + Cost: 1000, + }, + { + EndTime: baseTime.Add(time.Minute), + ProviderID: 1, + IsSuccessful: true, + DurationMs: 2000, + InputTokens: 200, + OutputTokens: 100, + CacheRead: 20, + CacheWrite: 10, + Cost: 2000, + }, + { + EndTime: baseTime.Add(2 * time.Minute), + ProviderID: 1, + IsFailed: true, + DurationMs: 500, + }, + } + + // Full pipeline + minuteStats := AggregateAttempts(records, time.UTC) + hourStats := RollUp(minuteStats, domain.GranularityHour, time.UTC) + dayStats := RollUp(hourStats, domain.GranularityDay, time.UTC) + monthStats := RollUp(dayStats, domain.GranularityMonth, time.UTC) + + // Check all fields are preserved at month level + totalReq, successReq, failedReq, inputTokens, outputTokens, cacheRead, cacheWrite, cost := SumStats(monthStats) + + if totalReq != 3 { + t.Errorf("totalReq = %d, want 3", totalReq) + } + if successReq != 2 { + t.Errorf("successReq = %d, want 2", successReq) + } + if failedReq != 1 { + t.Errorf("failedReq = %d, want 1", failedReq) + } + if inputTokens != 300 { + t.Errorf("inputTokens = %d, want 300", inputTokens) + } + if outputTokens != 150 { + t.Errorf("outputTokens = %d, want 150", outputTokens) + } + if cacheRead != 30 { + t.Errorf("cacheRead = %d, want 30", cacheRead) + } + if cacheWrite != 15 { + t.Errorf("cacheWrite = %d, want 15", cacheWrite) + } + if cost != 3000 { + t.Errorf("cost = %d, want 3000", cost) + } +} + +// TestFullAggregationPipeline_MultipleModels tests aggregation with multiple models +func TestFullAggregationPipeline_MultipleModels(t *testing.T) { + baseTime := time.Date(2024, 1, 17, 10, 30, 0, 0, time.UTC) + + records := []AttemptRecord{ + {EndTime: baseTime, ProviderID: 1, Model: "claude-3-opus", IsSuccessful: true, InputTokens: 100, Cost: 5000}, + {EndTime: baseTime, ProviderID: 1, Model: "claude-3-sonnet", IsSuccessful: true, InputTokens: 100, Cost: 1000}, + {EndTime: baseTime, ProviderID: 1, Model: "claude-3-opus", IsSuccessful: true, InputTokens: 100, Cost: 5000}, + } + + minuteStats := AggregateAttempts(records, time.UTC) + monthStats := RollUp( + RollUp( + RollUp(minuteStats, domain.GranularityHour, time.UTC), + domain.GranularityDay, time.UTC), + domain.GranularityMonth, time.UTC) + + // Should have 2 entries: one for each model + if len(monthStats) != 2 { + t.Errorf("expected 2 model entries, got %d", len(monthStats)) + } + + // Find opus and sonnet stats + var opusStats, sonnetStats *domain.UsageStats + for _, s := range monthStats { + switch s.Model { + case "claude-3-opus": + opusStats = s + case "claude-3-sonnet": + sonnetStats = s + } + } + + if opusStats == nil { + t.Fatal("opus stats not found") + } + if opusStats.TotalRequests != 2 { + t.Errorf("opus requests = %d, want 2", opusStats.TotalRequests) + } + if opusStats.Cost != 10000 { + t.Errorf("opus cost = %d, want 10000", opusStats.Cost) + } + + if sonnetStats == nil { + t.Fatal("sonnet stats not found") + } + if sonnetStats.TotalRequests != 1 { + t.Errorf("sonnet requests = %d, want 1", sonnetStats.TotalRequests) + } + if sonnetStats.Cost != 1000 { + t.Errorf("sonnet cost = %d, want 1000", sonnetStats.Cost) + } +} diff --git a/main.go b/main.go index e558d30c..866e7ff4 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "io/fs" "log" goruntime "runtime" + "time" "github.com/awsl-project/maxx/internal/desktop" "github.com/awsl-project/maxx/internal/handler" @@ -44,7 +45,7 @@ func main() { go func() { // 等待 app context 初始化 for appCtx == nil { - // 等待 OnStartup 设置 appCtx + time.Sleep(10 * time.Millisecond) // 等待 OnStartup 设置 appCtx } tray := desktop.NewTrayManager(appCtx, app) tray.Start() diff --git a/web/package.json b/web/package.json index 4b2609d9..c47b8c78 100644 --- a/web/package.json +++ b/web/package.json @@ -35,7 +35,7 @@ "react-i18next": "^16.5.3", "react-resizable-panels": "^2.1.7", "react-router-dom": "^7.11.0", - "recharts": "^2.15.4", + "recharts": "^3.6.0", "shadcn": "^3.6.3", "tailwind-merge": "^3.4.0", "tailwindcss": "^4.1.18", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 58c7e341..b114424b 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -72,8 +72,8 @@ importers: specifier: ^7.11.0 version: 7.12.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3) recharts: - specifier: ^2.15.4 - version: 2.15.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3) + specifier: ^3.6.0 + version: 3.6.0(@types/react@19.2.8)(react-dom@19.2.3(react@19.2.3))(react-is@18.3.1)(react@19.2.3)(redux@5.0.1) shadcn: specifier: ^3.6.3 version: 3.6.3(@types/node@24.10.8)(hono@4.11.4)(typescript@5.9.3) @@ -687,6 +687,17 @@ packages: '@open-draft/until@2.1.0': resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==} + '@reduxjs/toolkit@2.11.2': + resolution: {integrity: sha512-Kd6kAHTA6/nUpp8mySPqj3en3dm0tdMIgbttnQ1xFMVpufoj+ADi8pXLBsd4xzTRHQa7t/Jv8W5UnCuW4kuWMQ==} + peerDependencies: + react: ^16.9.0 || ^17.0.0 || ^18 || ^19 + react-redux: ^7.2.1 || ^8.1.3 || ^9.0.0 + peerDependenciesMeta: + react: + optional: true + react-redux: + optional: true + '@rolldown/pluginutils@1.0.0-beta.53': resolution: {integrity: sha512-vENRlFU4YbrwVqNDZ7fLvy+JR1CRkyr01jhSiDpE1u6py3OMzQfztQU2jxykW3ALNxO4kSlqIDeYyD0Y9RcQeQ==} @@ -822,6 +833,12 @@ packages: resolution: {integrity: sha512-tlqY9xq5ukxTUZBmoOp+m61cqwQD5pHJtFY3Mn8CA8ps6yghLH/Hw8UPdqg4OLmFW3IFlcXnQNmo/dh8HzXYIQ==} engines: {node: '>=18'} + '@standard-schema/spec@1.1.0': + resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==} + + '@standard-schema/utils@0.3.0': + resolution: {integrity: sha512-e7Mew686owMaPJVNNLs55PUvgz371nKgwsc4vxE49zsODpJEnxgxRo2y/OKrqueavXgZNMDVj3DdHFlaSAeU8g==} + '@tailwindcss/node@4.1.18': resolution: {integrity: sha512-DoR7U1P7iYhw16qJ49fgXUlry1t4CpXeErJHnQ44JgTSKMaZUdf17cfn5mHchfJ4KRBZRFA/Coo+MUF5+gOaCQ==} @@ -988,6 +1005,9 @@ packages: '@types/statuses@2.0.6': resolution: {integrity: sha512-xMAgYwceFhRA2zY+XbEA7mxYbA093wdiW8Vu6gZPGWy9cmOyU9XesH1tNcEWsKFd5Vzrqx5T3D38PWx1FIIXkA==} + '@types/use-sync-external-store@0.0.6': + resolution: {integrity: sha512-zFDAD+tlpf2r4asuHEj0XH6pY6i0g5NeAHPn+15wk3BV6JA69eERFXC1gyGThDkVa1zCyKr5jox1+2LbV/AMLg==} + '@types/validate-npm-package-name@4.0.2': resolution: {integrity: sha512-lrpDziQipxCEeK5kWxvljWYhUvOiB2A9izZd9B2AFarYAkqZshb4lPbRs7zKEic6eGtH8V/2qJW+dPp9OtF6bw==} @@ -1383,9 +1403,6 @@ packages: resolution: {integrity: sha512-qejHi7bcSD4hQAZE0tNAawRK1ZtafHDmMTMkrrIGgSLl7hTnQHmKCeB45xAcbfTqK2zowkM3j3bHt/4b/ARbYQ==} engines: {node: '>=0.3.1'} - dom-helpers@5.2.1: - resolution: {integrity: sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA==} - dotenv@17.2.3: resolution: {integrity: sha512-JVUnt+DUIzu87TABbhPmNfVdBDt18BLOWjMUFJMSi/Qqg7NTYtabbvSNJGOJ7afbRuv9D/lngizHtP7QyLQ+9w==} engines: {node: '>=12'} @@ -1441,6 +1458,9 @@ packages: resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} engines: {node: '>= 0.4'} + es-toolkit@1.44.0: + resolution: {integrity: sha512-6penXeZalaV88MM3cGkFZZfOoLGWshWWfdy0tWw/RlVVyhvMaWSBTOvXNeiW3e5FwdS5ePW0LGEu17zT139ktg==} + esbuild@0.27.2: resolution: {integrity: sha512-HyNQImnsOC7X9PMNaCIeAm4ISCQXs5a5YasTXVliKv4uuBo1dKrG0A+uQS8M5eXjVMnLg3WgXaKvprHlFJQffw==} engines: {node: '>=18'} @@ -1519,8 +1539,8 @@ packages: resolution: {integrity: sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==} engines: {node: '>= 0.6'} - eventemitter3@4.0.7: - resolution: {integrity: sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==} + eventemitter3@5.0.4: + resolution: {integrity: sha512-mlsTRyGaPBjPedk6Bvw+aqbsXDtoAyAzm5MO7JgU+yVRyMQ5O8bD4Kcci7BS85f93veegeCPkL8R4GLClnjLFw==} eventsource-parser@3.0.6: resolution: {integrity: sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==} @@ -1551,10 +1571,6 @@ packages: fast-deep-equal@3.1.3: resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} - fast-equals@5.4.0: - resolution: {integrity: sha512-jt2DW/aNFNwke7AUd+Z+e6pz39KO5rzdbbFCg2sGafS4mk13MI7Z8O5z9cADNn5lhGODIgLwug6TZO2ctf7kcw==} - engines: {node: '>=6.0.0'} - fast-glob@3.3.3: resolution: {integrity: sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==} engines: {node: '>=8.6.0'} @@ -1789,6 +1805,9 @@ packages: resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} engines: {node: '>= 4'} + immer@10.2.0: + resolution: {integrity: sha512-d/+XTN3zfODyjr89gM3mPq1WNX2B8pYsu7eORitdwyA2sBubnTl3laYlBk4sXY5FUa5qTZGBDPJICVbvqzjlbw==} + immer@11.1.3: resolution: {integrity: sha512-6jQTc5z0KJFtr1UgFpIL3N9XSC3saRaI9PwWtzM2pSqkNGtiNkYY2OSwkOGDK2XcTRcLb1pi/aNkKZz0nxVH4Q==} @@ -2033,17 +2052,10 @@ packages: lodash.merge@4.6.2: resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} - lodash@4.17.21: - resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==} - log-symbols@6.0.0: resolution: {integrity: sha512-i24m8rpwhmPIS4zscNzK6MSEhk0DUWa/8iYQWxhffV8jkI4Phvs3F+quL5xvS0gdQR0FyTCMMH33Y78dDTzzIw==} engines: {node: '>=18'} - loose-envify@1.4.0: - resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} - hasBin: true - lru-cache@5.1.1: resolution: {integrity: sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w==} @@ -2302,9 +2314,6 @@ packages: resolution: {integrity: sha512-NxNv/kLguCA7p3jE8oL2aEBsrJWgAakBpgmgK6lpPWV+WuOmY6r2/zbAVnP+T8bQlA0nzHXSJSJW0Hq7ylaD2Q==} engines: {node: '>= 6'} - prop-types@15.8.1: - resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} - proxy-addr@2.0.7: resolution: {integrity: sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==} engines: {node: '>= 0.10'} @@ -2352,12 +2361,21 @@ packages: typescript: optional: true - react-is@16.13.1: - resolution: {integrity: sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==} - react-is@18.3.1: resolution: {integrity: sha512-/LLMVyas0ljjAtoYiPqYiL8VWXzUUdThrmU5+n20DZv+a+ClRoevUzw5JxU+Ieh5/c87ytoTBV9G1FiKfNJdmg==} + react-redux@9.2.0: + resolution: {integrity: sha512-ROY9fvHhwOD9ySfrF0wmvu//bKCQ6AeZZq1nJNtbDC+kk5DuSuNX/n6YWYF/SYy7bSba4D4FSz8DJeKY/S/r+g==} + peerDependencies: + '@types/react': ^18.2.25 || ^19 + react: ^18.0 || ^19 + redux: ^5.0.0 + peerDependenciesMeta: + '@types/react': + optional: true + redux: + optional: true + react-refresh@0.18.0: resolution: {integrity: sha512-QgT5//D3jfjJb6Gsjxv0Slpj23ip+HtOpnNgnb2S5zU3CB26G/IDPGoy4RJB42wzFE46DRsstbW6tKHoKbhAxw==} engines: {node: '>=0.10.0'} @@ -2385,18 +2403,6 @@ packages: react-dom: optional: true - react-smooth@4.0.4: - resolution: {integrity: sha512-gnGKTpYwqL0Iii09gHobNolvX4Kiq4PKx6eWBCYYix+8cdw+cGo3do906l1NBPKkSWx1DghC1dlWG9L2uGd61Q==} - peerDependencies: - react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - - react-transition-group@4.4.5: - resolution: {integrity: sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g==} - peerDependencies: - react: '>=16.6.0' - react-dom: '>=16.6.0' - react@19.2.3: resolution: {integrity: sha512-Ku/hhYbVjOQnXDZFv2+RibmLFGwFdeeKHFcOTlrt7xplBnya5OGn/hIRDsqDiSUcfORsDC7MPxwork8jBwsIWA==} engines: {node: '>=0.10.0'} @@ -2405,15 +2411,21 @@ packages: resolution: {integrity: sha512-YTUo+Flmw4ZXiWfQKGcwwc11KnoRAYgzAE2E7mXKCjSviTKShtxBsN6YUUBB2gtaBzKzeKunxhUwNHQuRryhWA==} engines: {node: '>= 4'} - recharts-scale@0.4.5: - resolution: {integrity: sha512-kivNFO+0OcUNu7jQquLXAxz1FIwZj8nrj+YkOKc5694NbjCvcT6aSZiIzNzd2Kul4o4rTto8QVR9lMNtxD4G1w==} - - recharts@2.15.4: - resolution: {integrity: sha512-UT/q6fwS3c1dHbXv2uFgYJ9BMFHu3fwnd7AYZaEQhXuYQ4hgsxLvsUXzGdKeZrW5xopzDCvuA2N41WJ88I7zIw==} - engines: {node: '>=14'} + recharts@3.6.0: + resolution: {integrity: sha512-L5bjxvQRAe26RlToBAziKUB7whaGKEwD3znoM6fz3DrTowCIC/FnJYnuq1GEzB8Zv2kdTfaxQfi5GoH0tBinyg==} + engines: {node: '>=18'} peerDependencies: - react: ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + react-is: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + + redux-thunk@3.1.0: + resolution: {integrity: sha512-NW2r5T6ksUKXCabzhL9z+h206HQw/NJkcLm1GPImRQ8IzfXwRGqjVhKJGauHirT0DAuyy6hjdnMZaRoAcy0Klw==} + peerDependencies: + redux: ^5.0.0 + + redux@5.0.1: + resolution: {integrity: sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==} require-directory@2.1.1: resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==} @@ -2716,8 +2728,8 @@ packages: resolution: {integrity: sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==} engines: {node: '>= 0.8'} - victory-vendor@36.9.2: - resolution: {integrity: sha512-PnpQQMuxlwYdocC8fIJqVXvkeViHYzotI+NJrCuav0ZYFoq912ZHBk3mCeuj+5/VpodOjPe1z0Fk2ihgzlXqjQ==} + victory-vendor@37.3.6: + resolution: {integrity: sha512-SbPDPdDBYp+5MJHhBCAyI7wKM3d5ivekigc2Dk2s7pgbZ9wIgIBYGVw4zGHBml/qTFbexrofXW6Gu4noGxrOwQ==} vite@7.3.1: resolution: {integrity: sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==} @@ -3404,6 +3416,18 @@ snapshots: '@open-draft/until@2.1.0': {} + '@reduxjs/toolkit@2.11.2(react-redux@9.2.0(@types/react@19.2.8)(react@19.2.3)(redux@5.0.1))(react@19.2.3)': + dependencies: + '@standard-schema/spec': 1.1.0 + '@standard-schema/utils': 0.3.0 + immer: 11.1.3 + redux: 5.0.1 + redux-thunk: 3.1.0(redux@5.0.1) + reselect: 5.1.1 + optionalDependencies: + react: 19.2.3 + react-redux: 9.2.0(@types/react@19.2.8)(react@19.2.3)(redux@5.0.1) + '@rolldown/pluginutils@1.0.0-beta.53': {} '@rollup/rollup-android-arm-eabi@4.55.1': @@ -3485,6 +3509,10 @@ snapshots: '@sindresorhus/merge-streams@4.0.0': {} + '@standard-schema/spec@1.1.0': {} + + '@standard-schema/utils@0.3.0': {} + '@tailwindcss/node@4.1.18': dependencies: '@jridgewell/remapping': 2.3.5 @@ -3639,6 +3667,8 @@ snapshots: '@types/statuses@2.0.6': {} + '@types/use-sync-external-store@0.0.6': {} + '@types/validate-npm-package-name@4.0.2': {} '@typescript-eslint/eslint-plugin@8.53.0(@typescript-eslint/parser@8.53.0(eslint@9.39.2(jiti@2.6.1))(typescript@5.9.3))(eslint@9.39.2(jiti@2.6.1))(typescript@5.9.3)': @@ -4021,11 +4051,6 @@ snapshots: diff@8.0.3: {} - dom-helpers@5.2.1: - dependencies: - '@babel/runtime': 7.28.6 - csstype: 3.2.3 - dotenv@17.2.3: {} dunder-proto@1.0.1: @@ -4077,6 +4102,8 @@ snapshots: has-tostringtag: 1.0.2 hasown: 2.0.2 + es-toolkit@1.44.0: {} + esbuild@0.27.2: optionalDependencies: '@esbuild/aix-ppc64': 0.27.2 @@ -4199,7 +4226,7 @@ snapshots: etag@1.8.1: {} - eventemitter3@4.0.7: {} + eventemitter3@5.0.4: {} eventsource-parser@3.0.6: {} @@ -4273,8 +4300,6 @@ snapshots: fast-deep-equal@3.1.3: {} - fast-equals@5.4.0: {} - fast-glob@3.3.3: dependencies: '@nodelib/fs.stat': 2.0.5 @@ -4484,8 +4509,9 @@ snapshots: ignore@7.0.5: {} - immer@11.1.3: - optional: true + immer@10.2.0: {} + + immer@11.1.3: {} import-fresh@3.3.1: dependencies: @@ -4650,17 +4676,11 @@ snapshots: lodash.merge@4.6.2: {} - lodash@4.17.21: {} - log-symbols@6.0.0: dependencies: chalk: 5.6.2 is-unicode-supported: 1.3.0 - loose-envify@1.4.0: - dependencies: - js-tokens: 4.0.0 - lru-cache@5.1.1: dependencies: yallist: 3.1.1 @@ -4899,12 +4919,6 @@ snapshots: kleur: 3.0.3 sisteransi: 1.0.5 - prop-types@15.8.1: - dependencies: - loose-envify: 1.4.0 - object-assign: 4.1.1 - react-is: 16.13.1 - proxy-addr@2.0.7: dependencies: forwarded: 0.2.0 @@ -4945,10 +4959,17 @@ snapshots: react-dom: 19.2.3(react@19.2.3) typescript: 5.9.3 - react-is@16.13.1: {} - react-is@18.3.1: {} + react-redux@9.2.0(@types/react@19.2.8)(react@19.2.3)(redux@5.0.1): + dependencies: + '@types/use-sync-external-store': 0.0.6 + react: 19.2.3 + use-sync-external-store: 1.6.0(react@19.2.3) + optionalDependencies: + '@types/react': 19.2.8 + redux: 5.0.1 + react-refresh@0.18.0: {} react-resizable-panels@2.1.9(react-dom@19.2.3(react@19.2.3))(react@19.2.3): @@ -4970,23 +4991,6 @@ snapshots: optionalDependencies: react-dom: 19.2.3(react@19.2.3) - react-smooth@4.0.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3): - dependencies: - fast-equals: 5.4.0 - prop-types: 15.8.1 - react: 19.2.3 - react-dom: 19.2.3(react@19.2.3) - react-transition-group: 4.4.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3) - - react-transition-group@4.4.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3): - dependencies: - '@babel/runtime': 7.28.6 - dom-helpers: 5.2.1 - loose-envify: 1.4.0 - prop-types: 15.8.1 - react: 19.2.3 - react-dom: 19.2.3(react@19.2.3) - react@19.2.3: {} recast@0.23.11: @@ -4997,22 +5001,31 @@ snapshots: tiny-invariant: 1.3.3 tslib: 2.8.1 - recharts-scale@0.4.5: - dependencies: - decimal.js-light: 2.5.1 - - recharts@2.15.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3): + recharts@3.6.0(@types/react@19.2.8)(react-dom@19.2.3(react@19.2.3))(react-is@18.3.1)(react@19.2.3)(redux@5.0.1): dependencies: + '@reduxjs/toolkit': 2.11.2(react-redux@9.2.0(@types/react@19.2.8)(react@19.2.3)(redux@5.0.1))(react@19.2.3) clsx: 2.1.1 - eventemitter3: 4.0.7 - lodash: 4.17.21 + decimal.js-light: 2.5.1 + es-toolkit: 1.44.0 + eventemitter3: 5.0.4 + immer: 10.2.0 react: 19.2.3 react-dom: 19.2.3(react@19.2.3) react-is: 18.3.1 - react-smooth: 4.0.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3) - recharts-scale: 0.4.5 + react-redux: 9.2.0(@types/react@19.2.8)(react@19.2.3)(redux@5.0.1) + reselect: 5.1.1 tiny-invariant: 1.3.3 - victory-vendor: 36.9.2 + use-sync-external-store: 1.6.0(react@19.2.3) + victory-vendor: 37.3.6 + transitivePeerDependencies: + - '@types/react' + - redux + + redux-thunk@3.1.0(redux@5.0.1): + dependencies: + redux: 5.0.1 + + redux@5.0.1: {} require-directory@2.1.1: {} @@ -5358,7 +5371,7 @@ snapshots: vary@1.1.2: {} - victory-vendor@36.9.2: + victory-vendor@37.3.6: dependencies: '@types/d3-array': 3.2.2 '@types/d3-ease': 3.0.2 diff --git a/web/src/components/provider-details-dialog.tsx b/web/src/components/provider-details-dialog.tsx index b51490f9..df3ff690 100644 --- a/web/src/components/provider-details-dialog.tsx +++ b/web/src/components/provider-details-dialog.tsx @@ -109,16 +109,17 @@ function formatTokens(count: number): string { return count.toString(); } -// 格式化成本 (微美元 → 美元) -function formatCost(microUsd: number): string { - const usd = microUsd / 1_000_000; +// 格式化成本 (纳美元 → 美元,向下取整到 6 位) +function formatCost(nanoUsd: number): string { + // 向下取整到 6 位小数 (microUSD 精度) + const usd = Math.floor(nanoUsd / 1000) / 1_000_000; if (usd >= 1) { return `$${usd.toFixed(2)}`; } if (usd >= 0.01) { return `$${usd.toFixed(3)}`; } - return `$${usd.toFixed(4)}`; + return `$${usd.toFixed(6).replace(/\.?0+$/, '')}`; } // 计算缓存利用率 diff --git a/web/src/components/routes/ClientTypeRoutesContent.tsx b/web/src/components/routes/ClientTypeRoutesContent.tsx index d17c77c4..3f97b167 100644 --- a/web/src/components/routes/ClientTypeRoutesContent.tsx +++ b/web/src/components/routes/ClientTypeRoutesContent.tsx @@ -43,13 +43,20 @@ import { } from '@/pages/client-routes/components/provider-row'; import type { ProviderConfigItem } from '@/pages/client-routes/types'; import { Button } from '../ui'; -import { - AntigravityQuotasProvider, - useAntigravityQuotasContext, -} from '@/contexts/antigravity-quotas-context'; +import { AntigravityQuotasProvider, useAntigravityQuotasContext } from '@/contexts/antigravity-quotas-context'; import { CooldownsProvider } from '@/contexts/cooldowns-context'; import { useTranslation } from 'react-i18next'; +type ProviderTypeKey = 'antigravity' | 'kiro' | 'custom'; + +const PROVIDER_TYPE_ORDER: ProviderTypeKey[] = ['antigravity', 'kiro', 'custom']; + +const PROVIDER_TYPE_LABELS: Record = { + antigravity: 'Antigravity', + kiro: 'Kiro', + custom: 'Custom', +}; + interface ClientTypeRoutesContentProps { clientType: ClientType; projectID: number; // 0 for global routes @@ -143,8 +150,14 @@ function ClientTypeRoutesContentInner({ }); }, [providers, clientRoutes, clientType, searchQuery]); - // Get available providers (without routes yet) - const availableProviders = useMemo((): Provider[] => { + // Get available providers (without routes yet), grouped by type and sorted alphabetically + const groupedAvailableProviders = useMemo((): Record => { + const groups: Record = { + antigravity: [], + kiro: [], + custom: [], + }; + let available = providers.filter((p) => { const hasRoute = clientRoutes.some((r) => Number(r.providerID) === Number(p.id)); return !hasRoute; @@ -158,9 +171,29 @@ function ClientTypeRoutesContentInner({ ); } - return available; + // Group by type + available.forEach((p) => { + const type = p.type as ProviderTypeKey; + if (groups[type]) { + groups[type].push(p); + } else { + groups.custom.push(p); + } + }); + + // Sort alphabetically within each group + for (const key of Object.keys(groups) as ProviderTypeKey[]) { + groups[key].sort((a, b) => a.name.localeCompare(b.name)); + } + + return groups; }, [providers, clientRoutes, searchQuery]); + // Check if there are any available providers + const hasAvailableProviders = useMemo(() => { + return PROVIDER_TYPE_ORDER.some((type) => groupedAvailableProviders[type].length > 0); + }, [groupedAvailableProviders]); + // Check if there are any Antigravity routes const hasAntigravityRoutes = useMemo(() => { return items.some((item) => item.provider.type === 'antigravity' && item.route); @@ -209,7 +242,9 @@ function ClientTypeRoutesContentInner({ }); // Collect original positions (sorted by position value) - const originalPositions = antigravityItems.map((a) => a.position).sort((a, b) => a - b); + const originalPositions = antigravityItems + .map((a) => a.position) + .sort((a, b) => a - b); // Build updates: assign sorted items to the original position slots // Only update Antigravity routes, leaving Other types unchanged @@ -395,7 +430,8 @@ function ClientTypeRoutesContentInner({ index={items.findIndex((i) => i.id === activeItem.id)} clientType={clientType} streamingCount={ - countsByProviderAndClient.get(`${activeItem.provider.id}:${clientType}`) || 0 + countsByProviderAndClient.get(`${activeItem.provider.id}:${clientType}`) || + 0 } stats={providerStats[activeItem.provider.id]} isToggling={false} @@ -412,8 +448,8 @@ function ClientTypeRoutesContentInner({ )} - {/* Add Route Section - Card Style */} - {availableProviders.length > 0 && ( + {/* Add Route Section - Grouped by Type */} + {hasAvailableProviders && (
@@ -421,61 +457,75 @@ function ClientTypeRoutesContentInner({ Available Providers
-
- {availableProviders.map((provider) => { - const isNative = (provider.supportedClientTypes || []).includes(clientType); - const providerColor = getProviderColor(provider.type as ProviderType); +
+ {PROVIDER_TYPE_ORDER.map((typeKey) => { + const typeProviders = groupedAvailableProviders[typeKey]; + if (typeProviders.length === 0) return null; + return ( - +
+ {typeProviders.map((provider) => { + const isNative = (provider.supportedClientTypes || []).includes(clientType); + const providerColor = getProviderColor(provider.type as ProviderType); + return ( + + ); + })} +
+
); })}
diff --git a/web/src/components/ui/chart.tsx b/web/src/components/ui/chart.tsx index 854c61fe..5e5e5fe2 100644 --- a/web/src/components/ui/chart.tsx +++ b/web/src/components/ui/chart.tsx @@ -1,5 +1,8 @@ import * as React from 'react'; import * as RechartsPrimitive from 'recharts'; +import type { TooltipContentProps } from 'recharts'; +import type { NameType, ValueType } from 'recharts/types/component/DefaultTooltipContent'; +import type { LegendPayload, Props as DefaultLegendContentProps } from 'recharts/types/component/DefaultLegendContent'; import { cn } from '@/lib/utils'; @@ -108,7 +111,7 @@ function ChartTooltipContent({ color, nameKey, labelKey, -}: React.ComponentProps & +}: Partial> & React.ComponentProps<'div'> & { hideLabel?: boolean; hideIndicator?: boolean; @@ -239,7 +242,7 @@ function ChartLegendContent({ verticalAlign = 'bottom', nameKey, }: React.ComponentProps<'div'> & - Pick & { + Pick & { hideIcon?: boolean; nameKey?: string; }) { @@ -258,8 +261,8 @@ function ChartLegendContent({ )} > {payload - .filter((item) => item.type !== 'none') - .map((item) => { + .filter((item: LegendPayload) => item.type !== 'none') + .map((item: LegendPayload) => { const key = `${nameKey || item.dataKey || 'value'}`; const itemConfig = getPayloadConfigFromPayload(config, item, key); diff --git a/web/src/components/ui/index.ts b/web/src/components/ui/index.ts index 3614a365..620352aa 100644 --- a/web/src/components/ui/index.ts +++ b/web/src/components/ui/index.ts @@ -23,7 +23,15 @@ export { Badge, badgeVariants } from './badge'; export { Input } from './input'; // Select -export { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from './select'; +export { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from './select'; // Dialog export { @@ -50,3 +58,18 @@ export { MarqueeBackground } from './marquee-background'; // Activity Heatmap export { ActivityHeatmap } from './activity-heatmap'; + +// Progress +export { Progress } from './progress'; + +// Sheet +export { + Sheet, + SheetTrigger, + SheetClose, + SheetContent, + SheetHeader, + SheetFooter, + SheetTitle, + SheetDescription, +} from './sheet'; diff --git a/web/src/components/ui/progress.tsx b/web/src/components/ui/progress.tsx new file mode 100644 index 00000000..bd4b0ccd --- /dev/null +++ b/web/src/components/ui/progress.tsx @@ -0,0 +1,78 @@ +import { Progress as ProgressPrimitive } from "@base-ui/react/progress" + +import { cn } from "@/lib/utils" + +function Progress({ + className, + children, + value, + ...props +}: ProgressPrimitive.Root.Props) { + return ( + + {children} + + + + + ) +} + +function ProgressTrack({ className, ...props }: ProgressPrimitive.Track.Props) { + return ( + + ) +} + +function ProgressIndicator({ + className, + ...props +}: ProgressPrimitive.Indicator.Props) { + return ( + + ) +} + +function ProgressLabel({ className, ...props }: ProgressPrimitive.Label.Props) { + return ( + + ) +} + +function ProgressValue({ className, ...props }: ProgressPrimitive.Value.Props) { + return ( + + ) +} + +export { + Progress, + ProgressTrack, + ProgressIndicator, + ProgressLabel, + ProgressValue, +} diff --git a/web/src/components/ui/table.tsx b/web/src/components/ui/table.tsx index a3af1721..bb7415cd 100644 --- a/web/src/components/ui/table.tsx +++ b/web/src/components/ui/table.tsx @@ -7,7 +7,7 @@ function Table({ className, ...props }: React.ComponentProps<"table">) {
@@ -48,7 +48,7 @@ function TableRow({ className, ...props }: React.ComponentProps<"tr">) { return ( ) diff --git a/web/src/hooks/queries/index.ts b/web/src/hooks/queries/index.ts index 4a60a044..b18fb54b 100644 --- a/web/src/hooks/queries/index.ts +++ b/web/src/hooks/queries/index.ts @@ -112,6 +112,7 @@ export { useUsageStats, useUsageStatsWithPreset, useRecalculateUsageStats, + useRecalculateCosts, selectGranularity, getTimeRange, type TimeRangePreset, @@ -142,3 +143,6 @@ export { type HeatmapDataPoint, type ModelRanking, } from './use-dashboard-stats'; + +// Pricing hooks +export { pricingKeys, usePricing } from './use-pricing'; diff --git a/web/src/hooks/queries/use-pricing.ts b/web/src/hooks/queries/use-pricing.ts new file mode 100644 index 00000000..fc5de905 --- /dev/null +++ b/web/src/hooks/queries/use-pricing.ts @@ -0,0 +1,22 @@ +/** + * Pricing API Hooks + */ + +import { useQuery } from '@tanstack/react-query'; +import { getTransport } from '@/lib/transport'; + +export const pricingKeys = { + all: ['pricing'] as const, +}; + +/** + * 获取价格表 + * 价格表较少变化,使用较长的 staleTime + */ +export function usePricing() { + return useQuery({ + queryKey: pricingKeys.all, + queryFn: () => getTransport().getPricing(), + staleTime: 1000 * 60 * 60, // 1 hour + }); +} diff --git a/web/src/hooks/queries/use-requests.ts b/web/src/hooks/queries/use-requests.ts index 0994079e..bec62d6d 100644 --- a/web/src/hooks/queries/use-requests.ts +++ b/web/src/hooks/queries/use-requests.ts @@ -31,10 +31,10 @@ export function useProxyRequests(params?: CursorPaginationParams) { } // 获取 ProxyRequests 总数 -export function useProxyRequestsCount() { +export function useProxyRequestsCount(providerId?: number, status?: string) { return useQuery({ - queryKey: ['requestsCount'] as const, - queryFn: () => getTransport().getProxyRequestsCount(), + queryKey: ['requestsCount', providerId, status] as const, + queryFn: () => getTransport().getProxyRequestsCount(providerId, status), }); } @@ -75,28 +75,82 @@ export function useProxyRequestUpdates() { queryClient.setQueryData(requestKeys.detail(updatedRequest.id), updatedRequest); // 更新列表缓存(乐观更新)- 适配 CursorPaginationResult 结构 - queryClient.setQueriesData>( - { queryKey: requestKeys.lists() }, - (old) => { + // 使用 queryCache 遍历所有匹配的查询,以获取每个查询的过滤参数 + const queryCache = queryClient.getQueryCache(); + const listQueries = queryCache.findAll({ queryKey: requestKeys.lists() }); + + for (const query of listQueries) { + const queryKey = query.queryKey as ReturnType; + // 从 queryKey 中提取过滤参数: ['requests', 'list', params] + const params = queryKey[2] as CursorPaginationParams | undefined; + const filterProviderId = params?.providerId; + const filterStatus = params?.status; + + // 检查是否匹配过滤条件的辅助函数 + const matchesFilter = (request: ProxyRequest) => { + if (filterProviderId !== undefined && request.providerID !== filterProviderId) { + return false; + } + if (filterStatus !== undefined && request.status !== filterStatus) { + return false; + } + return true; + }; + + queryClient.setQueryData>(queryKey, (old) => { if (!old || !old.items) return old; + const index = old.items.findIndex((r) => r.id === updatedRequest.id); if (index >= 0) { + // 已存在的请求:检查是否仍然匹配过滤条件 + if (!matchesFilter(updatedRequest)) { + // 不再匹配过滤条件,从列表中移除 + const newItems = old.items.filter((r) => r.id !== updatedRequest.id); + return { ...old, items: newItems }; + } + // 仍然匹配,更新 const newItems = [...old.items]; newItems[index] = updatedRequest; return { ...old, items: newItems }; } + + // 新请求:检查是否匹配过滤条件 + if (!matchesFilter(updatedRequest)) { + // 不匹配过滤条件,不添加 + return old; + } + // 新请求添加到列表开头(只在首页,即没有 before 参数的查询) + if (params?.before) { + // 不是首页,不添加新请求 + return old; + } + return { ...old, items: [updatedRequest, ...old.items], firstId: updatedRequest.id, }; - }, - ); + }); + } - // 新请求时乐观更新 count + // 新请求时乐观更新 count(需要考虑每个 count 查询的过滤条件) if (isNewRequest) { - queryClient.setQueryData(['requestsCount'], (old) => (old ?? 0) + 1); + // 遍历所有 requestsCount 缓存 + const countQueries = queryCache.findAll({ queryKey: ['requestsCount'] }); + for (const query of countQueries) { + // queryKey: ['requestsCount', providerId, status] + const filterProviderId = query.queryKey[1] as number | undefined; + const filterStatus = query.queryKey[2] as string | undefined; + // 如果有过滤条件且不匹配,不更新计数 + if (filterProviderId !== undefined && updatedRequest.providerID !== filterProviderId) { + continue; + } + if (filterStatus !== undefined && updatedRequest.status !== filterStatus) { + continue; + } + queryClient.setQueryData(query.queryKey, (old) => (old ?? 0) + 1); + } } // 请求完成或失败时刷新相关数据 diff --git a/web/src/hooks/queries/use-usage-stats.ts b/web/src/hooks/queries/use-usage-stats.ts index 7fbf445a..60b24db1 100644 --- a/web/src/hooks/queries/use-usage-stats.ts +++ b/web/src/hooks/queries/use-usage-stats.ts @@ -140,7 +140,7 @@ export function useUsageStatsWithPreset( } /** - * 清空并重新计算统计数据 + * 清空并重新聚合统计数据(不重算成本) */ export function useRecalculateUsageStats() { const queryClient = useQueryClient(); @@ -153,3 +153,18 @@ export function useRecalculateUsageStats() { }, }); } + +/** + * 重新计算所有请求的成本 + */ +export function useRecalculateCosts() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: () => getTransport().recalculateCosts(), + onSuccess: () => { + // 使所有 usageStats 查询失效,触发重新获取 + queryClient.invalidateQueries({ queryKey: usageStatsKeys.all }); + }, + }); +} diff --git a/web/src/index.css b/web/src/index.css index b660cf1c..eba6b713 100644 --- a/web/src/index.css +++ b/web/src/index.css @@ -5,11 +5,6 @@ @custom-variant dark (&:is(.dark *)); -/* Theme transition animations */ -* { - transition: background-color 0.3s ease, color 0.3s ease, border-color 0.3s ease; -} - :root { --background: oklch(1 0 0); --foreground: oklch(0.2644 0 0); @@ -29,18 +24,12 @@ --border: oklch(0.9404 0 0); --input: oklch(0.9404 0 0); --ring: oklch(0.7716 0 0); - --chart-1: oklch(0.6447 0.1234 166.1234); - --chart-2: oklch(0.6983 0.1337 142.5); - --chart-3: oklch(0.7232 0.15 60.6307); - --chart-4: oklch(0.6192 0.2037 312.7283); - --chart-5: oklch(0.6123 0.2093 6.3856); + --chart-1: oklch(0.8241 0.1251 84.4866); + --chart-2: oklch(0.8006 0.1116 203.6044); + --chart-3: oklch(0.4198 0.1693 266.7798); + --chart-4: oklch(0.9214 0.0762 125.5777); + --chart-5: oklch(0.9151 0.1032 116.1913); - /* 语义化图表颜色 */ - --chart-success: oklch(0.6447 0.1234 166.1234); /* 绿色 - 成功 */ - --chart-error: oklch(0.6123 0.2093 6.3856); /* 红色 - 失败 */ - --chart-warning: oklch(0.7232 0.15 60.6307); /* 黄色 - 警告 */ - --chart-info: oklch(0.6983 0.1337 142.5); /* 青色 - 信息 */ - --chart-primary: oklch(0.6192 0.2037 312.7283); /* 紫色 - 主要 */ --radius: 0.625rem; --sidebar: oklch(0.9886 0 0); --sidebar-foreground: oklch(0.2644 0 0); @@ -122,12 +111,6 @@ --chart-4: oklch(0.6192 0.2037 312.7283); --chart-5: oklch(0.6123 0.2093 6.3856); - /* 语义化图表颜色 - 深色主题 */ - --chart-success: oklch(0.6983 0.1337 165.4626); /* 青绿色 - 成功 */ - --chart-error: oklch(0.6123 0.2093 6.3856); /* 红色 - 失败 */ - --chart-warning: oklch(0.7232 0.15 60.6307); /* 黄色 - 警告 */ - --chart-info: oklch(0.5292 0.1931 262.1292); /* 紫色 - 信息 */ - --chart-primary: oklch(0.6192 0.2037 312.7283); /* 粉紫色 - 主要 */ --sidebar: oklch(0.2103 0.0059 285.8835); --sidebar-foreground: oklch(0.9676 0.0013 286.3752); --sidebar-primary: oklch(0.4878 0.217 264.3876); @@ -164,468 +147,6 @@ --shadow-2xl: 0 1px 3px 0px hsl(0 0% 0% / 0.25); } -/* Luxury Theme: Hermès Orange - Warm, sophisticated, generous spacing */ -.theme-hermes { - --background: oklch(0.98 0.01 60); - --foreground: oklch(0.22 0.03 45); - --card: oklch(0.97 0.015 58); - --card-foreground: oklch(0.22 0.03 45); - --popover: oklch(0.98 0.01 60); - --popover-foreground: oklch(0.22 0.03 45); - --primary: oklch(0.62 0.16 50); - --primary-foreground: oklch(0.98 0.005 60); - --secondary: oklch(0.88 0.03 55); - --secondary-foreground: oklch(0.25 0.03 45); - --muted: oklch(0.93 0.015 58); - --muted-foreground: oklch(0.48 0.04 48); - --accent: oklch(0.72 0.14 80); - --accent-foreground: oklch(0.20 0.03 45); - --destructive: oklch(0.52 0.22 25); - --destructive-foreground: oklch(0.98 0.005 60); - --border: oklch(0.86 0.025 55); - --input: oklch(0.90 0.02 56); - --ring: oklch(0.62 0.16 50); - --chart-1: oklch(0.62 0.16 50); - --chart-2: oklch(0.72 0.14 80); - --chart-3: oklch(0.58 0.15 38); - --chart-4: oklch(0.68 0.12 65); - --chart-5: oklch(0.48 0.10 45); - - /* 语义化图表颜色 - Hermès 主题 */ - --chart-success: oklch(0.72 0.14 80); /* 金黄色 - 成功 */ - --chart-error: oklch(0.48 0.10 45); /* 深棕色 - 失败 */ - --chart-warning: oklch(0.62 0.16 50); /* 主橙色 - 警告 */ - --chart-info: oklch(0.68 0.12 65); /* 浅橙色 - 信息 */ - --chart-primary: oklch(0.58 0.15 38); /* 深橙色 - 主要 */ - --radius: 1.25rem; - --sidebar: oklch(0.96 0.018 58); - --sidebar-foreground: oklch(0.28 0.03 45); - --sidebar-primary: oklch(0.62 0.16 50); - --sidebar-primary-foreground: oklch(0.98 0.005 60); - --sidebar-accent: oklch(0.91 0.025 56); - --sidebar-accent-foreground: oklch(0.28 0.03 45); - --sidebar-border: oklch(0.86 0.025 55); - --sidebar-ring: oklch(0.62 0.16 50); - - /* Typography */ - --font-sans: 'Crimson Text', 'Source Serif Pro', Georgia, serif; - --font-serif: 'Cormorant Garamond', 'Playfair Display', Georgia, serif; - --font-mono: 'IBM Plex Mono', 'SF Mono', Monaco, monospace; - - /* Spacing - Generous, luxurious */ - --spacing-multiplier: 1.2; - - /* Border Radius - Soft, rounded */ - --radius-sm: 12px; - --radius-md: 18px; - --radius-lg: 28px; - --radius-xl: 40px; - - /* Shadows - Warm, soft */ - --shadow-color: oklch(0.42 0.10 48); - --shadow-opacity: 0.15; - - /* Animation - Elegant, slower */ - --animation-multiplier: 1.35; - --animation-easing: cubic-bezier(0.34, 0, 0.15, 1); -} - -/* Luxury Theme: Tiffany Blue - Clean, elegant, modern */ -.theme-tiffany { - --background: oklch(0.99 0.003 200); - --foreground: oklch(0.24 0.015 215); - --card: oklch(0.98 0.005 198); - --card-foreground: oklch(0.24 0.015 215); - --popover: oklch(0.99 0.003 200); - --popover-foreground: oklch(0.24 0.015 215); - --primary: oklch(0.68 0.11 192); - --primary-foreground: oklch(0.99 0.003 200); - --secondary: oklch(0.92 0.008 205); - --secondary-foreground: oklch(0.26 0.015 215); - --muted: oklch(0.95 0.008 200); - --muted-foreground: oklch(0.52 0.025 208); - --accent: oklch(0.72 0.13 188); - --accent-foreground: oklch(0.99 0.003 200); - --destructive: oklch(0.54 0.21 22); - --destructive-foreground: oklch(0.99 0.003 200); - --border: oklch(0.90 0.015 196); - --input: oklch(0.92 0.012 197); - --ring: oklch(0.68 0.11 192); - --chart-1: oklch(0.68 0.11 192); - --chart-2: oklch(0.72 0.13 188); - --chart-3: oklch(0.58 0.10 184); - --chart-4: oklch(0.64 0.09 198); - --chart-5: oklch(0.76 0.008 210); - - /* 语义化图表颜色 - Tiffany 主题 */ - --chart-success: oklch(0.72 0.13 188); /* 浅蓝色 - 成功 */ - --chart-error: oklch(0.58 0.10 184); /* 深青色 - 失败 */ - --chart-warning: oklch(0.76 0.008 210); /* 浅灰蓝 - 警告 */ - --chart-info: oklch(0.68 0.11 192); /* 主蓝色 - 信息 */ - --chart-primary: oklch(0.64 0.09 198); /* 中蓝色 - 主要 */ - --radius: 1.125rem; - --sidebar: oklch(0.97 0.006 198); - --sidebar-foreground: oklch(0.28 0.015 215); - --sidebar-primary: oklch(0.68 0.11 192); - --sidebar-primary-foreground: oklch(0.99 0.003 200); - --sidebar-accent: oklch(0.93 0.012 197); - --sidebar-accent-foreground: oklch(0.28 0.015 215); - --sidebar-border: oklch(0.90 0.015 196); - --sidebar-ring: oklch(0.68 0.11 192); - - /* Typography */ - --font-sans: 'Inter Variable', 'Inter', system-ui, sans-serif; - --font-serif: 'Montserrat', 'Inter', -apple-system, sans-serif; - --font-mono: 'JetBrains Mono', 'SF Mono', monospace; - - /* Spacing - Comfortable, balanced */ - --spacing-multiplier: 1.12; - - /* Border Radius - Very rounded, friendly */ - --radius-sm: 14px; - --radius-md: 20px; - --radius-lg: 30px; - --radius-xl: 42px; - - /* Shadows - Light, airy */ - --shadow-color: oklch(0.68 0.11 192); - --shadow-opacity: 0.06; - - /* Animation - Smooth, refined */ - --animation-multiplier: 1.18; - --animation-easing: cubic-bezier(0.22, 0.08, 0.22, 1); -} - -/* Luxury Theme: Chanel Black - Minimalist, sharp, timeless */ -.theme-chanel { - --background: oklch(0.12 0.003 280); - --foreground: oklch(0.98 0.002 280); - --card: oklch(0.14 0.004 280); - --card-foreground: oklch(0.98 0.002 280); - --popover: oklch(0.12 0.003 280); - --popover-foreground: oklch(0.98 0.002 280); - --primary: oklch(0.98 0.002 280); - --primary-foreground: oklch(0.12 0.003 280); - --secondary: oklch(0.22 0.006 280); - --secondary-foreground: oklch(0.98 0.002 280); - --muted: oklch(0.20 0.005 280); - --muted-foreground: oklch(0.68 0.008 280); - --accent: oklch(0.78 0.13 82); - --accent-foreground: oklch(0.12 0.003 280); - --destructive: oklch(0.56 0.21 24); - --destructive-foreground: oklch(0.98 0.002 280); - --border: oklch(0.28 0.008 280); - --input: oklch(0.25 0.007 280); - --ring: oklch(0.78 0.13 82); - --chart-1: oklch(0.78 0.13 82); - --chart-2: oklch(0.98 0.002 280); - --chart-3: oklch(0.62 0.11 78); - --chart-4: oklch(0.68 0.008 280); - --chart-5: oklch(0.52 0.09 86); - - /* 语义化图表颜色 - Chanel 主题 */ - --chart-success: oklch(0.78 0.13 82); /* 金色 - 成功 */ - --chart-error: oklch(0.52 0.09 86); /* 深金色 - 失败 */ - --chart-warning: oklch(0.62 0.11 78); /* 浅金色 - 警告 */ - --chart-info: oklch(0.98 0.002 280); /* 白色 - 信息 */ - --chart-primary: oklch(0.68 0.008 280); /* 灰色 - 主要 */ - --radius: 0.25rem; - --sidebar: oklch(0.16 0.004 280); - --sidebar-foreground: oklch(0.96 0.002 280); - --sidebar-primary: oklch(0.98 0.002 280); - --sidebar-primary-foreground: oklch(0.12 0.003 280); - --sidebar-accent: oklch(0.22 0.006 280); - --sidebar-accent-foreground: oklch(0.96 0.002 280); - --sidebar-border: oklch(0.28 0.008 280); - --sidebar-ring: oklch(0.78 0.13 82); - - /* Typography */ - --font-sans: 'Helvetica Neue', 'Arial', sans-serif; - --font-serif: 'Futura', 'Helvetica Neue', Arial, sans-serif; - --font-mono: 'Fira Code', 'Consolas', monospace; - - /* Spacing - Compact, precise */ - --spacing-multiplier: 0.92; - - /* Border Radius - Sharp, minimal */ - --radius-sm: 2px; - --radius-md: 3px; - --radius-lg: 5px; - --radius-xl: 8px; - - /* Shadows - Strong, defined */ - --shadow-color: oklch(0.08 0.002 280); - --shadow-opacity: 0.35; - - /* Animation - Quick, snappy */ - --animation-multiplier: 0.8; - --animation-easing: cubic-bezier(0.45, 0, 0.65, 1); -} - -/* Luxury Theme: Cartier Red - Rich, opulent, dramatic */ -.theme-cartier { - --background: oklch(0.18 0.045 18); - --foreground: oklch(0.96 0.015 45); - --card: oklch(0.20 0.042 17); - --card-foreground: oklch(0.96 0.015 45); - --popover: oklch(0.18 0.045 18); - --popover-foreground: oklch(0.96 0.015 45); - --primary: oklch(0.42 0.20 18); - --primary-foreground: oklch(0.98 0.008 45); - --secondary: oklch(0.68 0.14 78); - --secondary-foreground: oklch(0.18 0.045 18); - --muted: oklch(0.26 0.035 19); - --muted-foreground: oklch(0.62 0.04 28); - --accent: oklch(0.74 0.15 82); - --accent-foreground: oklch(0.16 0.045 18); - --destructive: oklch(0.54 0.24 22); - --destructive-foreground: oklch(0.98 0.008 45); - --border: oklch(0.32 0.055 19); - --input: oklch(0.29 0.05 19); - --ring: oklch(0.68 0.14 78); - --chart-1: oklch(0.42 0.20 18); - --chart-2: oklch(0.74 0.15 82); - --chart-3: oklch(0.68 0.14 78); - --chart-4: oklch(0.58 0.12 68); - --chart-5: oklch(0.52 0.17 26); - - /* 语义化图表颜色 - Cartier 主题 */ - --chart-success: oklch(0.74 0.15 82); /* 金色 - 成功 */ - --chart-error: oklch(0.42 0.20 18); /* 深红色 - 失败 */ - --chart-warning: oklch(0.68 0.14 78); /* 玫瑰金 - 警告 */ - --chart-info: oklch(0.58 0.12 68); /* 暖棕色 - 信息 */ - --chart-primary: oklch(0.52 0.17 26); /* 深棕红 - 主要 */ - --radius: 0.5rem; - --sidebar: oklch(0.22 0.042 18); - --sidebar-foreground: oklch(0.94 0.015 45); - --sidebar-primary: oklch(0.42 0.20 18); - --sidebar-primary-foreground: oklch(0.98 0.008 45); - --sidebar-accent: oklch(0.29 0.045 19); - --sidebar-accent-foreground: oklch(0.94 0.015 45); - --sidebar-border: oklch(0.32 0.055 19); - --sidebar-ring: oklch(0.68 0.14 78); - - /* Typography */ - --font-sans: 'Lora', 'Georgia', serif; - --font-serif: 'Libre Baskerville', 'Baskerville', Georgia, serif; - --font-mono: 'Source Code Pro', Monaco, monospace; - - /* Spacing - Luxurious, spacious */ - --spacing-multiplier: 1.25; - - /* Border Radius - Classic, refined */ - --radius-sm: 7px; - --radius-md: 11px; - --radius-lg: 20px; - --radius-xl: 32px; - - /* Shadows - Deep, rich */ - --shadow-color: oklch(0.38 0.20 18); - --shadow-opacity: 0.22; - - /* Animation - Dramatic, slower */ - --animation-multiplier: 1.45; - --animation-easing: cubic-bezier(0.28, 0, 0.18, 1); -} - -/* Luxury Theme: Burberry Beige - Classic, heritage, warm */ -.theme-burberry { - --background: oklch(0.96 0.018 68); - --foreground: oklch(0.18 0.008 72); - --card: oklch(0.95 0.02 67); - --card-foreground: oklch(0.18 0.008 72); - --popover: oklch(0.96 0.018 68); - --popover-foreground: oklch(0.18 0.008 72); - --primary: oklch(0.58 0.09 62); - --primary-foreground: oklch(0.18 0.008 72); - --secondary: oklch(0.89 0.018 66); - --secondary-foreground: oklch(0.20 0.008 72); - --muted: oklch(0.90 0.015 67); - --muted-foreground: oklch(0.48 0.035 64); - --accent: oklch(0.48 0.20 22); - --accent-foreground: oklch(0.98 0.008 68); - --destructive: oklch(0.54 0.21 24); - --destructive-foreground: oklch(0.98 0.008 68); - --border: oklch(0.82 0.028 66); - --input: oklch(0.85 0.025 66); - --ring: oklch(0.58 0.09 62); - --chart-1: oklch(0.58 0.09 62); - --chart-2: oklch(0.48 0.20 22); - --chart-3: oklch(0.68 0.07 68); - --chart-4: oklch(0.54 0.11 58); - --chart-5: oklch(0.22 0.008 280); - - /* 语义化图表颜色 - Burberry 主题 */ - --chart-success: oklch(0.68 0.07 68); /* 浅卡其色 - 成功 */ - --chart-error: oklch(0.48 0.20 22); /* 深红色 - 失败 */ - --chart-warning: oklch(0.58 0.09 62); /* 卡其色 - 警告 */ - --chart-info: oklch(0.54 0.11 58); /* 暖棕色 - 信息 */ - --chart-primary: oklch(0.22 0.008 280); /* 黑色 - 主要 */ - --radius: 0.625rem; - --sidebar: oklch(0.93 0.02 67); - --sidebar-foreground: oklch(0.23 0.008 72); - --sidebar-primary: oklch(0.58 0.09 62); - --sidebar-primary-foreground: oklch(0.18 0.008 72); - --sidebar-accent: oklch(0.87 0.022 66); - --sidebar-accent-foreground: oklch(0.23 0.008 72); - --sidebar-border: oklch(0.82 0.028 66); - --sidebar-ring: oklch(0.58 0.09 62); - - /* Typography */ - --font-sans: 'Merriweather', 'Georgia', serif; - --font-serif: 'Merriweather', 'Times New Roman', serif; - --font-mono: 'Courier Prime', 'Courier New', monospace; - - /* Spacing - Traditional, balanced */ - --spacing-multiplier: 1.08; - - /* Border Radius - Classic, moderate */ - --radius-sm: 9px; - --radius-md: 13px; - --radius-lg: 20px; - --radius-xl: 30px; - - /* Shadows - Soft, natural */ - --shadow-color: oklch(0.48 0.04 64); - --shadow-opacity: 0.12; - - /* Animation - Steady, traditional */ - --animation-multiplier: 1.12; - --animation-easing: cubic-bezier(0.32, 0, 0.22, 1); -} - -/* Luxury Theme: Gucci Green - Bold, dramatic, rich contrast */ -.theme-gucci { - --background: oklch(0.16 0.045 158); - --foreground: oklch(0.96 0.018 82); - --card: oklch(0.18 0.042 157); - --card-foreground: oklch(0.96 0.018 82); - --popover: oklch(0.16 0.045 158); - --popover-foreground: oklch(0.96 0.018 82); - --primary: oklch(0.38 0.14 152); - --primary-foreground: oklch(0.98 0.008 82); - --secondary: oklch(0.42 0.20 18); - --secondary-foreground: oklch(0.98 0.008 82); - --muted: oklch(0.26 0.045 157); - --muted-foreground: oklch(0.64 0.05 148); - --accent: oklch(0.70 0.15 78); - --accent-foreground: oklch(0.14 0.045 158); - --destructive: oklch(0.54 0.24 22); - --destructive-foreground: oklch(0.98 0.008 82); - --border: oklch(0.30 0.065 157); - --input: oklch(0.28 0.06 157); - --ring: oklch(0.70 0.15 78); - --chart-1: oklch(0.38 0.14 152); - --chart-2: oklch(0.70 0.15 78); - --chart-3: oklch(0.42 0.20 18); - --chart-4: oklch(0.58 0.10 142); - --chart-5: oklch(0.48 0.12 162); - - /* 语义化图表颜色 - Gucci 主题 */ - --chart-success: oklch(0.58 0.10 142); /* 翠绿色 - 成功 */ - --chart-error: oklch(0.42 0.20 18); /* 深红色 - 失败 */ - --chart-warning: oklch(0.70 0.15 78); /* 金色 - 警告 */ - --chart-info: oklch(0.38 0.14 152); /* 主绿色 - 信息 */ - --chart-primary: oklch(0.48 0.12 162); /* 青绿色 - 主要 */ - --radius: 0.75rem; - --sidebar: oklch(0.20 0.042 157); - --sidebar-foreground: oklch(0.94 0.018 82); - --sidebar-primary: oklch(0.38 0.14 152); - --sidebar-primary-foreground: oklch(0.98 0.008 82); - --sidebar-accent: oklch(0.28 0.055 157); - --sidebar-accent-foreground: oklch(0.94 0.018 82); - --sidebar-border: oklch(0.30 0.065 157); - --sidebar-ring: oklch(0.70 0.15 78); - - /* Typography */ - --font-sans: 'Raleway', 'Helvetica', sans-serif; - --font-serif: 'Cinzel', 'Palatino', serif; - --font-mono: 'Roboto Mono', monospace; - - /* Spacing - Balanced, confident */ - --spacing-multiplier: 1.05; - - /* Border Radius - Moderate, bold */ - --radius-sm: 11px; - --radius-md: 16px; - --radius-lg: 24px; - --radius-xl: 34px; - - /* Shadows - Rich, prominent */ - --shadow-color: oklch(0.32 0.14 152); - --shadow-opacity: 0.20; - - /* Animation - Confident, steady */ - --animation-multiplier: 1.05; - --animation-easing: cubic-bezier(0.42, 0, 0.12, 1); -} - -/* Luxury Theme: Dior Gray - Refined, subtle, sophisticated */ -.theme-dior { - --background: oklch(0.97 0.003 260); - --foreground: oklch(0.26 0.008 262); - --card: oklch(0.96 0.004 260); - --card-foreground: oklch(0.26 0.008 262); - --popover: oklch(0.97 0.003 260); - --popover-foreground: oklch(0.26 0.008 262); - --primary: oklch(0.52 0.018 258); - --primary-foreground: oklch(0.98 0.002 260); - --secondary: oklch(0.88 0.006 255); - --secondary-foreground: oklch(0.28 0.008 262); - --muted: oklch(0.91 0.006 260); - --muted-foreground: oklch(0.50 0.015 260); - --accent: oklch(0.66 0.09 22); - --accent-foreground: oklch(0.98 0.002 260); - --destructive: oklch(0.54 0.21 24); - --destructive-foreground: oklch(0.98 0.002 260); - --border: oklch(0.86 0.008 260); - --input: oklch(0.89 0.006 260); - --ring: oklch(0.66 0.09 22); - --chart-1: oklch(0.52 0.018 258); - --chart-2: oklch(0.66 0.09 22); - --chart-3: oklch(0.58 0.025 256); - --chart-4: oklch(0.64 0.045 268); - --chart-5: oklch(0.72 0.008 252); - - /* 语义化图表颜色 - Dior 主题 */ - --chart-success: oklch(0.64 0.045 268); /* 紫蓝色 - 成功 */ - --chart-error: oklch(0.66 0.09 22); /* 玫瑰金 - 失败 */ - --chart-warning: oklch(0.58 0.025 256); /* 浅紫色 - 警告 */ - --chart-info: oklch(0.52 0.018 258); /* 主灰紫 - 信息 */ - --chart-primary: oklch(0.72 0.008 252); /* 浅灰色 - 主要 */ - --radius: 0.875rem; - --sidebar: oklch(0.95 0.004 260); - --sidebar-foreground: oklch(0.30 0.008 262); - --sidebar-primary: oklch(0.52 0.018 258); - --sidebar-primary-foreground: oklch(0.98 0.002 260); - --sidebar-accent: oklch(0.89 0.006 260); - --sidebar-accent-foreground: oklch(0.30 0.008 262); - --sidebar-border: oklch(0.86 0.008 260); - --sidebar-ring: oklch(0.66 0.09 22); - - /* Typography */ - --font-sans: 'Lato', 'Helvetica Neue', sans-serif; - --font-serif: 'Didot', 'Bodoni Moda', 'Playfair Display', serif; - --font-mono: 'Anonymous Pro', monospace; - - /* Spacing - Comfortable, refined */ - --spacing-multiplier: 1.1; - - /* Border Radius - Elegant, smooth */ - --radius-sm: 10px; - --radius-md: 16px; - --radius-lg: 24px; - --radius-xl: 36px; - - /* Shadows - Subtle, delicate */ - --shadow-color: oklch(0.50 0.015 260); - --shadow-opacity: 0.05; - - /* Animation - Graceful, smooth */ - --animation-multiplier: 1.25; - --animation-easing: cubic-bezier(0.26, 0.08, 0.18, 1); -} - @theme inline { --font-sans: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, @@ -644,13 +165,6 @@ --color-chart-3: var(--chart-3); --color-chart-2: var(--chart-2); --color-chart-1: var(--chart-1); - - /* 语义化图表颜色映射 (Tailwind 可用) */ - --color-chart-success: var(--chart-success); - --color-chart-error: var(--chart-error); - --color-chart-warning: var(--chart-warning); - --color-chart-info: var(--chart-info); - --color-chart-primary: var(--chart-primary); --color-ring: var(--ring); --color-input: var(--input); --color-border: var(--border); diff --git a/web/src/lib/transport/http-transport.ts b/web/src/lib/transport/http-transport.ts index 27abc26f..4ba158a6 100644 --- a/web/src/lib/transport/http-transport.ts +++ b/web/src/lib/transport/http-transport.ts @@ -44,10 +44,13 @@ import type { RoutePositionUpdate, UsageStats, UsageStatsFilter, + RecalculateCostsResult, + RecalculateRequestCostResult, DashboardData, BackupFile, BackupImportOptions, BackupImportResult, + PriceTable, } from './types'; export class HttpTransport implements Transport { @@ -273,8 +276,15 @@ export class HttpTransport implements Transport { return data ?? { items: [], hasMore: false }; } - async getProxyRequestsCount(): Promise { - const { data } = await this.client.get('/requests/count'); + async getProxyRequestsCount(providerId?: number, status?: string): Promise { + const params: Record = {}; + if (providerId !== undefined) { + params.providerId = String(providerId); + } + if (status !== undefined) { + params.status = status; + } + const { data } = await this.client.get('/requests/count', { params }); return data ?? 0; } @@ -525,6 +535,7 @@ export class HttpTransport implements Transport { if (filter?.projectId) params.set('projectId', String(filter.projectId)); if (filter?.clientType) params.set('clientType', filter.clientType); if (filter?.apiTokenId) params.set('apiTokenId', String(filter.apiTokenId)); + if (filter?.model) params.set('model', filter.model); const query = params.toString(); const url = query ? `/usage-stats?${query}` : '/usage-stats'; @@ -536,6 +547,20 @@ export class HttpTransport implements Transport { await this.client.post('/usage-stats/recalculate'); } + async recalculateCosts(): Promise { + const { data } = await this.client.post( + '/usage-stats/recalculate-costs', + ); + return data; + } + + async recalculateRequestCost(requestId: number): Promise { + const { data } = await this.client.post( + `/requests/${requestId}/recalculate-cost`, + ); + return data; + } + // ===== Dashboard API ===== async getDashboardData(): Promise { @@ -568,6 +593,13 @@ export class HttpTransport implements Transport { return data; } + // ===== Pricing API ===== + + async getPricing(): Promise { + const { data } = await this.client.get('/pricing'); + return data; + } + // ===== WebSocket 订阅 ===== subscribe(eventType: WSMessageType, callback: EventCallback): UnsubscribeFn { diff --git a/web/src/lib/transport/index.ts b/web/src/lib/transport/index.ts index ec8983c2..60d7727e 100644 --- a/web/src/lib/transport/index.ts +++ b/web/src/lib/transport/index.ts @@ -66,6 +66,10 @@ export type { UsageStats, UsageStatsFilter, StatsGranularity, + RecalculateRequestCostResult, + RecalculateCostsResult, + RecalculateCostsProgress, + RecalculateStatsProgress, // Dashboard DashboardData, DashboardDaySummary, @@ -74,6 +78,9 @@ export type { DashboardModelStats, DashboardTrendPoint, DashboardProviderStats, + // Pricing + ModelPricing, + PriceTable, } from './types'; export type { Transport, TransportType, TransportConfig } from './interface'; diff --git a/web/src/lib/transport/interface.ts b/web/src/lib/transport/interface.ts index 0361ced9..5468fa1d 100644 --- a/web/src/lib/transport/interface.ts +++ b/web/src/lib/transport/interface.ts @@ -41,10 +41,13 @@ import type { RoutePositionUpdate, UsageStats, UsageStatsFilter, + RecalculateCostsResult, + RecalculateRequestCostResult, DashboardData, BackupFile, BackupImportOptions, BackupImportResult, + PriceTable, } from './types'; /** @@ -100,7 +103,7 @@ export interface Transport { // ===== ProxyRequest API (只读) ===== getProxyRequests(params?: CursorPaginationParams): Promise>; - getProxyRequestsCount(): Promise; + getProxyRequestsCount(providerId?: number, status?: string): Promise; getActiveProxyRequests(): Promise; getProxyRequest(id: number): Promise; getProxyUpstreamAttempts(proxyRequestId: number): Promise; @@ -164,6 +167,8 @@ export interface Transport { // ===== Usage Stats API ===== getUsageStats(filter?: UsageStatsFilter): Promise; recalculateUsageStats(): Promise; + recalculateCosts(): Promise; + recalculateRequestCost(requestId: number): Promise; // ===== Dashboard API ===== getDashboardData(): Promise; @@ -175,6 +180,9 @@ export interface Transport { exportBackup(): Promise; importBackup(backup: BackupFile, options?: BackupImportOptions): Promise; + // ===== Pricing API ===== + getPricing(): Promise; + // ===== 实时订阅 ===== subscribe(eventType: WSMessageType, callback: EventCallback): UnsubscribeFn; diff --git a/web/src/lib/transport/types.ts b/web/src/lib/transport/types.ts index 37c25849..05003040 100644 --- a/web/src/lib/transport/types.ts +++ b/web/src/lib/transport/types.ts @@ -181,6 +181,7 @@ export interface ProxyRequest { startTime: string; endTime: string; duration: number; // nanoseconds + ttft: number; // nanoseconds - Time To First Token (首字时长) isStream: boolean; // 是否为 SSE 流式请求 status: ProxyRequestStatus; statusCode: number; // HTTP 状态码(冗余存储,用于列表查询优化) @@ -220,6 +221,7 @@ export interface ProxyUpstreamAttempt { startTime: string; endTime: string; duration: number; // nanoseconds + ttft: number; // nanoseconds - Time To First Token (首字时长) status: ProxyUpstreamAttemptStatus; proxyRequestID: number; isStream: boolean; // 是否为 SSE 流式请求 @@ -254,6 +256,10 @@ export interface CursorPaginationParams { before?: number; /** 获取 id 大于此值的记录 (向前翻页/获取新数据) */ after?: number; + /** 按 Provider ID 过滤 */ + providerId?: number; + /** 按状态过滤 */ + status?: string; } /** 游标分页响应 */ @@ -277,6 +283,8 @@ export type WSMessageType = | 'new_session_pending' | 'session_pending_cancelled' | 'cooldown_update' + | 'recalculate_costs_progress' + | 'recalculate_stats_progress' | '_ws_reconnected'; // 内部事件:WebSocket 重连成功 export interface WSMessage { @@ -523,7 +531,7 @@ export interface CreateAPITokenData { // ===== Usage Stats ===== /** 统计数据时间粒度 */ -export type StatsGranularity = 'minute' | 'hour' | 'day' | 'week' | 'month'; +export type StatsGranularity = 'minute' | 'hour' | 'day' | 'week' | 'month' | 'year'; export interface UsageStats { id: number; @@ -540,6 +548,7 @@ export interface UsageStats { successfulRequests: number; failedRequests: number; totalDurationMs: number; // 累计请求耗时(毫秒) + totalTtftMs: number; // 累计首字时长(毫秒) inputTokens: number; outputTokens: number; cacheRead: number; @@ -572,6 +581,41 @@ export interface UsageStatsFilter { model?: string; // 模型名称 } +/** RecalculateCostsResult - 全量成本重算结果 */ +export interface RecalculateCostsResult { + totalAttempts: number; + updatedAttempts: number; + updatedRequests: number; + message: string; +} + +/** RecalculateCostsProgress - 成本重算进度更新 */ +export interface RecalculateCostsProgress { + phase: 'calculating' | 'updating_attempts' | 'updating_requests' | 'completed'; + current: number; + total: number; + percentage: number; + message: string; +} + +/** RecalculateStatsProgress - 统计重算进度更新 */ +export interface RecalculateStatsProgress { + phase: 'clearing' | 'aggregating' | 'rollup' | 'completed'; + current: number; + total: number; + percentage: number; + message: string; +} + +/** RecalculateRequestCostResult - 单条请求成本重算结果 */ +export interface RecalculateRequestCostResult { + requestId: number; + oldCost: number; + newCost: number; + updatedAttempts: number; + message: string; +} + /** Response Model - 记录所有出现过的 response model */ export interface ResponseModel { id: number; @@ -747,3 +791,27 @@ export interface DashboardData { providerStats: Record; timezone: string; // 配置的时区,如 "Asia/Shanghai" } + +// ===== Pricing API Types ===== + +/** 单个模型的价格配置 - 价格单位:微美元/百万tokens */ +export interface ModelPricing { + modelId: string; + inputPriceMicro: number; // 输入价格 (microUSD/M tokens) + outputPriceMicro: number; // 输出价格 (microUSD/M tokens) + cacheReadPriceMicro?: number; // 缓存读取价格,默认 input / 10 + cache5mWritePriceMicro?: number; // 5分钟缓存写入,默认 input * 1.25 + cache1hWritePriceMicro?: number; // 1小时缓存写入,默认 input * 2 + has1mContext?: boolean; // 是否支持 1M context + context1mThreshold?: number; // 1M context 阈值,默认 200000 + inputPremiumNum?: number; // 超阈值 input 倍率分子 + inputPremiumDenom?: number; // 超阈值 input 倍率分母 + outputPremiumNum?: number; // 超阈值 output 倍率分子 + outputPremiumDenom?: number; // 超阈值 output 倍率分母 +} + +/** 完整价格表 */ +export interface PriceTable { + version: string; + models: Record; +} diff --git a/web/src/locales/en.json b/web/src/locales/en.json index 52a19a74..414339ea 100644 --- a/web/src/locales/en.json +++ b/web/src/locales/en.json @@ -21,6 +21,7 @@ "created": "Created", "updated": "Updated", "reset": "Reset", + "apply": "Apply", "import": "Import", "export": "Export", "refresh": "Refresh", @@ -108,6 +109,8 @@ "noRequests": "No requests recorded", "noRequestsHint": "Requests will appear here automatically", "refresh": "Refresh", + "allProviders": "All Providers", + "allStatuses": "All Statuses", "time": "Time", "client": "Client", "model": "Model", @@ -116,6 +119,7 @@ "provider": "Provider", "code": "Code", "duration": "Duration", + "ttft": "Time To First Token", "cost": "Cost", "attempts": "Attempts", "attShort": "Att.", @@ -572,13 +576,22 @@ "title": "Statistics", "description": "Usage statistics and analytics", "filters": "Filters", + "filter": "Filter", + "filterConditions": "Filter Conditions", "timeRange": "Time Range", + "today": "Today", + "yesterday": "Yesterday", + "thisWeek": "This Week", + "lastWeek": "Last Week", + "thisMonth": "This Month", + "lastMonth": "Last Month", "last1h": "Last 1 hour", "last24h": "Last 24 hours", "last7d": "Last 7 days", "last30d": "Last 30 days", "last90d": "Last 90 days", "allTime": "All Time", + "filterSummary": "Filters", "provider": "Provider", "allProviders": "All Providers", "project": "Project", @@ -607,6 +620,7 @@ "outputTokens": "Output Tokens", "cacheRead": "Cache Read", "cacheWrite": "Cache Write", + "cacheHit": "Cache Hit", "totalCost": "Total Cost", "dataPoints": "Data Points", "chart": "Statistics Chart", @@ -616,7 +630,9 @@ "costUSD": "Cost (USD)", "successful": "Successful", "failed": "Failed", - "recalculate": "Recalculate" + "recalculate": "Recalculate", + "recalculateCosts": "Recalculate Costs", + "recalculateStats": "Re-aggregate Stats" }, "addProvider": { "title": "Add Provider", diff --git a/web/src/locales/zh.json b/web/src/locales/zh.json index 01598289..8fb31a10 100644 --- a/web/src/locales/zh.json +++ b/web/src/locales/zh.json @@ -21,6 +21,7 @@ "created": "创建时间", "updated": "更新时间", "reset": "重置", + "apply": "应用", "import": "导入", "export": "导出", "refresh": "刷新", @@ -108,6 +109,8 @@ "noRequests": "暂无请求记录", "noRequestsHint": "请求将自动显示在这里", "refresh": "刷新", + "allProviders": "全部供应商", + "allStatuses": "全部状态", "time": "时间", "client": "客户端", "model": "模型", @@ -116,6 +119,7 @@ "provider": "提供商", "code": "状态码", "duration": "耗时", + "ttft": "首字时长", "cost": "费用", "attempts": "尝试次数", "attShort": "尝试", @@ -571,13 +575,22 @@ "title": "统计", "description": "使用统计和分析", "filters": "筛选", + "filter": "筛选", + "filterConditions": "筛选条件", "timeRange": "时间范围", + "today": "今天", + "yesterday": "昨天", + "thisWeek": "本周", + "lastWeek": "上周", + "thisMonth": "本月", + "lastMonth": "上月", "last1h": "最近 1 小时", "last24h": "最近 24 小时", "last7d": "最近 7 天", "last30d": "最近 30 天", "last90d": "最近 90 天", "allTime": "全部时间", + "filterSummary": "筛选条件", "provider": "提供商", "allProviders": "所有提供商", "project": "项目", @@ -606,6 +619,7 @@ "outputTokens": "输出 Token", "cacheRead": "缓存读取", "cacheWrite": "缓存写入", + "cacheHit": "缓存命中", "totalCost": "总费用", "dataPoints": "数据点", "chart": "统计图表", @@ -615,7 +629,9 @@ "costUSD": "费用 (USD)", "successful": "成功", "failed": "失败", - "recalculate": "重新计算" + "recalculate": "重新计算", + "recalculateCosts": "重算成本", + "recalculateStats": "重新聚合" }, "addProvider": { "title": "添加提供商", diff --git a/web/src/pages/client-routes/components/provider-row.tsx b/web/src/pages/client-routes/components/provider-row.tsx index a6433d4c..276db431 100644 --- a/web/src/pages/client-routes/components/provider-row.tsx +++ b/web/src/pages/client-routes/components/provider-row.tsx @@ -25,16 +25,17 @@ function formatTokens(count: number): string { return count.toString(); } -// 格式化成本 (微美元 → 美元) -function formatCost(microUsd: number): string { - const usd = microUsd / 1_000_000; +// 格式化成本 (纳美元 → 美元,向下取整到 6 位) +function formatCost(nanoUsd: number): string { + // 向下取整到 6 位小数 (microUSD 精度) + const usd = Math.floor(nanoUsd / 1000) / 1_000_000; if (usd >= 1) { return `$${usd.toFixed(2)}`; } if (usd >= 0.01) { return `$${usd.toFixed(3)}`; } - return `$${usd.toFixed(4)}`; + return `$${usd.toFixed(6).replace(/\.?0+$/, '')}`; } // Sortable Provider Row diff --git a/web/src/pages/overview.tsx b/web/src/pages/overview.tsx index be1fe668..1da8d5be 100644 --- a/web/src/pages/overview.tsx +++ b/web/src/pages/overview.tsx @@ -78,16 +78,17 @@ function formatNumber(num: number): string { return num.toLocaleString(); } -// 格式化成本 -function formatCost(microUsd: number): string { - const usd = microUsd / 1000000; +// 格式化成本 (纳美元 → 美元,向下取整到 6 位) +function formatCost(nanoUsd: number): string { + // 向下取整到 6 位小数 (microUSD 精度) + const usd = Math.floor(nanoUsd / 1000) / 1_000_000; if (usd >= 1000) { return '$' + (usd / 1000).toFixed(1) + 'K'; } if (usd >= 1) { return '$' + usd.toFixed(2); } - return '$' + usd.toFixed(4); + return '$' + usd.toFixed(6).replace(/\.?0+$/, ''); } // 格式化相对时间 diff --git a/web/src/pages/providers/components/provider-row.tsx b/web/src/pages/providers/components/provider-row.tsx index d06df46a..d8a2191c 100644 --- a/web/src/pages/providers/components/provider-row.tsx +++ b/web/src/pages/providers/components/provider-row.tsx @@ -20,16 +20,17 @@ function formatTokens(count: number): string { return count.toString(); } -// 格式化成本 (微美元 → 美元) -function formatCost(microUsd: number): string { - const usd = microUsd / 1_000_000; +// 格式化成本 (纳美元 → 美元,向下取整到 6 位) +function formatCost(nanoUsd: number): string { + // 向下取整到 6 位小数 (microUSD 精度) + const usd = Math.floor(nanoUsd / 1000) / 1_000_000; if (usd >= 1) { return `$${usd.toFixed(2)}`; } if (usd >= 0.01) { return `$${usd.toFixed(3)}`; } - return `$${usd.toFixed(4)}`; + return `$${usd.toFixed(6).replace(/\.?0+$/, '')}`; } interface ProviderRowProps { diff --git a/web/src/pages/providers/index.tsx b/web/src/pages/providers/index.tsx index a92b4e6c..9a4dd8ff 100644 --- a/web/src/pages/providers/index.tsx +++ b/web/src/pages/providers/index.tsx @@ -53,9 +53,9 @@ export function ProvidersPage() { } }); - // 按 id 倒序排列(新创建的在前) + // 按名称字母顺序排列 for (const key of Object.keys(groups) as ProviderTypeKey[]) { - groups[key].sort((a, b) => b.id - a.id); + groups[key].sort((a, b) => a.name.localeCompare(b.name)); } return groups; diff --git a/web/src/pages/requests/detail.tsx b/web/src/pages/requests/detail.tsx index 19d2e60d..a0975571 100644 --- a/web/src/pages/requests/detail.tsx +++ b/web/src/pages/requests/detail.tsx @@ -1,6 +1,8 @@ -import { useState, useMemo, useEffect } from 'react'; +import { useState, useMemo, useEffect, useCallback } from 'react'; import { useParams, useNavigate } from 'react-router-dom'; import { AlertCircle, Loader2 } from 'lucide-react'; +import { useMutation, useQueryClient } from '@tanstack/react-query'; +import { useTransport } from '@/lib/transport'; import { useProxyRequest, useProxyUpstreamAttempts, @@ -10,6 +12,7 @@ import { useSessions, useRoutes, useAPITokens, + requestKeys, } from '@/hooks/queries'; import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '@/components/ui/resizable'; import { RequestHeader } from './detail/RequestHeader'; @@ -22,6 +25,8 @@ type SelectionType = { type: 'request' } | { type: 'attempt'; attemptId: number export function RequestDetailPage() { const { id } = useParams<{ id: string }>(); const navigate = useNavigate(); + const { transport } = useTransport(); + const queryClient = useQueryClient(); const { data: request, isLoading, error } = useProxyRequest(Number(id)); const { data: attempts } = useProxyUpstreamAttempts(Number(id)); const { data: providers } = useProviders(); @@ -34,6 +39,18 @@ export function RequestDetailPage() { }); const [activeTab, setActiveTab] = useState<'request' | 'response' | 'metadata'>('request'); + // Recalculate cost mutation + const recalculateMutation = useMutation({ + mutationFn: () => transport.recalculateRequestCost(Number(id)), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: requestKeys.detail(Number(id)) }); + }, + }); + + const handleRecalculateCost = useCallback(() => { + recalculateMutation.mutate(); + }, [recalculateMutation]); + useProxyRequestUpdates(); // ESC 键返回列表 @@ -120,7 +137,12 @@ export function RequestDetailPage() { return (
{/* Header */} - navigate('/requests')} /> + navigate('/requests')} + onRecalculateCost={handleRecalculateCost} + isRecalculating={recalculateMutation.isPending} + /> {/* Error Banner */} {request.error && ( diff --git a/web/src/pages/requests/detail/RequestDetailPanel.tsx b/web/src/pages/requests/detail/RequestDetailPanel.tsx index bb507c92..3b75f67a 100644 --- a/web/src/pages/requests/detail/RequestDetailPanel.tsx +++ b/web/src/pages/requests/detail/RequestDetailPanel.tsx @@ -11,28 +11,209 @@ import { } from '@/components/ui'; import { Server, Code, Database, Info, Zap } from 'lucide-react'; import { useTranslation } from 'react-i18next'; -import type { ProxyUpstreamAttempt, ProxyRequest } from '@/lib/transport'; -import { cn } from '@/lib/utils'; +import type { ProxyUpstreamAttempt, ProxyRequest, ModelPricing } from '@/lib/transport'; +import { cn, formatDuration } from '@/lib/utils'; import { CopyButton, CopyAsCurlButton, DiffButton, EmptyState } from './components'; import { RequestDetailView } from './RequestDetailView'; +import { usePricing } from '@/hooks/queries'; // Selection type: either the main request or an attempt type SelectionType = { type: 'request' } | { type: 'attempt'; attemptId: number }; -// 微美元转美元 (1 USD = 1,000,000 microUSD) -const MICRO_USD_PER_USD = 1_000_000; -function microToUSD(microUSD: number): number { - return microUSD / MICRO_USD_PER_USD; +function formatCost(nanoUSD: number): string { + if (nanoUSD === 0) return '-'; + // 向下取整到 6 位小数 (microUSD 精度) + const usd = Math.floor(nanoUSD / 1000) / 1_000_000; + return `$${usd.toFixed(6)}`; } -function formatCost(microUSD: number): string { - if (microUSD === 0) return '-'; - const usd = microToUSD(microUSD); - if (usd < 0.0001) return '<$0.0001'; - if (usd < 0.001) return `$${usd.toFixed(5)}`; - if (usd < 0.01) return `$${usd.toFixed(4)}`; - if (usd < 1) return `$${usd.toFixed(3)}`; - return `$${usd.toFixed(2)}`; +// Cost breakdown item +export interface CostBreakdownItem { + label: string; + tokens: number; + pricePerM: number; // microUSD/M tokens + cost: number; // nanoUSD +} + +// Cost breakdown result +export interface CostBreakdown { + model: string; + pricing?: ModelPricing; + items: CostBreakdownItem[]; + totalCost: number; // nanoUSD +} + +// MicroToNano conversion factor +const MICRO_TO_NANO = 1000; + +// Calculate linear cost (same as backend CalculateLinearCost) +// Returns nanoUSD +function calculateLinearCost(tokens: number, priceMicro: number): number { + // Use BigInt to prevent overflow for large token counts + const t = BigInt(tokens); + const p = BigInt(priceMicro); + const microToNano = BigInt(MICRO_TO_NANO); + const tokensPerMillion = BigInt(1_000_000); + + const result = (t * p * microToNano) / tokensPerMillion; + return Number(result); +} + +// Calculate tiered cost for 1M context models (same as backend CalculateTieredCost) +// Returns nanoUSD +function calculateTieredCost( + tokens: number, + basePriceMicro: number, + premiumNum: number, + premiumDenom: number, + threshold: number, +): number { + if (tokens <= threshold) { + return calculateLinearCost(tokens, basePriceMicro); + } + + const baseCostNano = calculateLinearCost(threshold, basePriceMicro); + const premiumTokens = tokens - threshold; + + // Use BigInt for premium calculation + const t = BigInt(premiumTokens); + const p = BigInt(basePriceMicro); + const microToNano = BigInt(MICRO_TO_NANO); + const tokensPerMillion = BigInt(1_000_000); + const num = BigInt(premiumNum); + const denom = BigInt(premiumDenom); + + const premiumCostNano = (t * p * microToNano * num) / tokensPerMillion / denom; + return baseCostNano + Number(premiumCostNano); +} + +// Calculate cost breakdown from request/attempt data and pricing table +function calculateCostBreakdown( + model: string, + inputTokens: number, + outputTokens: number, + cacheReadTokens: number, + cacheWriteTokens: number, + cache5mWriteTokens: number, + cache1hWriteTokens: number, + priceTable?: Record, +): CostBreakdown { + const items: CostBreakdownItem[] = []; + let pricing: ModelPricing | undefined; + + // Find pricing for model (exact match first, then prefix match) + if (priceTable) { + pricing = priceTable[model]; + if (!pricing) { + // Try prefix match (find longest matching prefix) + let bestMatch: ModelPricing | undefined; + let bestLen = 0; + for (const [key, p] of Object.entries(priceTable)) { + if (model.startsWith(key) && key.length > bestLen) { + bestMatch = p; + bestLen = key.length; + } + } + pricing = bestMatch; + } + } + + if (pricing) { + // Get 1M context settings + const has1MContext = pricing.has1mContext || false; + const threshold = pricing.context1mThreshold || 200_000; + const inputPremiumNum = pricing.inputPremiumNum || 2; + const inputPremiumDenom = pricing.inputPremiumDenom || 1; + const outputPremiumNum = pricing.outputPremiumNum || 3; + const outputPremiumDenom = pricing.outputPremiumDenom || 2; + + // Input tokens + if (inputTokens > 0) { + const cost = has1MContext + ? calculateTieredCost( + inputTokens, + pricing.inputPriceMicro, + inputPremiumNum, + inputPremiumDenom, + threshold, + ) + : calculateLinearCost(inputTokens, pricing.inputPriceMicro); + items.push({ + label: 'Input', + tokens: inputTokens, + pricePerM: pricing.inputPriceMicro, + cost, + }); + } + + // Output tokens + if (outputTokens > 0) { + const cost = has1MContext + ? calculateTieredCost( + outputTokens, + pricing.outputPriceMicro, + outputPremiumNum, + outputPremiumDenom, + threshold, + ) + : calculateLinearCost(outputTokens, pricing.outputPriceMicro); + items.push({ + label: 'Output', + tokens: outputTokens, + pricePerM: pricing.outputPriceMicro, + cost, + }); + } + + // Cache read + if (cacheReadTokens > 0) { + const cacheReadPrice = + pricing.cacheReadPriceMicro || Math.floor(pricing.inputPriceMicro / 10); + items.push({ + label: 'Cache Read', + tokens: cacheReadTokens, + pricePerM: cacheReadPrice, + cost: calculateLinearCost(cacheReadTokens, cacheReadPrice), + }); + } + + // Cache write (5m or 1h) + if (cache5mWriteTokens > 0) { + const cache5mPrice = + pricing.cache5mWritePriceMicro || Math.floor((pricing.inputPriceMicro * 5) / 4); + items.push({ + label: 'Cache Write (5m)', + tokens: cache5mWriteTokens, + pricePerM: cache5mPrice, + cost: calculateLinearCost(cache5mWriteTokens, cache5mPrice), + }); + } + if (cache1hWriteTokens > 0) { + const cache1hPrice = + pricing.cache1hWritePriceMicro || Math.floor(pricing.inputPriceMicro * 2); + items.push({ + label: 'Cache Write (1h)', + tokens: cache1hWriteTokens, + pricePerM: cache1hPrice, + cost: calculateLinearCost(cache1hWriteTokens, cache1hPrice), + }); + } + // Fallback: if no 5m/1h breakdown but has cacheWrite + if (cache5mWriteTokens === 0 && cache1hWriteTokens === 0 && cacheWriteTokens > 0) { + const cacheWritePrice = + pricing.cache5mWritePriceMicro || Math.floor((pricing.inputPriceMicro * 5) / 4); + items.push({ + label: 'Cache Write', + tokens: cacheWriteTokens, + pricePerM: cacheWritePrice, + cost: calculateLinearCost(cacheWriteTokens, cacheWritePrice), + }); + } + } + + const totalCost = items.reduce((sum, item) => sum + item.cost, 0); + + return { model, pricing, items, totalCost }; } function formatJSON(obj: unknown): string { @@ -68,9 +249,24 @@ export function RequestDetailPanel({ tokenMap, }: RequestDetailPanelProps) { const { t } = useTranslation(); + const { data: priceTable } = usePricing(); const selectedAttempt = selection.type === 'attempt' ? attempts?.find((a) => a.id === selection.attemptId) : null; + // Calculate cost breakdown for request + const requestCostBreakdown = priceTable + ? calculateCostBreakdown( + request.responseModel || request.requestModel, + request.inputTokenCount, + request.outputTokenCount, + request.cacheReadCount, + request.cacheWriteCount, + request.cache5mWriteCount || 0, + request.cache1hWriteCount || 0, + priceTable.models, + ) + : undefined; + if (selection.type === 'request') { return ( ); } @@ -407,6 +604,14 @@ export function RequestDetailPanel({
+
+
+ TTFT +
+
+ {selectedAttempt.ttft && selectedAttempt.ttft > 0 ? formatDuration(selectedAttempt.ttft) : '-'} +
+
Input Tokens diff --git a/web/src/pages/requests/detail/RequestDetailView.tsx b/web/src/pages/requests/detail/RequestDetailView.tsx index 1ee804b7..0a9f9f06 100644 --- a/web/src/pages/requests/detail/RequestDetailView.tsx +++ b/web/src/pages/requests/detail/RequestDetailView.tsx @@ -12,20 +12,33 @@ import { import { Code, Database, Info, Zap } from 'lucide-react'; import { useTranslation } from 'react-i18next'; import type { ProxyRequest, ClientType } from '@/lib/transport'; -import { cn } from '@/lib/utils'; +import { cn, formatDuration } from '@/lib/utils'; import { ClientIcon, getClientName, getClientColor } from '@/components/icons/client-icons'; import { CopyButton, CopyAsCurlButton, EmptyState } from './components'; +import type { CostBreakdown } from './RequestDetailPanel'; + +// 微美元转美元 +const MICRO_USD_PER_USD = 1_000_000; + +// 格式化价格 (microUSD/M tokens -> $/M tokens) +function formatPricePerM(priceMicro: number): string { + const usd = priceMicro / MICRO_USD_PER_USD; + if (usd < 0.01) return `$${usd.toFixed(4)}/M`; + if (usd < 1) return `$${usd.toFixed(2)}/M`; + return `$${usd.toFixed(2)}/M`; +} interface RequestDetailViewProps { request: ProxyRequest; activeTab: 'request' | 'response' | 'metadata'; setActiveTab: (tab: 'request' | 'response' | 'metadata') => void; formatJSON: (obj: unknown) => string; - formatCost: (microUSD: number) => string; + formatCost: (nanoUSD: number) => string; projectName?: string; sessionInfo?: { clientType: string; projectID: number }; projectMap: Map; tokenName?: string; + costBreakdown?: CostBreakdown; } export function RequestDetailView({ @@ -38,6 +51,7 @@ export function RequestDetailView({ sessionInfo, projectMap, tokenName, + costBreakdown, }: RequestDetailViewProps) { const { t } = useTranslation(); return ( @@ -345,34 +359,79 @@ export function RequestDetailView({
- Input Tokens + TTFT
- {request.inputTokenCount.toLocaleString()} + {request.ttft && request.ttft > 0 ? formatDuration(request.ttft) : '-'} +
+
+
+
+ Input Tokens +
+
+ {request.inputTokenCount.toLocaleString()} + {costBreakdown?.items.find((i) => i.label === 'Input') && ( + + × {formatPricePerM(costBreakdown.items.find((i) => i.label === 'Input')!.pricePerM)} ={' '} + + {formatCost(costBreakdown.items.find((i) => i.label === 'Input')!.cost)} + + + )}
Output Tokens
-
- {request.outputTokenCount.toLocaleString()} +
+ {request.outputTokenCount.toLocaleString()} + {costBreakdown?.items.find((i) => i.label === 'Output') && ( + + × {formatPricePerM(costBreakdown.items.find((i) => i.label === 'Output')!.pricePerM)} ={' '} + + {formatCost(costBreakdown.items.find((i) => i.label === 'Output')!.cost)} + + + )}
Cache Read
-
- {request.cacheReadCount.toLocaleString()} +
+ {request.cacheReadCount.toLocaleString()} + {costBreakdown?.items.find((i) => i.label === 'Cache Read') && ( + + × {formatPricePerM(costBreakdown.items.find((i) => i.label === 'Cache Read')!.pricePerM)} ={' '} + + {formatCost(costBreakdown.items.find((i) => i.label === 'Cache Read')!.cost)} + + + )}
Cache Write
-
- {request.cacheWriteCount.toLocaleString()} +
+ {request.cacheWriteCount.toLocaleString()} + {(() => { + const cache5m = costBreakdown?.items.find((i) => i.label === 'Cache Write (5m)'); + const cache1h = costBreakdown?.items.find((i) => i.label === 'Cache Write (1h)'); + const cacheWrite = costBreakdown?.items.find((i) => i.label === 'Cache Write'); + const item = cache5m || cache1h || cacheWrite; + if (!item) return null; + return ( + + × {formatPricePerM(item.pricePerM)} ={' '} + {formatCost(item.cost)} + + ); + })()}
{(request.cache5mWriteCount > 0 || request.cache1hWriteCount > 0) && ( @@ -382,14 +441,29 @@ export function RequestDetailView({ | 1h: {request.cache1hWriteCount}
+
+ {(() => { + const cache5m = costBreakdown?.items.find((i) => i.label === 'Cache Write (5m)'); + const cache1h = costBreakdown?.items.find((i) => i.label === 'Cache Write (1h)'); + const parts: string[] = []; + if (cache5m) parts.push(`5m: ${formatCost(cache5m.cost)}`); + if (cache1h) parts.push(`1h: ${formatCost(cache1h.cost)}`); + return parts.length > 0 ? parts.join(' | ') : null; + })()} +
)}
Cost
-
- {formatCost(request.cost)} +
+ {formatCost(request.cost)} + {costBreakdown && costBreakdown.totalCost !== request.cost && ( + + (计算: {formatCost(costBreakdown.totalCost)}) + + )}
diff --git a/web/src/pages/requests/detail/RequestHeader.tsx b/web/src/pages/requests/detail/RequestHeader.tsx index a3fba0e2..1c6a5b91 100644 --- a/web/src/pages/requests/detail/RequestHeader.tsx +++ b/web/src/pages/requests/detail/RequestHeader.tsx @@ -1,24 +1,16 @@ -import { Button, Badge } from '@/components/ui'; -import { ArrowLeft } from 'lucide-react'; +import { Badge, Button } from '@/components/ui'; +import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; +import { ArrowLeft, RefreshCw } from 'lucide-react'; import { statusVariant } from '../index'; import type { ProxyRequest, ClientType } from '@/lib/transport'; import { ClientIcon, getClientName, getClientColor } from '@/components/icons/client-icons'; import { formatDuration } from '@/lib/utils'; -// 微美元转美元 (1 USD = 1,000,000 microUSD) -const MICRO_USD_PER_USD = 1_000_000; -function microToUSD(microUSD: number): number { - return microUSD / MICRO_USD_PER_USD; -} - -function formatCost(microUSD: number): string { - if (microUSD === 0) return '-'; - const usd = microToUSD(microUSD); - if (usd < 0.0001) return '<$0.0001'; - if (usd < 0.001) return `$${usd.toFixed(5)}`; - if (usd < 0.01) return `$${usd.toFixed(4)}`; - if (usd < 1) return `$${usd.toFixed(3)}`; - return `$${usd.toFixed(2)}`; +function formatCost(nanoUSD: number): string { + if (nanoUSD === 0) return '-'; + // 向下取整到 6 位小数 (microUSD 精度) + const usd = Math.floor(nanoUSD / 1000) / 1_000_000; + return `$${usd.toFixed(6)}`; } function formatTime(timestamp: string): string { @@ -35,9 +27,16 @@ function formatTime(timestamp: string): string { interface RequestHeaderProps { request: ProxyRequest; onBack: () => void; + onRecalculateCost?: () => void; + isRecalculating?: boolean; } -export function RequestHeader({ request, onBack }: RequestHeaderProps) { +export function RequestHeader({ + request, + onBack, + onRecalculateCost, + isRecalculating, +}: RequestHeaderProps) { return (
@@ -89,6 +88,15 @@ export function RequestHeader({ request, onBack }: RequestHeaderProps) { {/* Right: Stats Grid */}
+
+
+ TTFT +
+
+ {request.ttft && request.ttft > 0 ? formatDuration(request.ttft) : '-'} +
+
+
Duration @@ -138,8 +146,20 @@ export function RequestHeader({ request, onBack }: RequestHeaderProps) {
Cost
-
+
{formatCost(request.cost)} + {onRecalculateCost && ( + + + + + Recalculate Cost + + )}
diff --git a/web/src/pages/requests/index.tsx b/web/src/pages/requests/index.tsx index 3b87aef0..755e9ed5 100644 --- a/web/src/pages/requests/index.tsx +++ b/web/src/pages/requests/index.tsx @@ -1,4 +1,4 @@ -import { useState, useEffect } from 'react'; +import { useState, useEffect, useMemo } from 'react'; import { useNavigate } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; import { @@ -20,7 +20,7 @@ import { AlertTriangle, Ban, } from 'lucide-react'; -import type { ProxyRequest, ProxyRequestStatus } from '@/lib/transport'; +import type { ProxyRequest, ProxyRequestStatus, Provider } from '@/lib/transport'; import { ClientIcon } from '@/components/icons/client-icons'; import { Table, @@ -30,10 +30,27 @@ import { TableHeader, TableRow, Badge, + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, + SelectGroup, + SelectLabel, } from '@/components/ui'; import { cn } from '@/lib/utils'; import { PageHeader } from '@/components/layout/page-header'; +type ProviderTypeKey = 'antigravity' | 'kiro' | 'custom'; + +const PROVIDER_TYPE_ORDER: ProviderTypeKey[] = ['antigravity', 'kiro', 'custom']; + +const PROVIDER_TYPE_LABELS: Record = { + antigravity: 'Antigravity', + kiro: 'Kiro', + custom: 'Custom', +}; + const PAGE_SIZE = 50; export const statusVariant: Record< @@ -54,13 +71,19 @@ export function RequestsPage() { // 使用游标分页:存储每页的 lastId 用于向后翻页 const [cursors, setCursors] = useState<(number | undefined)[]>([undefined]); const [pageIndex, setPageIndex] = useState(0); + // Provider 过滤器 + const [selectedProviderId, setSelectedProviderId] = useState(undefined); + // Status 过滤器 + const [selectedStatus, setSelectedStatus] = useState(undefined); const currentCursor = cursors[pageIndex]; const { data, isLoading, refetch } = useProxyRequests({ limit: PAGE_SIZE, before: currentCursor, + providerId: selectedProviderId, + status: selectedStatus, }); - const { data: totalCount, refetch: refetchCount } = useProxyRequestsCount(); + const { data: totalCount, refetch: refetchCount } = useProxyRequestsCount(selectedProviderId, selectedStatus); const { data: providers = [] } = useProviders(); const { data: projects = [] } = useProjects(); const { data: apiTokens = [] } = useAPITokens(); @@ -115,6 +138,20 @@ export function RequestsPage() { refetchCount(); }; + // Provider 过滤器变化时重置分页 + const handleProviderFilterChange = (providerId: number | undefined) => { + setSelectedProviderId(providerId); + setCursors([undefined]); + setPageIndex(0); + }; + + // Status 过滤器变化时重置分页 + const handleStatusFilterChange = (status: string | undefined) => { + setSelectedStatus(status); + setCursors([undefined]); + setPageIndex(0); + }; + return (
+ {/* Provider Filter */} + {providers.length > 0 && ( + + )} + {/* Status Filter */} + +
+ + +
} /> -
- {/* 过滤器 */} -
- setTimeRange(v as TimeRange)} - options={[ - { value: '1h', label: t('stats.last1h') }, - { value: '24h', label: t('stats.last24h') }, - { value: '7d', label: t('stats.last7d') }, - { value: '30d', label: t('stats.last30d') }, - { value: '90d', label: t('stats.last90d') }, - { value: 'all', label: t('stats.allTime') }, - ]} - /> - ({ - value: String(p.id), - label: p.name, - })) || []), - ]} - /> - ({ - value: String(p.id), - label: p.name, - })) || []), - ]} - /> - - ({ - value: String(t.id), - label: t.name, - })) || []), - ]} - /> - ({ value: m, label: m })) || []), - ]} - /> + {/* Cost recalculation progress bar */} + {costsProgress && ( +
+
+
+ {costsProgress.message} + {costsProgress.phase !== 'completed' && costsProgress.total > 0 && ( + + {costsProgress.percentage}% + + )} +
+ +
+ )} - {/* 汇总卡片 */} -
- - - 0 ? ((summary.successfulRequests / summary.totalRequests) * 100).toFixed(1) : 0}%`} - className={ - summary.totalRequests > 0 && - summary.successfulRequests / summary.totalRequests >= 0.95 - ? 'text-(--color-chart-1)' - : summary.totalRequests > 0 && - summary.successfulRequests / summary.totalRequests < 0.8 - ? 'text-(--color-chart-2)' - : 'text-(--color-chart-3)' - } - /> - + {/* Stats recalculation progress bar */} + {statsProgress && ( +
+
+
+ {statsProgress.message} + {statsProgress.phase !== 'completed' && statsProgress.total > 0 && ( + + {statsProgress.percentage}% + + )} +
+ +
+ )} + +
+ {/* 左侧筛选栏 */} +
+
+ {/* 标题 */} +
+ + + {t('stats.filterConditions')} + +
- {isLoading ? ( -
{t('common.loading')}
- ) : chartData.length === 0 ? ( -
{t('common.noData')}
- ) : ( - - - {t('stats.chart')} - setChartView(v as ChartView)}> - - {t('stats.requests')} - {t('stats.tokens')} - - - - - - setTimeRange('all')} + > + {[ + { value: 'today', label: t('stats.today') }, + { value: 'yesterday', label: t('stats.yesterday') }, + { value: 'thisWeek', label: t('stats.thisWeek') }, + { value: 'lastWeek', label: t('stats.lastWeek') }, + { value: 'thisMonth', label: t('stats.thisMonth') }, + { value: 'lastMonth', label: t('stats.lastMonth') }, + { value: '1h', label: t('stats.last1h') }, + { value: '24h', label: t('stats.last24h') }, + { value: '7d', label: t('stats.last7d') }, + { value: '30d', label: t('stats.last30d') }, + { value: '90d', label: t('stats.last90d') }, + ].map((item) => ( + setTimeRange(item.value as TimeRange)} > - - - - `$${v.toFixed(2)}`} - /> - value} - formatter={(value, name) => { - const numValue = typeof value === 'number' ? value : 0; - const nameStr = name ?? ''; - if (nameStr === t('stats.costUSD')) - return [`$${numValue.toFixed(4)}`, nameStr]; - return [numValue.toLocaleString(), nameStr]; - }} - /> - } - /> - - {chartView === 'requests' && ( - <> - - - - - )} - {chartView === 'tokens' && ( - <> - - - - - - - )} - - - - - )} + {item.label} + + ))} + + + {/* Provider - 按类型分组,按名称排序 */} + {providers && providers.length > 0 && ( + setProviderId('all')} + > + {(() => { + // 按类型分组 + const grouped = providers.reduce( + (acc, p) => { + const type = p.type || 'other'; + if (!acc[type]) acc[type] = []; + acc[type].push(p); + return acc; + }, + {} as Record, + ); + // 类型排序优先级 + const typeOrder = ['antigravity', 'kiro', 'custom', 'other']; + const sortedTypes = Object.keys(grouped).sort((a, b) => { + const aIndex = typeOrder.indexOf(a); + const bIndex = typeOrder.indexOf(b); + if (aIndex === -1 && bIndex === -1) return a.localeCompare(b); + if (aIndex === -1) return 1; + if (bIndex === -1) return -1; + return aIndex - bIndex; + }); + return sortedTypes.map((type) => ( +
+
+ {type} +
+
+ {grouped[type] + .sort((a, b) => a.name.localeCompare(b.name)) + .map((p) => ( + setProviderId(String(p.id))} + > + {p.name} + + ))} +
+
+ )); + })()} +
+ )} + + {/* Project */} + {projects && projects.length > 0 && ( + setProjectId('all')} + > + {projects.map((p) => ( + setProjectId(String(p.id))} + > + {p.name} + + ))} + + )} + + {/* Client Type */} + setClientType('all')} + > + {[ + { value: 'claude', label: 'Claude' }, + { value: 'openai', label: 'OpenAI' }, + { value: 'codex', label: 'Codex' }, + { value: 'gemini', label: 'Gemini' }, + ].map((item) => ( + setClientType(item.value)} + > + {item.label} + + ))} + + + {/* API Token */} + {apiTokens && apiTokens.length > 0 && ( + setApiTokenId('all')} + > + {apiTokens.map((token) => ( + setApiTokenId(String(token.id))} + > + {token.name} + + ))} + + )} + + {/* Model */} + {responseModels && responseModels.length > 0 && ( + setModel('all')} + > + {responseModels.map((m) => ( + setModel(m)} + title={m} + > + {m} + + ))} + + )} + + {/* 重置按钮 */} + +
+
+ + {/* 右侧内容区 */} +
+
+ {/* 当前筛选条件摘要 */} +
+ {t('stats.filterSummary')}: + + {timeConfig.start + ? `${timeConfig.start.toLocaleString()} - ${timeConfig.end.toLocaleString()}` + : t('stats.allTime')} + + {providerId !== 'all' && ( + + {t('stats.provider')}: {providers?.find((p) => String(p.id) === providerId)?.name || providerId} + + )} + {projectId !== 'all' && ( + + {t('stats.project')}: {projects?.find((p) => String(p.id) === projectId)?.name || projectId} + + )} + {clientType !== 'all' && ( + + {t('stats.clientType')}: {clientType} + + )} + {apiTokenId !== 'all' && ( + + {t('stats.apiToken')}: {apiTokens?.find((t) => String(t.id) === apiTokenId)?.name || apiTokenId} + + )} + {model !== 'all' && ( + + {t('stats.model')}: {model} + + )} +
+ + {/* 汇总卡片 - 与 Dashboard 一致的排列顺序 */} +
+ + + + 0 ? ((summary.successfulRequests / summary.totalRequests) * 100).toFixed(1) : 0}%`} + icon={CheckCircle} + iconClassName={cn( + (summary.successfulRequests / summary.totalRequests) >= 0.95 + ? 'text-emerald-600 dark:text-emerald-400' + : (summary.successfulRequests / summary.totalRequests) >= 0.8 + ? 'text-amber-600 dark:text-amber-400' + : 'text-red-600 dark:text-red-400' + )} + /> +
+ + {isLoading ? ( +
{t('common.loading')}
+ ) : chartData.length === 0 ? ( +
{t('common.noData')}
+ ) : ( + + + + + {t('stats.chart')} + + setChartView(v as ChartView)}> + + {t('stats.requests')} + {t('stats.tokens')} + + + + +
+ + + + + formatNumber(v)} + /> + `${v.toFixed(2)}`} + /> + (a.name === t('stats.costUSD') ? -1 : 0)} + formatter={(value, name) => { + const numValue = typeof value === 'number' ? value : 0; + const nameStr = name ?? ''; + if (nameStr === t('stats.costUSD')) + return [`$${numValue.toFixed(4)}`, nameStr]; + return [numValue.toLocaleString(), nameStr]; + }} + /> + (a.value === t('stats.costUSD') ? -1 : 0)} + /> + {chartView === 'requests' && ( + <> + + + + + )} + {chartView === 'tokens' && ( + <> + + + + + + + )} + +
+
+
+ )} +
+
); } -function FilterSelect({ - label, +function StatCard({ + title, value, - onChange, - options, + subtitle, + icon: Icon, + iconClassName, }: { - label: string; + title: string; value: string; - onChange: (value: string) => void; - options: { value: string; label: string }[]; + subtitle?: string; + icon: React.ElementType; + iconClassName?: string; }) { - const selectedLabel = options.find((opt) => opt.value === value)?.label; return ( -
- - + + +
+
+

+ {title} +

+

+ {value} +

+ {subtitle && ( +
+ {subtitle} +
+ )} +
+
+ +
+
+
+
+ ); +} + +function FilterSection({ + label, + children, + onClear, + showClear, +}: { + label: string; + children: React.ReactNode; + onClear?: () => void; + showClear?: boolean; +}) { + return ( +
+
+ + {onClear && ( + + )} +
+
+ {children} +
); } -function SummaryCard({ +function FilterChip({ + selected, + onClick, + children, title, - value, - subtitle, - className, }: { - title: string; - value: string; - subtitle?: string; - className?: string; + selected: boolean; + onClick: () => void; + children: React.ReactNode; + title?: string; }) { return ( - - -
{title}
-
- {value} - {subtitle && ( - {subtitle} - )} -
-
-
+ ); }