Skip to content

Commit 1d61827

Browse files
authored
Merge pull request #31 from tstromberg/main
transparent test->PR lookups
2 parents 1aac68f + f65df9f commit 1d61827

File tree

10 files changed

+353
-78
lines changed

10 files changed

+353
-78
lines changed

cmd/server/main.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,10 @@ const (
3232
minMaskHeaderLength = 20 // Minimum header length before we show full "[REDACTED]"
3333
)
3434

35-
// getEnvOrDefault returns the value of the environment variable or the default if not set.
36-
func getEnvOrDefault(key, defaultValue string) string {
37-
if value := os.Getenv(key); value != "" {
38-
return value
39-
}
40-
return defaultValue
41-
}
35+
// contextKey is a custom type for context keys to avoid collisions.
36+
type contextKey string
37+
38+
const reservationTokenKey contextKey = "reservation_token"
4239

4340
var (
4441
webhookSecret = flag.String("webhook-secret", os.Getenv("GITHUB_WEBHOOK_SECRET"), "GitHub webhook secret for signature verification")
@@ -50,8 +47,12 @@ var (
5047
maxConnsPerIP = flag.Int("max-conns-per-ip", 10, "Maximum WebSocket connections per IP")
5148
maxConnsTotal = flag.Int("max-conns-total", 1000, "Maximum total WebSocket connections")
5249
rateLimit = flag.Int("rate-limit", 100, "Maximum requests per minute per IP")
53-
allowedEvents = flag.String("allowed-events", getEnvOrDefault("ALLOWED_WEBHOOK_EVENTS", "*"),
54-
"Comma-separated list of allowed webhook event types (use '*' for all, default: '*')")
50+
allowedEvents = flag.String("allowed-events", func() string {
51+
if value := os.Getenv("ALLOWED_WEBHOOK_EVENTS"); value != "" {
52+
return value
53+
}
54+
return "*"
55+
}(), "Comma-separated list of allowed webhook event types (use '*' for all, default: '*')")
5556
debugHeaders = flag.Bool("debug-headers", false, "Log request headers for debugging (security warning: may log sensitive data)")
5657
)
5758

@@ -244,7 +245,7 @@ func main() {
244245
}
245246

246247
// Set reservation token in request context so websocket handler can commit it
247-
r = r.WithContext(context.WithValue(r.Context(), "reservation_token", reservationToken))
248+
r = r.WithContext(context.WithValue(r.Context(), reservationTokenKey, reservationToken))
248249

249250
// Log successful auth and proceed to upgrade
250251
log.Printf("WebSocket UPGRADE: ip=%s duration=%v", ip, time.Since(startTime))

go.sum

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
11
github.com/codeGROOVE-dev/retry v1.2.0 h1:xYpYPX2PQZmdHwuiQAGGzsBm392xIMl4nfMEFApQnu8=
22
github.com/codeGROOVE-dev/retry v1.2.0/go.mod h1:8OgefgV1XP7lzX2PdKlCXILsYKuz6b4ZpHa/20iLi8E=
3-
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
4-
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
53
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
64
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
7-
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
8-
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
9-
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
10-
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
115
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
126
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
13-
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
14-
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
157
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
168
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=

pkg/client/client.go

Lines changed: 139 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/codeGROOVE-dev/retry"
14+
"github.com/codeGROOVE-dev/sprinkler/pkg/github"
1415
"golang.org/x/net/websocket"
1516
)
1617

@@ -45,8 +46,9 @@ type Event struct {
4546
Timestamp time.Time `json:"timestamp"`
4647
Raw map[string]any
4748
Type string `json:"type"`
48-
URL string `json:"url"`
49+
URL string `json:"url"` // PR URL (or repo URL for check events with race condition)
4950
DeliveryID string `json:"delivery_id,omitempty"`
51+
CommitSHA string `json:"commit_sha,omitempty"` // Commit SHA for check events
5052
}
5153

5254
// Config holds the configuration for the client.
@@ -85,9 +87,16 @@ type Client struct {
8587
ws *websocket.Conn
8688
stopCh chan struct{}
8789
stoppedCh chan struct{}
88-
writeCh chan any // Channel for serializing all writes
90+
stopOnce sync.Once // Ensures Stop() is only executed once
91+
writeCh chan any // Channel for serializing all writes
8992
eventCount int
9093
retries int
94+
95+
// Cache for commit SHA to PR number lookups (for check event race condition)
96+
commitPRCache map[string][]int // key: "owner/repo:sha", value: PR numbers
97+
commitCacheKeys []string // track insertion order for LRU eviction
98+
cacheMu sync.RWMutex
99+
maxCacheSize int
91100
}
92101

