Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 133 additions & 51 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ type CalculateRequest struct {
//
//nolint:govet // fieldalignment: API struct field order optimized for readability
type CalculateResponse struct {
Breakdown cost.Breakdown `json:"breakdown"`
Timestamp time.Time `json:"timestamp"`
Commit string `json:"commit"`
Breakdown cost.Breakdown `json:"breakdown"`
Timestamp time.Time `json:"timestamp"`
Commit string `json:"commit"`
SecondsInState map[string]int `json:"seconds_in_state,omitempty"` // Only populated when using turnserver
}

// RepoSampleRequest represents a request to sample and calculate costs for a repository.
Expand Down Expand Up @@ -162,24 +163,26 @@ type OrgSampleRequest struct {
//
//nolint:govet // fieldalignment: API struct field order optimized for readability
type SampleResponse struct {
Extrapolated cost.ExtrapolatedBreakdown `json:"extrapolated"`
Timestamp time.Time `json:"timestamp"`
Commit string `json:"commit"`
Extrapolated cost.ExtrapolatedBreakdown `json:"extrapolated"`
Timestamp time.Time `json:"timestamp"`
Commit string `json:"commit"`
SecondsInState map[string]int `json:"seconds_in_state,omitempty"` // Aggregated across all sampled PRs
}

// ProgressUpdate represents a progress update for streaming responses.
//
//nolint:govet // fieldalignment: API struct field order optimized for readability
type ProgressUpdate struct {
Type string `json:"type"` // "fetching", "processing", "complete", "error", "done"
PR int `json:"pr,omitempty"`
Owner string `json:"owner,omitempty"`
Repo string `json:"repo,omitempty"`
Progress string `json:"progress,omitempty"` // e.g., "5/15"
Error string `json:"error,omitempty"`
Result *cost.ExtrapolatedBreakdown `json:"result,omitempty"`
Commit string `json:"commit,omitempty"`
R2RCallout bool `json:"r2r_callout,omitempty"`
Type string `json:"type"` // "fetching", "processing", "complete", "error", "done"
PR int `json:"pr,omitempty"`
Owner string `json:"owner,omitempty"`
Repo string `json:"repo,omitempty"`
Progress string `json:"progress,omitempty"` // e.g., "5/15"
Error string `json:"error,omitempty"`
Result *cost.ExtrapolatedBreakdown `json:"result,omitempty"`
Commit string `json:"commit,omitempty"`
R2RCallout bool `json:"r2r_callout,omitempty"`
SecondsInState map[string]int `json:"seconds_in_state,omitempty"` // Only in "done" messages
}

// New creates a new Server instance.
Expand Down Expand Up @@ -1046,26 +1049,38 @@ func (s *Server) processRequest(ctx context.Context, req *CalculateRequest, toke
// Cache miss - need to fetch PR data and calculate
cacheKey := fmt.Sprintf("pr:%s", req.URL)
prData, prCached := s.cachedPRData(ctx, cacheKey)
var secondsInState map[string]int
if !prCached {
// Fetch PR data using configured data source
var err error
// For single PR requests, use 1 hour ago as reference time to enable reasonable caching
referenceTime := time.Now().Add(-1 * time.Hour)
if s.dataSource == "turnserver" {
// Use turnserver for PR data
prData, err = github.FetchPRDataViaTurnserver(ctx, req.URL, token, referenceTime)
// Use turnserver for PR data with analysis
prDataWithAnalysis, err := github.FetchPRDataWithAnalysisViaTurnserver(ctx, req.URL, token, referenceTime)
if err != nil {
s.logger.ErrorContext(ctx, "[processRequest] Failed to fetch PR data", "url", req.URL, "source", s.dataSource, errorKey, err)
// Check if it's an access error (404, 403) - return error to client.
if IsAccessError(err) {
s.logger.WarnContext(ctx, "[processRequest] Access denied", "url", req.URL)
return nil, NewAccessError(http.StatusForbidden, "access denied to PR")
}
return nil, fmt.Errorf("failed to fetch PR data: %w", err)
}
prData = prDataWithAnalysis.PRData
secondsInState = prDataWithAnalysis.Analysis.SecondsInState
} else {
// Use prx for PR data
prData, err = github.FetchPRData(ctx, req.URL, token, referenceTime)
}
if err != nil {
s.logger.ErrorContext(ctx, "[processRequest] Failed to fetch PR data", "url", req.URL, "source", s.dataSource, errorKey, err)
// Check if it's an access error (404, 403) - return error to client.
if IsAccessError(err) {
s.logger.WarnContext(ctx, "[processRequest] Access denied", "url", req.URL)
return nil, NewAccessError(http.StatusForbidden, "access denied to PR")
if err != nil {
s.logger.ErrorContext(ctx, "[processRequest] Failed to fetch PR data", "url", req.URL, "source", s.dataSource, errorKey, err)
// Check if it's an access error (404, 403) - return error to client.
if IsAccessError(err) {
s.logger.WarnContext(ctx, "[processRequest] Access denied", "url", req.URL)
return nil, NewAccessError(http.StatusForbidden, "access denied to PR")
}
return nil, fmt.Errorf("failed to fetch PR data: %w", err)
}
return nil, fmt.Errorf("failed to fetch PR data: %w", err)
}

s.logger.InfoContext(ctx, "[processRequest] PR data cache miss - fetched from GitHub", "url", req.URL)
Expand All @@ -1080,9 +1095,10 @@ func (s *Server) processRequest(ctx context.Context, req *CalculateRequest, toke
s.cacheCalcResult(ctx, req.URL, cfg, &breakdown, 1*time.Hour)

return &CalculateResponse{
Breakdown: breakdown,
Timestamp: time.Now(),
Commit: s.serverCommit,
Breakdown: breakdown,
Timestamp: time.Now(),
Commit: s.serverCommit,
SecondsInState: secondsInState,
}, nil
}

Expand Down Expand Up @@ -1584,8 +1600,9 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest,
samples := github.SamplePRs(prs, req.SampleSize)
s.logger.InfoContext(ctx, "Sampled PRs", "sample_size", len(samples))

// Collect breakdowns from each sample
// Collect breakdowns from each sample and aggregate seconds_in_state
var breakdowns []cost.Breakdown
aggregatedSeconds := make(map[string]int)
for i, pr := range samples {
prURL := fmt.Sprintf("https://github.com/%s/%s/pull/%d", req.Owner, req.Repo, pr.Number)
s.logger.InfoContext(ctx, "Processing sample PR",
Expand All @@ -1596,11 +1613,17 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest,
// Try cache first
prCacheKey := fmt.Sprintf("pr:%s", prURL)
prData, prCached := s.cachedPRData(ctx, prCacheKey)
var secondsInState map[string]int
if !prCached {
var err error
// Use configured data source with updatedAt for effective caching
if s.dataSource == "turnserver" {
prData, err = github.FetchPRDataViaTurnserver(ctx, prURL, token, pr.UpdatedAt)
var prDataWithAnalysis github.PRDataWithAnalysis
prDataWithAnalysis, err = github.FetchPRDataWithAnalysisViaTurnserver(ctx, prURL, token, pr.UpdatedAt)
if err == nil {
prData = prDataWithAnalysis.PRData
secondsInState = prDataWithAnalysis.Analysis.SecondsInState
}
} else {
prData, err = github.FetchPRData(ctx, prURL, token, pr.UpdatedAt)
}
Expand All @@ -1613,6 +1636,11 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest,
s.cachePRData(ctx, prCacheKey, prData)
}

// Aggregate seconds_in_state
for state, seconds := range secondsInState {
aggregatedSeconds[state] += seconds
}

breakdown := cost.Calculate(prData, cfg)
breakdowns = append(breakdowns, breakdown)
}
Expand All @@ -1634,10 +1662,17 @@ func (s *Server) processRepoSample(ctx context.Context, req *RepoSampleRequest,
// Extrapolate costs from samples
extrapolated := cost.ExtrapolateFromSamples(breakdowns, len(prs), totalAuthors, openPRCount, actualDays, cfg)

// Only include seconds_in_state if we have data (turnserver only)
var secondsInState map[string]int
if len(aggregatedSeconds) > 0 {
secondsInState = aggregatedSeconds
}

return &SampleResponse{
Extrapolated: extrapolated,
Timestamp: time.Now(),
Commit: s.serverCommit,
Extrapolated: extrapolated,
Timestamp: time.Now(),
Commit: s.serverCommit,
SecondsInState: secondsInState,
}, nil
}

Expand Down Expand Up @@ -1684,8 +1719,9 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to
samples := github.SamplePRs(prs, req.SampleSize)
s.logger.InfoContext(ctx, "Sampled PRs", "sample_size", len(samples))

// Collect breakdowns from each sample
// Collect breakdowns from each sample and aggregate seconds_in_state
var breakdowns []cost.Breakdown
aggregatedSeconds := make(map[string]int)
for i, pr := range samples {
prURL := fmt.Sprintf("https://github.com/%s/%s/pull/%d", pr.Owner, pr.Repo, pr.Number)
s.logger.InfoContext(ctx, "Processing sample PR",
Expand All @@ -1696,11 +1732,17 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to
// Try cache first
prCacheKey := fmt.Sprintf("pr:%s", prURL)
prData, prCached := s.cachedPRData(ctx, prCacheKey)
var secondsInState map[string]int
if !prCached {
var err error
// Use configured data source with updatedAt for effective caching
if s.dataSource == "turnserver" {
prData, err = github.FetchPRDataViaTurnserver(ctx, prURL, token, pr.UpdatedAt)
var prDataWithAnalysis github.PRDataWithAnalysis
prDataWithAnalysis, err = github.FetchPRDataWithAnalysisViaTurnserver(ctx, prURL, token, pr.UpdatedAt)
if err == nil {
prData = prDataWithAnalysis.PRData
secondsInState = prDataWithAnalysis.Analysis.SecondsInState
}
} else {
prData, err = github.FetchPRData(ctx, prURL, token, pr.UpdatedAt)
}
Expand All @@ -1713,6 +1755,11 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to
s.cachePRData(ctx, prCacheKey, prData)
}

// Aggregate seconds_in_state
for state, seconds := range secondsInState {
aggregatedSeconds[state] += seconds
}

breakdown := cost.Calculate(prData, cfg)
breakdowns = append(breakdowns, breakdown)
}
Expand All @@ -1735,10 +1782,17 @@ func (s *Server) processOrgSample(ctx context.Context, req *OrgSampleRequest, to
// Extrapolate costs from samples
extrapolated := cost.ExtrapolateFromSamples(breakdowns, len(prs), totalAuthors, totalOpenPRs, actualDays, cfg)

// Only include seconds_in_state if we have data (turnserver only)
var secondsInState map[string]int
if len(aggregatedSeconds) > 0 {
secondsInState = aggregatedSeconds
}

return &SampleResponse{
Extrapolated: extrapolated,
Timestamp: time.Now(),
Commit: s.serverCommit,
Extrapolated: extrapolated,
Timestamp: time.Now(),
Commit: s.serverCommit,
SecondsInState: secondsInState,
}, nil
}

Expand Down Expand Up @@ -2101,7 +2155,7 @@ func (s *Server) processRepoSampleWithProgress(ctx context.Context, req *RepoSam
}))

// Process samples in parallel with progress updates
breakdowns := s.processPRsInParallel(workCtx, ctx, samples, req.Owner, req.Repo, token, cfg, writer)
breakdowns, aggregatedSeconds := s.processPRsInParallel(workCtx, ctx, samples, req.Owner, req.Repo, token, cfg, writer)

if len(breakdowns) == 0 {
logSSEError(ctx, s.logger, sendSSE(writer, ProgressUpdate{
Expand All @@ -2125,12 +2179,19 @@ func (s *Server) processRepoSampleWithProgress(ctx context.Context, req *RepoSam
// Extrapolate costs from samples
extrapolated := cost.ExtrapolateFromSamples(breakdowns, len(prs), totalAuthors, openPRCount, actualDays, cfg)

// Only include seconds_in_state if we have data (turnserver only)
var secondsInState map[string]int
if len(aggregatedSeconds) > 0 {
secondsInState = aggregatedSeconds
}

// Send final result
logSSEError(ctx, s.logger, sendSSE(writer, ProgressUpdate{
Type: "done",
Result: &extrapolated,
Commit: s.serverCommit,
R2RCallout: s.r2rCallout,
Type: "done",
Result: &extrapolated,
Commit: s.serverCommit,
R2RCallout: s.r2rCallout,
SecondsInState: secondsInState,
}))
}

Expand Down Expand Up @@ -2238,7 +2299,7 @@ func (s *Server) processOrgSampleWithProgress(ctx context.Context, req *OrgSampl
}))

// Process samples in parallel with progress updates (org mode uses empty owner/repo since it's mixed)
breakdowns := s.processPRsInParallel(workCtx, ctx, samples, "", "", token, cfg, writer)
breakdowns, aggregatedSeconds := s.processPRsInParallel(workCtx, ctx, samples, "", "", token, cfg, writer)

s.logger.InfoContext(ctx, "[processOrgSampleWithProgress] Finished processing samples",
"org", req.Org,
Expand Down Expand Up @@ -2268,20 +2329,28 @@ func (s *Server) processOrgSampleWithProgress(ctx context.Context, req *OrgSampl
// Extrapolate costs from samples
extrapolated := cost.ExtrapolateFromSamples(breakdowns, len(prs), totalAuthors, totalOpenPRs, actualDays, cfg)

// Only include seconds_in_state if we have data (turnserver only)
var secondsInState map[string]int
if len(aggregatedSeconds) > 0 {
secondsInState = aggregatedSeconds
}

// Send final result
logSSEError(ctx, s.logger, sendSSE(writer, ProgressUpdate{
Type: "done",
Result: &extrapolated,
Commit: s.serverCommit,
R2RCallout: s.r2rCallout,
Type: "done",
Result: &extrapolated,
Commit: s.serverCommit,
R2RCallout: s.r2rCallout,
SecondsInState: secondsInState,
}))
}

// processPRsInParallel processes PRs in parallel and sends progress updates via SSE.
//
//nolint:revive // line-length/use-waitgroup-go: long function signature acceptable, standard wg pattern
func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples []github.PRSummary, defaultOwner, defaultRepo, token string, cfg cost.Config, writer http.ResponseWriter) []cost.Breakdown {
func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples []github.PRSummary, defaultOwner, defaultRepo, token string, cfg cost.Config, writer http.ResponseWriter) ([]cost.Breakdown, map[string]int) {
var breakdowns []cost.Breakdown
aggregatedSeconds := make(map[string]int)
var mu sync.Mutex
var sseMu sync.Mutex // Protects SSE writes to prevent corrupted chunked encoding

Expand Down Expand Up @@ -2350,12 +2419,18 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [
// Cache miss - need to fetch PR data and calculate
prCacheKey := fmt.Sprintf("pr:%s", prURL)
prData, prCached := s.cachedPRData(workCtx, prCacheKey)
var secondsInState map[string]int
if !prCached {
var err error
// Use work context for actual API calls (not tied to client connection)
// Use configured data source with updatedAt for effective caching
if s.dataSource == "turnserver" {
prData, err = github.FetchPRDataViaTurnserver(workCtx, prURL, token, prSummary.UpdatedAt)
var prDataWithAnalysis github.PRDataWithAnalysis
prDataWithAnalysis, err = github.FetchPRDataWithAnalysisViaTurnserver(workCtx, prURL, token, prSummary.UpdatedAt)
if err == nil {
prData = prDataWithAnalysis.PRData
secondsInState = prDataWithAnalysis.Analysis.SecondsInState
}
} else {
prData, err = github.FetchPRData(workCtx, prURL, token, prSummary.UpdatedAt)
}
Expand All @@ -2380,6 +2455,13 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [
s.cachePRData(workCtx, prCacheKey, prData)
}

// Aggregate seconds_in_state
mu.Lock()
for state, seconds := range secondsInState {
aggregatedSeconds[state] += seconds
}
mu.Unlock()

// Send "processing" update using request context for SSE
sseMu.Lock()
logSSEError(reqCtx, s.logger, sendSSE(writer, ProgressUpdate{
Expand Down Expand Up @@ -2415,5 +2497,5 @@ func (s *Server) processPRsInParallel(workCtx, reqCtx context.Context, samples [
}

wg.Wait()
return breakdowns
return breakdowns, aggregatedSeconds
}
Loading
Loading