diff --git a/github/config.go b/github/config.go index 37bc75321..5b4e8a51d 100644 --- a/github/config.go +++ b/github/config.go @@ -2,13 +2,19 @@ package github import ( "context" + "fmt" + "log" "net/http" "net/url" + "os" "path" "regexp" + "strconv" "strings" + "sync" "time" + "github.com/bradleyfalzon/ghinstallation/v2" "github.com/google/go-github/v66/github" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/logging" "github.com/shurcooL/githubv4" @@ -61,9 +67,50 @@ func RateLimitedHTTPClient(client *http.Client, writeDelay time.Duration, readDe func (c *Config) AuthenticatedHTTPClient() *http.Client { ctx := context.Background() - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: c.Token}, - ) + + initialExpiry := time.Now().Add(5 * time.Minute) // fallback expiry + + ts := NewRefreshingTokenSource(c.Token, initialExpiry, func(ctx context.Context) (string, time.Time, error) { + appID, err := strconv.ParseInt(os.Getenv("GITHUB_APP_ID"), 10, 64) + if err != nil { + return "", time.Time{}, fmt.Errorf("invalid GITHUB_APP_ID: %w", err) + } + + installationID, err := strconv.ParseInt(os.Getenv("GITHUB_APP_INSTALLATION_ID"), 10, 64) + if err != nil { + return "", time.Time{}, fmt.Errorf("invalid GITHUB_APP_INSTALLATION_ID: %w", err) + } + + var pemBytes []byte + pemFile := os.Getenv("GITHUB_APP_PEM_FILE") + if pemFile != "" { + pemBytes, err = os.ReadFile(pemFile) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to read PEM file: %w", err) + } + } else { + pemBytes = []byte(os.Getenv("GITHUB_APP_PEM")) + if len(pemBytes) == 0 { + return "", time.Time{}, fmt.Errorf("GITHUB_APP_PEM is empty") + } + } + + itr, err := ghinstallation.New(http.DefaultTransport, appID, installationID, pemBytes) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to create installation transport: %w", err) + } + + token, err := itr.Token(context.Background()) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to get GitHub App token: %w", err) + } + // Estimate expiry manually since ghinstallation.Token() doesn't return it + expiry := time.Now().Add(59 * time.Minute) + + log.Printf("[INFO] Refreshed GitHub App token valid until %s", expiry.Format(time.RFC3339)) + return token, expiry, nil + }) + client := oauth2.NewClient(ctx, ts) return RateLimitedHTTPClient(client, c.WriteDelay, c.ReadDelay, c.RetryDelay, c.ParallelRequests, c.RetryableErrors, c.MaxRetries) @@ -198,3 +245,45 @@ func (injector *previewHeaderInjectorTransport) RoundTrip(req *http.Request) (*h } return injector.rt.RoundTrip(req) } + +type refreshingTokenSource struct { + mu sync.Mutex + token string + expiry time.Time + refreshFunc func(ctx context.Context) (string, time.Time, error) +} + +func (r *refreshingTokenSource) Token() (*oauth2.Token, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if time.Now().Before(r.expiry.Add(-2*time.Minute)) && r.token != "" { + return &oauth2.Token{ + AccessToken: r.token, + TokenType: "Bearer", + Expiry: r.expiry, + }, nil + } + + newToken, newExpiry, err := r.refreshFunc(context.Background()) + if err != nil { + return nil, err + } + + r.token = newToken + r.expiry = newExpiry + + return &oauth2.Token{ + AccessToken: newToken, + TokenType: "Bearer", + Expiry: newExpiry, + }, nil +} + +func NewRefreshingTokenSource(initialToken string, initialExpiry time.Time, refreshFunc func(ctx context.Context) (string, time.Time, error)) oauth2.TokenSource { + return &refreshingTokenSource{ + token: initialToken, + expiry: initialExpiry, + refreshFunc: refreshFunc, + } +} diff --git a/github/config_refresh_token_test.go b/github/config_refresh_token_test.go new file mode 100644 index 000000000..bae0ef873 --- /dev/null +++ b/github/config_refresh_token_test.go @@ -0,0 +1,110 @@ +package github + +import ( + "context" + "errors" + "os" + "testing" + "time" +) + +// --- Unified Mock Refresh Function --- +func makeMockRefreshFunc(token string, expiry time.Time, fail bool) func(context.Context) (string, time.Time, error) { + return func(ctx context.Context) (string, time.Time, error) { + if fail { + return "", time.Time{}, errors.New("mock refresh failure") + } + return token, expiry, nil + } +} + +// --- RefreshingTokenSource Tests --- + +func TestRefreshingTokenSource_InitialValidToken(t *testing.T) { + exp := time.Now().Add(5 * time.Minute) + ts := NewRefreshingTokenSource("init-token", exp, makeMockRefreshFunc("new-token", time.Now().Add(10*time.Minute), false)) + + token, err := ts.Token() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.AccessToken != "init-token" { + t.Errorf("expected init-token, got %s", token.AccessToken) + } +} + +func TestRefreshingTokenSource_RefreshesAfterExpiry(t *testing.T) { + exp := time.Now().Add(-1 * time.Minute) + ts := NewRefreshingTokenSource("expired-token", exp, makeMockRefreshFunc("refreshed-token", time.Now().Add(10*time.Minute), false)) + + token, err := ts.Token() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token.AccessToken != "refreshed-token" { + t.Errorf("expected refreshed-token, got %s", token.AccessToken) + } +} + +func TestRefreshingTokenSource_RefreshFails(t *testing.T) { + exp := time.Now().Add(-1 * time.Minute) + ts := NewRefreshingTokenSource("expired-token", exp, makeMockRefreshFunc("", time.Time{}, true)) + + _, err := ts.Token() + if err == nil { + t.Fatal("expected error on refresh failure, got nil") + } +} + +func TestRefreshingTokenSource_Token(t *testing.T) { + rt := NewRefreshingTokenSource("initial-token", time.Now().Add(-10*time.Minute), func(ctx context.Context) (string, time.Time, error) { + return "fake-token", time.Now().Add(10 * time.Minute), nil + }) + token, err := rt.Token() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if token.AccessToken != "fake-token" { + t.Errorf("Expected token to be 'fake-token', got %s", token.AccessToken) + } +} + +// --- Config Behavior Tests --- + +func TestConfig_Anonymous(t *testing.T) { + cfg := &Config{Token: ""} + if !cfg.Anonymous() { + t.Error("expected anonymous to be true when token is empty") + } +} + +func TestConfig_NotAnonymous(t *testing.T) { + cfg := &Config{Token: "abc"} + if cfg.Anonymous() { + t.Error("expected anonymous to be false when token is set") + } +} + +func TestAnonymousClient(t *testing.T) { + config := &Config{} + if !config.Anonymous() { + t.Error("Expected config to be anonymous when no token is set") + } + client := config.AnonymousHTTPClient() + if client == nil { + t.Fatal("Expected a non-nil HTTP client") + } +} + +func TestAuthenticatedClientWithMock(t *testing.T) { + os.Setenv("GITHUB_APP_ID", "123456") + os.Setenv("GITHUB_APP_INSTALLATION_ID", "654321") + os.Setenv("GITHUB_APP_PEM", "dummy-pem-content") + + cfg := &Config{Token: "initial", BaseURL: "https://api.github.com"} + + client := cfg.AuthenticatedHTTPClient() + if client == nil { + t.Fatal("Expected non-nil authenticated HTTP client") + } +} \ No newline at end of file