93102
// New creates a new robust WebSocket client.
@@ -118,10 +127,13 @@ func New(config Config) (*Client, error) {
118127
}
119128

120129
return &Client{
121-
config: config,
122-
stopCh: make(chan struct{}),
123-
stoppedCh: make(chan struct{}),
124-
logger: logger,
130+
config: config,
131+
stopCh: make(chan struct{}),
132+
stoppedCh: make(chan struct{}),
133+
logger: logger,
134+
commitPRCache: make(map[string][]int),
135+
commitCacheKeys: make([]string, 0, 512),
136+
maxCacheSize: 512,
125137
}, nil
126138
}
127139

@@ -220,16 +232,27 @@ func (c *Client) Start(ctx context.Context) error {
220232
}
221233

222234
// Stop gracefully stops the client.
235+
// Safe to call multiple times - only the first call will take effect.
236+
// Also safe to call before Start() or if Start() was never called.
223237
func (c *Client) Stop() {
224-
close(c.stopCh)
225-
c.mu.Lock()
226-
if c.ws != nil {
227-
if closeErr := c.ws.Close(); closeErr != nil {
228-
c.logger.Error("Error closing websocket on shutdown", "error", closeErr)
238+
c.stopOnce.Do(func() {
239+
close(c.stopCh)
240+
c.mu.Lock()
241+
if c.ws != nil {
242+
if closeErr := c.ws.Close(); closeErr != nil {
243+
c.logger.Error("Error closing websocket on shutdown", "error", closeErr)
244+
}
229245
}
230-
}
231-
c.mu.Unlock()
232-
<-c.stoppedCh
246+
c.mu.Unlock()
247+
248+
// Wait for Start() to finish, but with timeout in case Start() was never called
249+
select {
250+
case <-c.stoppedCh:
251+
// Start() completed normally
252+
case <-time.After(100 * time.Millisecond):
253+
// Start() was never called or hasn't started yet - that's ok
254+
}
255+
})
233256
}
234257

235258
// connect establishes a WebSocket connection and handles events.
@@ -608,18 +631,119 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
608631
event.DeliveryID = deliveryID
609632
}
610633

634+
if commitSHA, ok := response["commit_sha"].(string); ok {
635+
event.CommitSHA = commitSHA
636+
}
637+
611638
c.mu.Lock()
612639
c.eventCount++
613640
eventNum := c.eventCount
614641
c.mu.Unlock()
615642

