Skip to content

Commit e79e5c0

Browse files
authored
Merge pull request #41 from tstromberg/main
Add app tier support
2 parents de12e2b + 4c0adff commit e79e5c0

18 files changed

+547
-132
lines changed

cmd/server/main.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ var (
5353
return "*"
5454
}(), "Comma-separated list of allowed webhook event types (use '*' for all, default: '*')")
5555
debugHeaders = flag.Bool("debug-headers", false, "Log request headers for debugging (security warning: may log sensitive data)")
56+
enforceTiers = flag.Bool("enforce-tiers", false, "Enforce GitHub Marketplace tier restrictions (default: false, logs warnings only)")
5657
)
5758

5859
//nolint:funlen,gocognit,lll,revive,maintidx // Main function orchestrates entire server setup and cannot be split without losing clarity
@@ -88,9 +89,15 @@ func main() {
8889

8990
// CORS support removed - WebSocket clients should handle auth via Authorization header
9091

91-
hub := srv.NewHub()
92+
hub := srv.NewHub(*enforceTiers)
9293
go hub.Run(ctx)
9394

95+
if *enforceTiers {
96+
log.Println("Tier enforcement ENABLED - private repo access restricted to Pro/Flock tiers")
97+
} else {
98+
log.Println("Tier enforcement DISABLED - will log warnings only (all users can access private repos)")
99+
}
100+
94101
// Create connection limiter for WebSocket connections
95102
connLimiter := security.NewConnectionLimiter(*maxConnsPerIP, *maxConnsTotal)
96103

pkg/github/interface.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ type APIClient interface {
1010

1111
// ValidateOrgMembership validates that the authenticated user is a member of the specified organization.
1212
ValidateOrgMembership(ctx context.Context, org string) (username string, orgs []string, err error)
13+
14+
// UserTier fetches the user's GitHub Marketplace subscription tier.
15+
UserTier(ctx context.Context, username string) (Tier, error)
1316
}
1417

1518
// Ensure Client implements APIClient interface.

pkg/github/mock.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ type MockClient struct {
1313
Username string
1414
LastValidatedOrg string
1515
Orgs []string
16+
Tier Tier // Mock tier to return
1617
UserAndOrgsCalls int
1718
ValidateOrgMembershipCalls int
19+
UserTierCalls int
1820
mu sync.Mutex
1921
}
2022

@@ -57,5 +59,20 @@ func (m *MockClient) ValidateOrgMembership(_ context.Context, org string) (usern
5759
return m.Username, m.Orgs, nil
5860
}
5961

62+
// UserTier returns the mock tier.
63+
func (m *MockClient) UserTier(_ context.Context, _ string) (Tier, error) {
64+
m.mu.Lock()
65+
defer m.mu.Unlock()
66+
m.UserTierCalls++
67+
if m.Err != nil {
68+
return TierFree, m.Err
69+
}
70+
// If no tier is set, default to TierFree
71+
if m.Tier == "" {
72+
return TierFree, nil
73+
}
74+
return m.Tier, nil
75+
}
76+
6077
// Ensure MockClient implements APIClient interface.
6178
var _ APIClient = (*MockClient)(nil)

pkg/github/tier.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"strings"
11+
)
12+
13+
// Tier represents a GitHub Marketplace pricing tier.
14+
type Tier string
15+
16+
const (
17+
// TierFree is the default tier for users without a paid subscription.
18+
TierFree Tier = "free"
19+
// TierPro is the Pro tier for individual users.
20+
TierPro Tier = "pro"
21+
// TierFlock is the Flock tier for teams/organizations.
22+
TierFlock Tier = "flock"
23+
)
24+
25+
// MarketplacePlan represents a GitHub Marketplace subscription plan.
26+
type MarketplacePlan struct {
27+
Name string `json:"name"`
28+
}
29+
30+
// MarketplaceAccount represents a GitHub Marketplace account subscription.
31+
type MarketplaceAccount struct {
32+
Plan MarketplacePlan `json:"plan"`
33+
}
34+
35+
// UserTier fetches the user's GitHub Marketplace subscription tier.
36+
// Returns TierFree if no subscription exists or on API errors (graceful degradation).
37+
func (c *Client) UserTier(ctx context.Context, username string) (Tier, error) {
38+
if username == "" {
39+
return TierFree, errors.New("username cannot be empty")
40+
}
41+
42+
// Query GitHub Marketplace API for user's subscription
43+
url := fmt.Sprintf("https://api.github.com/marketplace_listing/accounts/%s", username)
44+
45+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
46+
if err != nil {
47+
return TierFree, fmt.Errorf("failed to create marketplace API request: %w", err)
48+
}
49+
50+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
51+
req.Header.Set("Accept", "application/vnd.github+json")
52+
req.Header.Set("X-Github-Api-Version", "2022-11-28")
53+
req.Header.Set("User-Agent", "webhook-sprinkler/1.0")
54+
55+
resp, err := c.httpClient.Do(req)
56+
if err != nil {
57+
c.logger.Warn("marketplace API request failed", "error", err, "username", username)
58+
return TierFree, fmt.Errorf("marketplace API request failed: %w", err)
59+
}
60+
defer func() {
61+
if closeErr := resp.Body.Close(); closeErr != nil {
62+
c.logger.Warn("failed to close marketplace response body", "error", closeErr)
63+
}
64+
}()
65+
66+
// 404 means no subscription - this is normal, not an error
67+
if resp.StatusCode == http.StatusNotFound {
68+
c.logger.Info("no marketplace subscription found", "username", username)
69+
return TierFree, nil
70+
}
71+
72+
// Handle other error status codes
73+
if resp.StatusCode != http.StatusOK {
74+
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<10)) // Read up to 1KB for error message
75+
if err != nil {
76+
c.logger.Warn("failed to read error response", "error", err)
77+
}
78+
c.logger.Warn("marketplace API returned error",
79+
"status", resp.StatusCode,
80+
"username", username,
81+
"body", string(body))
82+
return TierFree, fmt.Errorf("marketplace API error: status %d", resp.StatusCode)
83+
}
84+
85+
// Parse successful response
86+
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
87+
if err != nil {
88+
c.logger.Warn("failed to read marketplace response", "error", err)
89+
return TierFree, fmt.Errorf("failed to read marketplace response: %w", err)
90+
}
91+
92+
var account MarketplaceAccount
93+
if err := json.Unmarshal(body, &account); err != nil {
94+
c.logger.Warn("failed to parse marketplace response", "error", err, "body", string(body))
95+
return TierFree, fmt.Errorf("failed to parse marketplace response: %w", err)
96+
}
97+
98+
tier := mapPlanToTier(account.Plan.Name)
99+
c.logger.Info("marketplace tier detected",
100+
"username", username,
101+
"plan", account.Plan.Name,
102+
"tier", tier)
103+
104+
return tier, nil
105+
}
106+
107+
// mapPlanToTier maps GitHub Marketplace plan names to internal tier constants.
108+
// Update this function when your marketplace listing is approved with the actual plan names.
109+
func mapPlanToTier(planName string) Tier {
110+
switch strings.ToLower(planName) {
111+
case "pro":
112+
return TierPro
113+
case "flock", "team", "enterprise":
114+
return TierFlock
115+
default:
116+
// Unknown plan names default to free tier
117+
return TierFree
118+
}
119+
}

