Skip to content

Commit a63d5d1

Browse files
committed
Cache check results
1 parent 4fcfaa5 commit a63d5d1

File tree

3 files changed

+235
-4
lines changed

3 files changed

+235
-4
lines changed

pkg/github/issues_test.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@ func Test_GetIssue(t *testing.T) {
5252
},
5353
}
5454

55+
mockPrivateIssue := &github.Issue{
56+
Number: github.Ptr(42),
57+
Title: github.Ptr("Test Issue"),
58+
Body: github.Ptr("This is a test issue"),
59+
State: github.Ptr("open"),
60+
HTMLURL: github.Ptr("https://github.com/owner/repo/issues/42"),
61+
User: &github.User{
62+
Login: github.Ptr("privateuser"),
63+
},
64+
Repository: &github.Repository{
65+
Name: github.Ptr("repo"),
66+
Owner: &github.User{
67+
Login: github.Ptr("owner"),
68+
},
69+
},
70+
}
71+
5572
tests := []struct {
5673
name string
5774
mockedClient *http.Client
@@ -101,7 +118,7 @@ func Test_GetIssue(t *testing.T) {
101118
mockedClient: mock.NewMockedHTTPClient(
102119
mock.WithRequestMatch(
103120
mock.GetReposIssuesByOwnerByRepoByIssueNumber,
104-
mockIssue,
121+
mockPrivateIssue,
105122
),
106123
),
107124
gqlHTTPClient: githubv4mock.NewMockedHTTPClient(
@@ -122,7 +139,7 @@ func Test_GetIssue(t *testing.T) {
122139
map[string]any{
123140
"owner": githubv4.String("owner"),
124141
"name": githubv4.String("repo"),
125-
"username": githubv4.String("testuser"),
142+
"username": githubv4.String("privateuser"),
126143
},
127144
githubv4mock.DataResponse(map[string]any{
128145
"repository": map[string]any{
@@ -140,7 +157,7 @@ func Test_GetIssue(t *testing.T) {
140157
"repo": "repo",
141158
"issue_number": float64(42),
142159
},
143-
expectedIssue: mockIssue,
160+
expectedIssue: mockPrivateIssue,
144161
lockdownEnabled: true,
145162
},
146163
{

pkg/lockdown/lockdown.go

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,70 @@ import (
44
"context"
55
"fmt"
66
"strings"
7+
"sync"
8+
"time"
79

810
"github.com/shurcooL/githubv4"
911
)
1012

13+
type repoAccessKey struct {
14+
owner string
15+
repo string
16+
username string
17+
}
18+
19+
type repoAccessEntry struct {
20+
isPrivate bool
21+
hasPush bool
22+
loadedAt time.Time
23+
}
24+
25+
var (
26+
repoAccessCache sync.Map
27+
repoAccessInfoFunc = repoAccessInfo
28+
timeNow = time.Now
29+
)
30+
31+
// repoAccessRefreshInterval defines how long to cache repository access
32+
// information before refreshing it.
33+
const repoAccessRefreshInterval = 10 * time.Minute
34+
35+
func newRepoAccessKey(username, owner, repo string) repoAccessKey {
36+
return repoAccessKey{
37+
owner: strings.ToLower(owner),
38+
repo: strings.ToLower(repo),
39+
username: strings.ToLower(username),
40+
}
41+
}
42+
1143
// ShouldRemoveContent determines if content should be removed based on
1244
// lockdown mode rules. It checks if the repository is private and if the user
1345
// has push access to the repository.
1446
func ShouldRemoveContent(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, error) {
15-
isPrivate, hasPushAccess, err := repoAccessInfo(ctx, client, username, owner, repo)
47+
key := newRepoAccessKey(username, owner, repo)
48+
49+
now := timeNow()
50+
if cached, ok := repoAccessCache.Load(key); ok {
51+
entry := cached.(repoAccessEntry)
52+
if now.Sub(entry.loadedAt) < repoAccessRefreshInterval {
53+
if entry.isPrivate {
54+
return false, nil
55+
}
56+
return !entry.hasPush, nil
57+
}
58+
}
59+
60+
isPrivate, hasPushAccess, err := repoAccessInfoFunc(ctx, client, username, owner, repo)
1661
if err != nil {
1762
return false, err
1863
}
1964

65+
repoAccessCache.Store(key, repoAccessEntry{
66+
isPrivate: isPrivate,
67+
hasPush: hasPushAccess,
68+
loadedAt: timeNow(),
69+
})
70+
2071
// Do not filter content for private repositories
2172
if isPrivate {
2273
return false, nil
@@ -25,6 +76,14 @@ func ShouldRemoveContent(ctx context.Context, client *githubv4.Client, username,
2576
return !hasPushAccess, nil
2677
}
2778

79+
// clearRepoAccessCache removes all cached repository access information; used by tests.
80+
func clearRepoAccessCache() {
81+
repoAccessCache.Range(func(key, _ any) bool {
82+
repoAccessCache.Delete(key)
83+
return true
84+
})
85+
}
86+
2887
func repoAccessInfo(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, bool, error) {
2988
if client == nil {
3089
return false, false, fmt.Errorf("nil GraphQL client")

pkg/lockdown/lockdown_test.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package lockdown
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"github.com/shurcooL/githubv4"
10+
)
11+
12+
func TestShouldRemoveContentCachesResultsWithinInterval(t *testing.T) {
13+
clearRepoAccessCache()
14+
defer clearRepoAccessCache()
15+
16+
originalInfoFunc := repoAccessInfoFunc
17+
defer func() { repoAccessInfoFunc = originalInfoFunc }()
18+
19+
originalTimeNow := timeNow
20+
defer func() { timeNow = originalTimeNow }()
21+
22+
fixed := time.Now()
23+
timeNow = func() time.Time { return fixed }
24+
25+
callCount := 0
26+
repoAccessInfoFunc = func(_ context.Context, _ *githubv4.Client, _, _, _ string) (bool, bool, error) {
27+
callCount++
28+
return false, true, nil
29+
}
30+
31+
ctx := context.Background()
32+
33+
remove, err := ShouldRemoveContent(ctx, nil, "User", "Owner", "Repo")
34+
if err != nil {
35+
t.Fatalf("unexpected error on first call: %v", err)
36+
}
37+
if remove {
38+
t.Fatalf("expected remove=false when user has push access")
39+
}
40+
41+
remove, err = ShouldRemoveContent(ctx, nil, "user", "owner", "repo")
42+
if err != nil {
43+
t.Fatalf("unexpected error on cached call: %v", err)
44+
}
45+
if remove {
46+
t.Fatalf("expected remove=false when cached entry reused")
47+
}
48+
if callCount != 1 {
49+
t.Fatalf("expected cached result to prevent additional repo access queries, got %d", callCount)
50+
}
51+
}
52+
53+
func TestShouldRemoveContentRefreshesAfterInterval(t *testing.T) {
54+
clearRepoAccessCache()
55+
defer clearRepoAccessCache()
56+
57+
originalInfoFunc := repoAccessInfoFunc
58+
defer func() { repoAccessInfoFunc = originalInfoFunc }()
59+
60+
originalTimeNow := timeNow
61+
defer func() { timeNow = originalTimeNow }()
62+
63+
base := time.Now()
64+
current := base
65+
timeNow = func() time.Time { return current }
66+
67+
callCount := 0
68+
repoAccessInfoFunc = func(_ context.Context, _ *githubv4.Client, _, _, _ string) (bool, bool, error) {
69+
callCount++
70+
if callCount == 1 {
71+
return false, false, nil
72+
}
73+
return false, true, nil
74+
}
75+
76+
ctx := context.Background()
77+
78+
remove, err := ShouldRemoveContent(ctx, nil, "user", "owner", "repo")
79+
if err != nil {
80+
t.Fatalf("unexpected error on first call: %v", err)
81+
}
82+
if !remove {
83+
t.Fatalf("expected remove=true when user lacks push access")
84+
}
85+
if callCount != 1 {
86+
t.Fatalf("expected first call to query once, got %d", callCount)
87+
}
88+
89+
current = base.Add(9 * time.Minute)
90+
remove, err = ShouldRemoveContent(ctx, nil, "user", "owner", "repo")
91+
if err != nil {
92+
t.Fatalf("unexpected error before refresh interval: %v", err)
93+
}
94+
if !remove {
95+
t.Fatalf("expected remove=true before refresh interval expires")
96+
}
97+
if callCount != 1 {
98+
t.Fatalf("expected cached value before refresh interval, got %d calls", callCount)
99+
}
100+
101+
current = base.Add(11 * time.Minute)
102+
remove, err = ShouldRemoveContent(ctx, nil, "user", "owner", "repo")
103+
if err != nil {
104+
t.Fatalf("unexpected error after refresh interval: %v", err)
105+
}
106+
if remove {
107+
t.Fatalf("expected remove=false after permissions refreshed")
108+
}
109+
if callCount != 2 {
110+
t.Fatalf("expected refreshed access info after interval, got %d calls", callCount)
111+
}
112+
}
113+
114+
func TestShouldRemoveContentDoesNotCacheErrors(t *testing.T) {
115+
clearRepoAccessCache()
116+
defer clearRepoAccessCache()
117+
118+
originalInfoFunc := repoAccessInfoFunc
119+
defer func() { repoAccessInfoFunc = originalInfoFunc }()
120+
121+
originalTimeNow := timeNow
122+
defer func() { timeNow = originalTimeNow }()
123+
124+
now := time.Now()
125+
timeNow = func() time.Time { return now }
126+
127+
callCount := 0
128+
repoAccessInfoFunc = func(_ context.Context, _ *githubv4.Client, _, _, _ string) (bool, bool, error) {
129+
callCount++
130+
if callCount == 1 {
131+
return false, false, errors.New("boom")
132+
}
133+
return false, false, nil
134+
}
135+
136+
ctx := context.Background()
137+
138+
if _, err := ShouldRemoveContent(ctx, nil, "user", "owner", "repo"); err == nil {
139+
t.Fatal("expected error on first call")
140+
}
141+
if callCount != 1 {
142+
t.Fatalf("expected single call after error, got %d", callCount)
143+
}
144+
145+
remove, err := ShouldRemoveContent(ctx, nil, "user", "owner", "repo")
146+
if err != nil {
147+
t.Fatalf("unexpected error on retry: %v", err)
148+
}
149+
if !remove {
150+
t.Fatalf("expected remove=true when user lacks push access")
151+
}
152+
if callCount != 2 {
153+
t.Fatalf("expected repo access to be queried again after error, got %d calls", callCount)
154+
}
155+
}

0 commit comments

Comments
 (0)