643+
// Handle check events with repo-only URLs (GitHub race condition)
644+
// Automatically expand into per-PR events using GitHub API with caching
645+
if (event.Type == "check_run" || event.Type == "check_suite") && event.CommitSHA != "" && !strings.Contains(event.URL, "/pull/") {
646+
// Extract owner/repo from URL
647+
parts := strings.Split(event.URL, "/")
648+
if len(parts) >= 5 && parts[2] == "github.com" {
649+
owner := parts[3]
650+
repo := parts[4]
651+
key := owner + "/" + repo + ":" + event.CommitSHA
652+
653+
// Check cache first
654+
c.cacheMu.RLock()
655+
cached, ok := c.commitPRCache[key]
656+
c.cacheMu.RUnlock()
657+
658+
var prs []int
659+
if ok {
660+
// Cache hit - return copy to prevent external modifications
661+
prs = make([]int, len(cached))
662+
copy(prs, cached)
663+
c.logger.Info("Check event with repo URL - using cached PR lookup",
664+
"commit_sha", event.CommitSHA,
665+
"repo_url", event.URL,
666+
"type", event.Type,
667+
"pr_count", len(prs),
668+
"cache_hit", true)
669+
} else {
670+
// Cache miss - look up via GitHub API
671+
c.logger.Info("Check event with repo URL - looking up PRs via GitHub API",
672+
"commit_sha", event.CommitSHA,
673+
"repo_url", event.URL,
674+
"type", event.Type,
675+
"cache_hit", false)
676+
677+
gh := github.NewClient(c.config.Token)
678+
var err error
679+
prs, err = gh.FindPRsForCommit(ctx, owner, repo, event.CommitSHA)
680+
if err != nil {
681+
c.logger.Warn("Failed to look up PRs for commit",
682+
"commit_sha", event.CommitSHA,
683+
"owner", owner,
684+
"repo", repo,
685+
"error", err)
686+
// Don't cache errors - try again next time
687+
} else {
688+
// Cache the result (even if empty)
689+
c.cacheMu.Lock()
690+
if _, exists := c.commitPRCache[key]; !exists {
691+
c.commitCacheKeys = append(c.commitCacheKeys, key)
692+
// Evict oldest 25% if cache is full
693+
if len(c.commitCacheKeys) > c.maxCacheSize {
694+
n := c.maxCacheSize / 4
695+
for i := range n {
696+
delete(c.commitPRCache, c.commitCacheKeys[i])
697+
}
698+
c.commitCacheKeys = c.commitCacheKeys[n:]
699+
}
700+
}
701+
// Store copy to prevent external modifications
702+
cached := make([]int, len(prs))
703+
copy(cached, prs)
704+
c.commitPRCache[key] = cached
705+
c.cacheMu.Unlock()
706+
707+
c.logger.Info("Cached PR lookup result",
708+
"commit_sha", event.CommitSHA,
709+
"pr_count", len(prs))
710+
}
711+
}
712+
713+
// Emit events for each PR found
714+
if len(prs) > 0 {
715+
for _, n := range prs {
716+
e := event // Copy the event
717+
e.URL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", owner, repo, n)
718+
719+
if c.config.OnEvent != nil {
720+
c.logger.Info("Event received (expanded from commit)",
721+
"timestamp", e.Timestamp.Format("15:04:05"),
722+
"event_number", eventNum,
723+
"type", e.Type,
724+
"url", e.URL,
725+
"commit_sha", e.CommitSHA,
726+
"delivery_id", e.DeliveryID)
727+
c.config.OnEvent(e)
728+
}
729+
}
730+
continue // Skip the normal event handling since we expanded it
731+
}
732+
c.logger.Info("No PRs found for commit - may be push to main",
733+
"commit_sha", event.CommitSHA,
734+
"owner", owner,
735+
"repo", repo)
736+
}
737+
}
738+
616739
// Log event
617740
if c.config.Verbose {
618741
c.logger.Info("Event received",
619742
"event_number", eventNum,
620743
"timestamp", event.Timestamp.Format("15:04:05"),
621744
"type", event.Type,
622745
"url", event.URL,
746+
"commit_sha", event.CommitSHA,
623747
"delivery_id", event.DeliveryID,
624748
"raw", event.Raw)
625749
} else {
@@ -629,6 +753,7 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
629753
"event_number", eventNum,
630754
"type", event.Type,
631755
"url", event.URL,
756+
"commit_sha", event.CommitSHA,
632757
"delivery_id", event.DeliveryID)
633758
} else {
634759
c.logger.Info("Event received",

pkg/client/client_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
"time"
8+
)
9+
10+
// TestStopMultipleCalls verifies that calling Stop() multiple times is safe
11+
// and doesn't panic with "close of closed channel".
12+
func TestStopMultipleCalls(t *testing.T) {
13+
// Create a client with minimal config
14+
client, err := New(Config{
15+
ServerURL: "ws://localhost:8080",
16+
Token: "test-token",
17+
Organization: "test-org",
18+
NoReconnect: true, // Disable reconnect to make test faster
19+
})
20+
if err != nil {
21+
t.Fatalf("Failed to create client: %v", err)
22+
}
23+
24+
// Start the client in a goroutine
25+
ctx, cancel := context.WithCancel(context.Background())
26+
defer cancel()
27+
28+
go func() {
29+
// Expected to fail to connect, but that's ok for this test
30+
if err := client.Start(ctx); err != nil {
31+
// Error is expected in tests - client can't connect to non-existent server
32+
}
33+
}()
34+
35+
// Give it a moment to initialize
36+
time.Sleep(10 * time.Millisecond)
37+
38+
// Call Stop() multiple times concurrently
39+
// This should NOT panic with "close of closed channel"
40+
var wg sync.WaitGroup
41+
for i := 0; i < 10; i++ {
42+
wg.Add(1)
43+
go func() {
44+
defer wg.Done()
45+
client.Stop() // Should be safe to call multiple times
46+
}()
47+
}
48+
49+
// Wait for all Stop() calls to complete
50+
wg.Wait()
51+
52+
// If we get here without a panic, the test passes
53+
}
54+
55+
// TestStopBeforeStart verifies that calling Stop() before Start() is safe.
56+
func TestStopBeforeStart(t *testing.T) {
57+
client, err := New(Config{
58+
ServerURL: "ws://localhost:8080",
59+
Token: "test-token",
60+
Organization: "test-org",
61+
NoReconnect: true,
62+
})
63+
if err != nil {
64+
t.Fatalf("Failed to create client: %v", err)
65+
}
66+
67+
// Call Stop() before Start()
68+
client.Stop()
69+
70+
// Now try to start - should exit cleanly
71+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
72+
defer cancel()
73+
74+
err = client.Start(ctx)
75+
// We expect either context.DeadlineExceeded or "stop requested"
76+
if err == nil {
77+
t.Error("Expected Start() to fail after Stop(), but it succeeded")
78+
}
79+
}

0 commit comments

Comments
 (0)