pkg/github/tier_cache.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package github
2+
3+
import (
4+
"sync"
5+
"time"
6+
)
7+
8+
// TierCache caches tier lookups to reduce API calls to GitHub Marketplace.
9+
// Implements thread-safe in-memory caching with TTL expiration.
10+
type TierCache struct {
11+
mu sync.RWMutex
12+
cache map[string]*tierEntry
13+
ttl time.Duration
14+
stopCh chan struct{}
15+
stopped chan struct{}
16+
}
17+
18+
// tierEntry represents a cached tier with expiration.
19+
type tierEntry struct {
20+
tier Tier
21+
expiresAt time.Time
22+
}
23+
24+
// NewTierCache creates a new tier cache with the specified TTL.
25+
// A background goroutine periodically cleans up expired entries.
26+
func NewTierCache(ttl time.Duration) *TierCache {
27+
tc := &TierCache{
28+
cache: make(map[string]*tierEntry),
29+
ttl: ttl,
30+
stopCh: make(chan struct{}),
31+
stopped: make(chan struct{}),
32+
}
33+
34+
// Start cleanup goroutine
35+
go tc.cleanupLoop()
36+
37+
return tc
38+
}
39+
40+
// Get retrieves a tier from the cache.
41+
// Returns (tier, true) if found and not expired, (TierFree, false) otherwise.
42+
func (tc *TierCache) Get(username string) (Tier, bool) {
43+
tc.mu.RLock()
44+
defer tc.mu.RUnlock()
45+
46+
entry, ok := tc.cache[username]
47+
if !ok {
48+
return TierFree, false
49+
}
50+
51+
// Check expiration
52+
if time.Now().After(entry.expiresAt) {
53+
return TierFree, false
54+
}
55+
56+
return entry.tier, true
57+
}
58+
59+
// Set stores a tier in the cache with TTL expiration.
60+
func (tc *TierCache) Set(username string, tier Tier) {
61+
tc.mu.Lock()
62+
defer tc.mu.Unlock()
63+
64+
tc.cache[username] = &tierEntry{
65+
tier: tier,
66+
expiresAt: time.Now().Add(tc.ttl),
67+
}
68+
}
69+
70+
// Stop stops the cleanup goroutine and waits for it to finish.
71+
// Safe to call multiple times (subsequent calls are no-ops).
72+
func (tc *TierCache) Stop() {
73+
select {
74+
case <-tc.stopCh:
75+
// Already stopped
76+
return
77+
default:
78+
close(tc.stopCh)
79+
<-tc.stopped
80+
}
81+
}
82+
83+
// cleanupLoop runs periodically to remove expired cache entries.
84+
func (tc *TierCache) cleanupLoop() {
85+
defer close(tc.stopped)
86+
87+
ticker := time.NewTicker(1 * time.Hour)
88+
defer ticker.Stop()
89+
90+
for {
91+
select {
92+
case <-tc.stopCh:
93+
return
94+
case <-ticker.C:
95+
tc.cleanup()
96+
}
97+
}
98+
}
99+
100+
// cleanup removes expired entries from the cache.
101+
func (tc *TierCache) cleanup() {
102+
tc.mu.Lock()
103+
defer tc.mu.Unlock()
104+
105+
now := time.Now()
106+
for username, entry := range tc.cache {
107+
if now.After(entry.expiresAt) {
108+
delete(tc.cache, username)
109+
}
110+
}
111+
}

