Skip to content

Feat/app token refresh #2695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions github/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@ package github

import (
"context"
"fmt"
"log"
"net/http"
"net/url"
"os"
"path"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/bradleyfalzon/ghinstallation/v2"
"github.com/google/go-github/v66/github"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/logging"
"github.com/shurcooL/githubv4"
Expand Down Expand Up @@ -61,9 +67,50 @@ func RateLimitedHTTPClient(client *http.Client, writeDelay time.Duration, readDe
func (c *Config) AuthenticatedHTTPClient() *http.Client {

ctx := context.Background()
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: c.Token},
)

initialExpiry := time.Now().Add(5 * time.Minute) // fallback expiry

ts := NewRefreshingTokenSource(c.Token, initialExpiry, func(ctx context.Context) (string, time.Time, error) {
appID, err := strconv.ParseInt(os.Getenv("GITHUB_APP_ID"), 10, 64)
if err != nil {
return "", time.Time{}, fmt.Errorf("invalid GITHUB_APP_ID: %w", err)
}

installationID, err := strconv.ParseInt(os.Getenv("GITHUB_APP_INSTALLATION_ID"), 10, 64)
if err != nil {
return "", time.Time{}, fmt.Errorf("invalid GITHUB_APP_INSTALLATION_ID: %w", err)
}

var pemBytes []byte
pemFile := os.Getenv("GITHUB_APP_PEM_FILE")
if pemFile != "" {
pemBytes, err = os.ReadFile(pemFile)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to read PEM file: %w", err)
}
} else {
pemBytes = []byte(os.Getenv("GITHUB_APP_PEM"))
if len(pemBytes) == 0 {
return "", time.Time{}, fmt.Errorf("GITHUB_APP_PEM is empty")
}
}

itr, err := ghinstallation.New(http.DefaultTransport, appID, installationID, pemBytes)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to create installation transport: %w", err)
}

token, err := itr.Token(context.Background())
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to get GitHub App token: %w", err)
}
// Estimate expiry manually since ghinstallation.Token() doesn't return it
expiry := time.Now().Add(59 * time.Minute)

log.Printf("[INFO] Refreshed GitHub App token valid until %s", expiry.Format(time.RFC3339))
return token, expiry, nil
})

client := oauth2.NewClient(ctx, ts)

return RateLimitedHTTPClient(client, c.WriteDelay, c.ReadDelay, c.RetryDelay, c.ParallelRequests, c.RetryableErrors, c.MaxRetries)
Expand Down Expand Up @@ -198,3 +245,45 @@ func (injector *previewHeaderInjectorTransport) RoundTrip(req *http.Request) (*h
}
return injector.rt.RoundTrip(req)
}

type refreshingTokenSource struct {
mu sync.Mutex
token string
expiry time.Time
refreshFunc func(ctx context.Context) (string, time.Time, error)
}

func (r *refreshingTokenSource) Token() (*oauth2.Token, error) {
r.mu.Lock()
defer r.mu.Unlock()

if time.Now().Before(r.expiry.Add(-2*time.Minute)) && r.token != "" {
return &oauth2.Token{
AccessToken: r.token,
TokenType: "Bearer",
Expiry: r.expiry,
}, nil
}

newToken, newExpiry, err := r.refreshFunc(context.Background())
if err != nil {
return nil, err
}

r.token = newToken
r.expiry = newExpiry

return &oauth2.Token{
AccessToken: newToken,
TokenType: "Bearer",
Expiry: newExpiry,
}, nil
}

func NewRefreshingTokenSource(initialToken string, initialExpiry time.Time, refreshFunc func(ctx context.Context) (string, time.Time, error)) oauth2.TokenSource {
return &refreshingTokenSource{
token: initialToken,
expiry: initialExpiry,
refreshFunc: refreshFunc,
}
}
110 changes: 110 additions & 0 deletions github/config_refresh_token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package github

import (
"context"
"errors"
"os"
"testing"
"time"
)

// --- Unified Mock Refresh Function ---
func makeMockRefreshFunc(token string, expiry time.Time, fail bool) func(context.Context) (string, time.Time, error) {
return func(ctx context.Context) (string, time.Time, error) {
if fail {
return "", time.Time{}, errors.New("mock refresh failure")
}
return token, expiry, nil
}
}

// --- RefreshingTokenSource Tests ---

func TestRefreshingTokenSource_InitialValidToken(t *testing.T) {
exp := time.Now().Add(5 * time.Minute)
ts := NewRefreshingTokenSource("init-token", exp, makeMockRefreshFunc("new-token", time.Now().Add(10*time.Minute), false))

token, err := ts.Token()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token.AccessToken != "init-token" {
t.Errorf("expected init-token, got %s", token.AccessToken)
}
}

func TestRefreshingTokenSource_RefreshesAfterExpiry(t *testing.T) {
exp := time.Now().Add(-1 * time.Minute)
ts := NewRefreshingTokenSource("expired-token", exp, makeMockRefreshFunc("refreshed-token", time.Now().Add(10*time.Minute), false))

token, err := ts.Token()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token.AccessToken != "refreshed-token" {
t.Errorf("expected refreshed-token, got %s", token.AccessToken)
}
}

func TestRefreshingTokenSource_RefreshFails(t *testing.T) {
exp := time.Now().Add(-1 * time.Minute)
ts := NewRefreshingTokenSource("expired-token", exp, makeMockRefreshFunc("", time.Time{}, true))

_, err := ts.Token()
if err == nil {
t.Fatal("expected error on refresh failure, got nil")
}
}

func TestRefreshingTokenSource_Token(t *testing.T) {
rt := NewRefreshingTokenSource("initial-token", time.Now().Add(-10*time.Minute), func(ctx context.Context) (string, time.Time, error) {
return "fake-token", time.Now().Add(10 * time.Minute), nil
})
token, err := rt.Token()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if token.AccessToken != "fake-token" {
t.Errorf("Expected token to be 'fake-token', got %s", token.AccessToken)
}
}

// --- Config Behavior Tests ---

func TestConfig_Anonymous(t *testing.T) {
cfg := &Config{Token: ""}
if !cfg.Anonymous() {
t.Error("expected anonymous to be true when token is empty")
}
}

func TestConfig_NotAnonymous(t *testing.T) {
cfg := &Config{Token: "abc"}
if cfg.Anonymous() {
t.Error("expected anonymous to be false when token is set")
}
}

func TestAnonymousClient(t *testing.T) {
config := &Config{}
if !config.Anonymous() {
t.Error("Expected config to be anonymous when no token is set")
}
client := config.AnonymousHTTPClient()
if client == nil {
t.Fatal("Expected a non-nil HTTP client")
}
}

func TestAuthenticatedClientWithMock(t *testing.T) {
os.Setenv("GITHUB_APP_ID", "123456")
os.Setenv("GITHUB_APP_INSTALLATION_ID", "654321")
os.Setenv("GITHUB_APP_PEM", "dummy-pem-content")

cfg := &Config{Token: "initial", BaseURL: "https://api.github.com"}

client := cfg.AuthenticatedHTTPClient()
if client == nil {
t.Fatal("Expected non-nil authenticated HTTP client")
}
}