pkg/srv/client.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"golang.org/x/net/websocket"
1212

13+
"github.com/codeGROOVE-dev/sprinkler/pkg/github"
1314
"github.com/codeGROOVE-dev/sprinkler/pkg/logger"
1415
)
1516

@@ -52,12 +53,13 @@ type Client struct {
5253
userOrgs map[string]bool
5354
ID string
5455
subscription Subscription
56+
tier github.Tier // GitHub Marketplace tier
5557
closeOnce sync.Once
5658
closed uint32 // Atomic flag: 1 if closed, 0 if open
5759
}
5860

5961
// NewClient creates a new client.
60-
func NewClient(ctx context.Context, id string, sub Subscription, conn *websocket.Conn, hub *Hub, userOrgs []string) *Client {
62+
func NewClient(ctx context.Context, id string, sub Subscription, conn *websocket.Conn, hub *Hub, userOrgs []string, tier github.Tier) *Client {
6163
// Limit the number of orgs to prevent memory exhaustion
6264
const maxOrgs = 1000
6365
orgsToProcess := userOrgs
@@ -85,6 +87,7 @@ func NewClient(ctx context.Context, id string, sub Subscription, conn *websocket
8587
hub: hub,
8688
done: make(chan struct{}),
8789
userOrgs: orgsMap,
90+
tier: tier,
8891
}
8992
}
9093

@@ -197,3 +200,9 @@ func (c *Client) Close() {
197200
func (c *Client) IsClosed() bool {
198201
return atomic.LoadUint32(&c.closed) != 0
199202
}
203+
204+
// CanAccessPrivateRepos returns true if the client's tier allows private repo access.
205+
// Only Pro and Flock tiers have access to private repository events.
206+
func (c *Client) CanAccessPrivateRepos() bool {
207+
return c.tier == github.TierPro || c.tier == github.TierFlock
208+
}

pkg/srv/client_test_helper.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package srv
2+
3+
import (
4+
"context"
5+
6+
"golang.org/x/net/websocket"
7+
8+
"github.com/codeGROOVE-dev/sprinkler/pkg/github"
9+
)
10+
11+
// NewClientForTest creates a new client for testing with default TierFree.
12+
// This maintains backward compatibility with existing tests.
13+
func NewClientForTest(ctx context.Context, id string, sub Subscription, conn *websocket.Conn, hub *Hub, userOrgs []string) *Client {
14+
return NewClient(ctx, id, sub, conn, hub, userOrgs, github.TierFree)
15+
}

0 commit comments

Comments
 (0)