diff --git a/backend/helpers/pluginhelper/api/api_client.go b/backend/helpers/pluginhelper/api/api_client.go index 1e7e57d44f5..b0cfccf499f 100644 --- a/backend/helpers/pluginhelper/api/api_client.go +++ b/backend/helpers/pluginhelper/api/api_client.go @@ -299,6 +299,11 @@ func (apiClient *ApiClient) SetLogger(logger log.Logger) { apiClient.logger = logger } +// GetClient returns the underlying http.Client +func (apiClient *ApiClient) GetClient() *http.Client { + return apiClient.client +} + func (apiClient *ApiClient) logDebug(format string, a ...interface{}) { if apiClient.logger != nil { apiClient.logger.Debug(format, a...) diff --git a/backend/plugins/github/models/connection.go b/backend/plugins/github/models/connection.go index 6a8c06a3738..4dba2f75689 100644 --- a/backend/plugins/github/models/connection.go +++ b/backend/plugins/github/models/connection.go @@ -56,6 +56,21 @@ type GithubConn struct { helper.MultiAuth `mapstructure:",squash"` GithubAccessToken `mapstructure:",squash" authMethod:"AccessToken"` GithubAppKey `mapstructure:",squash" authMethod:"AppKey"` + RefreshToken string `mapstructure:"refreshToken" json:"refreshToken" gorm:"type:text;serializer:encdec"` + TokenExpiresAt time.Time `mapstructure:"tokenExpiresAt" json:"tokenExpiresAt"` + RefreshTokenExpiresAt time.Time `mapstructure:"refreshTokenExpiresAt" json:"refreshTokenExpiresAt"` +} + +// UpdateToken updates the token and refresh token information +func (conn *GithubConn) UpdateToken(newToken, newRefreshToken string, expiry, refreshExpiry time.Time) { + conn.Token = newToken + conn.RefreshToken = newRefreshToken + conn.TokenExpiresAt = expiry + conn.RefreshTokenExpiresAt = refreshExpiry + + // Update the internal tokens slice used by SetupAuthentication + conn.tokens = []string{newToken} + conn.tokenIndex = 0 } // PrepareApiClient splits Token to tokens for SetupAuthentication to utilize @@ -249,7 +264,7 @@ func (conn *GithubConn) typeIs(token string) string { // total len is 40, {prefix}{showPrefix}{secret}{showSuffix} // fine-grained tokens // github_pat_{82_characters} - classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_", "ghr_"} + classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_", "ghr_", "ghu_"} classicalTokenFindGrainedPrefixes := []string{"github_pat_"} for _, prefix := range classicalTokenClassicalPrefixes { if strings.HasPrefix(token, prefix) { diff --git a/backend/plugins/github/models/connection_test.go b/backend/plugins/github/models/connection_test.go index 7b39cf0bdc5..41323b9b562 100644 --- a/backend/plugins/github/models/connection_test.go +++ b/backend/plugins/github/models/connection_test.go @@ -227,3 +227,14 @@ func TestGithubConnection_Sanitize(t *testing.T) { }) } } + +func TestTokenTypeClassification(t *testing.T) { + conn := &GithubConn{} + assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghp_123")) + assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("gho_123")) + assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghu_123")) + assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghs_123")) + assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghr_123")) + assert.Equal(t, GithubTokenTypeFineGrained, conn.typeIs("github_pat_123")) + assert.Equal(t, GithubTokenTypeUnknown, conn.typeIs("some_other_token")) +} diff --git a/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go b/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go new file mode 100644 index 00000000000..b2f826da388 --- /dev/null +++ b/backend/plugins/github/models/migrationscripts/20241120_add_refresh_token_fields.go @@ -0,0 +1,53 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package migrationscripts + +import ( + "time" + + "github.com/apache/incubator-devlake/core/context" + "github.com/apache/incubator-devlake/core/errors" + "github.com/apache/incubator-devlake/helpers/migrationhelper" +) + +type githubConnection20241120 struct { + RefreshToken string `gorm:"type:text;serializer:encdec"` + TokenExpiresAt time.Time + RefreshTokenExpiresAt time.Time +} + +func (githubConnection20241120) TableName() string { + return "_tool_github_connections" +} + +type addRefreshTokenFields struct{} + +func (*addRefreshTokenFields) Up(basicRes context.BasicRes) errors.Error { + return migrationhelper.AutoMigrateTables( + basicRes, + &githubConnection20241120{}, + ) +} + +func (*addRefreshTokenFields) Version() uint64 { + return 20241120000001 +} + +func (*addRefreshTokenFields) Name() string { + return "add refresh token fields to github_connections" +} diff --git a/backend/plugins/github/models/migrationscripts/register.go b/backend/plugins/github/models/migrationscripts/register.go index b8a0722eb02..74f9d712b4c 100644 --- a/backend/plugins/github/models/migrationscripts/register.go +++ b/backend/plugins/github/models/migrationscripts/register.go @@ -55,5 +55,6 @@ func All() []plugin.MigrationScript { new(addIsDraftToPr), new(changeIssueComponentType), new(addIndexToGithubJobs), + new(addRefreshTokenFields), } } diff --git a/backend/plugins/github/tasks/api_client.go b/backend/plugins/github/tasks/api_client.go index c9bfa852c4a..268af8ecec1 100644 --- a/backend/plugins/github/tasks/api_client.go +++ b/backend/plugins/github/tasks/api_client.go @@ -26,6 +26,7 @@ import ( "github.com/apache/incubator-devlake/core/plugin" "github.com/apache/incubator-devlake/helpers/pluginhelper/api" "github.com/apache/incubator-devlake/plugins/github/models" + "github.com/apache/incubator-devlake/plugins/github/token" ) func CreateApiClient(taskCtx plugin.TaskContext, connection *models.GithubConnection) (*api.ApiAsyncClient, errors.Error) { @@ -34,6 +35,24 @@ func CreateApiClient(taskCtx plugin.TaskContext, connection *models.GithubConnec return nil, err } + // Inject TokenProvider if refresh token is present + if connection.RefreshToken != "" { + logger := taskCtx.GetLogger() + db := taskCtx.GetDal() + + // Create TokenProvider + tp := token.NewTokenProvider(connection, db, apiClient.GetClient(), logger) + + // Wrap the transport + baseTransport := apiClient.GetClient().Transport + if baseTransport == nil { + baseTransport = http.DefaultTransport + } + + rt := token.NewRefreshRoundTripper(baseTransport, tp) + apiClient.GetClient().Transport = rt + } + // create rate limit calculator rateLimiter := &api.ApiRateLimitCalculator{ UserRateLimitPerHour: connection.RateLimitPerHour, diff --git a/backend/plugins/github/token/round_tripper.go b/backend/plugins/github/token/round_tripper.go new file mode 100644 index 00000000000..45ba3e9a7b6 --- /dev/null +++ b/backend/plugins/github/token/round_tripper.go @@ -0,0 +1,90 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package token + +import ( + "net/http" +) + +// RefreshRoundTripper is an HTTP transport middleware that automatically manages OAuth token refreshes. +// It wraps an underlying http.RoundTripper and provides token refresh on auth failures. +// On 401's the round tripper will: +// - Force a refresh of the OAuth token via the TokenProvider +// - Retry the original request with the new token +type RefreshRoundTripper struct { + base http.RoundTripper + tokenProvider *TokenProvider +} + +func NewRefreshRoundTripper(base http.RoundTripper, tp *TokenProvider) *RefreshRoundTripper { + return &RefreshRoundTripper{ + base: base, + tokenProvider: tp, + } +} + +// RoundTrip implements the http.RoundTripper interface and handles automatic token refresh on 401 responses. +// It clones the request, adds the Authorization header, and retries once with a refreshed token if needed. +func (rt *RefreshRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return rt.roundTripWithRetry(req, false) +} + +// roundTripWithRetry performs the actual request with retry on 401. +// The refreshAttempted parameter tracks whether a refresh has already been tried for this request +// to prevent infinite retry loops if token refresh itself fails. +func (rt *RefreshRoundTripper) roundTripWithRetry(req *http.Request, refreshAttempted bool) (*http.Response, error) { + // Get token + token, err := rt.tokenProvider.GetToken() + if err != nil { + return nil, err + } + + // Clone request before modifying + reqClone := req.Clone(req.Context()) + reqClone.Header.Set("Authorization", "Bearer "+token) + + // Execute request + resp, reqErr := rt.base.RoundTrip(reqClone) + if reqErr != nil { + return nil, reqErr + } + + // Reactive refresh on 401 + if resp.StatusCode == http.StatusUnauthorized && !refreshAttempted { + // Close previous response body + resp.Body.Close() + + // Force refresh + if err := rt.tokenProvider.ForceRefresh(token); err != nil { + return nil, err + } + + // Get new token + newToken, err := rt.tokenProvider.GetToken() + if err != nil { + return nil, err + } + + // Retry request with new token + reqRetry := req.Clone(req.Context()) + reqRetry.Header.Set("Authorization", "Bearer "+newToken) + return rt.roundTripWithRetry(reqRetry, true) + } + + return resp, nil +} diff --git a/backend/plugins/github/token/round_tripper_test.go b/backend/plugins/github/token/round_tripper_test.go new file mode 100644 index 00000000000..6767d8ccfd1 --- /dev/null +++ b/backend/plugins/github/token/round_tripper_test.go @@ -0,0 +1,101 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package token + +import ( + "bytes" + "io" + "net/http" + "testing" + "time" + + "github.com/apache/incubator-devlake/helpers/pluginhelper/api" + "github.com/apache/incubator-devlake/impls/logruslog" + "github.com/apache/incubator-devlake/plugins/github/models" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestRoundTripper401Refresh(t *testing.T) { + mockRT := new(MockRoundTripper) + client := &http.Client{Transport: mockRT} + + conn := &models.GithubConnection{ + GithubConn: models.GithubConn{ + RefreshToken: "refresh_token", + GithubAccessToken: models.GithubAccessToken{ + AccessToken: api.AccessToken{ + Token: "old_token", + }, + }, + TokenExpiresAt: time.Now().Add(10 * time.Minute), // Not expired + GithubAppKey: models.GithubAppKey{ + AppKey: api.AppKey{ + AppId: "123", + SecretKey: "secret", + }, + }, + }, + } + + logger, _ := logruslog.NewDefaultLogger(logrus.New()) + tp := NewTokenProvider(conn, nil, client, logger) + rt := NewRefreshRoundTripper(mockRT, tp) + + // Request + req, _ := http.NewRequest("GET", "https://api.github.com/user", nil) + + // 1. First call returns 401 + resp401 := &http.Response{ + StatusCode: 401, + Body: io.NopCloser(bytes.NewBufferString("Unauthorized")), + } + mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool { + return r.Header.Get("Authorization") == "Bearer old_token" && r.URL.String() == "https://api.github.com/user" + })).Return(resp401, nil).Once() + + // 2. Refresh call (triggered by 401) + respRefresh := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`)), + } + // The refresh call uses the same client, so it goes through mockRT too! + mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool { + return r.URL.String() == "https://github.com/login/oauth/access_token" + })).Return(respRefresh, nil).Once() + + // 3. Retry call with new token + resp200 := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString("Success")), + } + mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool { + return r.Header.Get("Authorization") == "Bearer new_token" && r.URL.String() == "https://api.github.com/user" + })).Return(resp200, nil).Once() + + // Execute + resp, err := rt.RoundTrip(req) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, "Success", string(body)) + + mockRT.AssertExpectations(t) +} diff --git a/backend/plugins/github/token/token_provider.go b/backend/plugins/github/token/token_provider.go new file mode 100644 index 00000000000..ba9941cd47d --- /dev/null +++ b/backend/plugins/github/token/token_provider.go @@ -0,0 +1,185 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package token + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strconv" + "sync" + "time" + + "github.com/apache/incubator-devlake/core/dal" + "github.com/apache/incubator-devlake/core/errors" + "github.com/apache/incubator-devlake/core/log" + "github.com/apache/incubator-devlake/plugins/github/models" +) + +const ( + DefaultRefreshBuffer = 5 * time.Minute +) + +type TokenProvider struct { + conn *models.GithubConnection + dal dal.Dal + httpClient *http.Client + logger log.Logger + mu sync.Mutex + refreshURL string +} + +// NewTokenProvider creates a TokenProvider for the given GitHub connection using +// the provided DAL, HTTP client, and logger, and returns a pointer to it. +func NewTokenProvider(conn *models.GithubConnection, d dal.Dal, client *http.Client, logger log.Logger) *TokenProvider { + return &TokenProvider{ + conn: conn, + dal: d, + httpClient: client, + logger: logger, + refreshURL: "https://github.com/login/oauth/access_token", + } +} + +func (tp *TokenProvider) GetToken() (string, errors.Error) { + tp.mu.Lock() + defer tp.mu.Unlock() + + if tp.needsRefresh() { + if err := tp.refreshToken(); err != nil { + return "", err + } + } + return tp.conn.Token, nil +} + +func (tp *TokenProvider) needsRefresh() bool { + if tp.conn.RefreshToken == "" { + return false + } + + buffer := DefaultRefreshBuffer + if envBuffer := os.Getenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES"); envBuffer != "" { + if val, err := strconv.Atoi(envBuffer); err == nil { + buffer = time.Duration(val) * time.Minute + } + } + + return time.Now().Add(buffer).After(tp.conn.TokenExpiresAt) +} + +func (tp *TokenProvider) refreshToken() errors.Error { + tp.logger.Info("Refreshing GitHub token for connection %d", tp.conn.ID) + + data := map[string]string{ + "refresh_token": tp.conn.RefreshToken, + "grant_type": "refresh_token", + "client_id": tp.conn.AppId, + "client_secret": tp.conn.SecretKey, + } + jsonData, _ := json.Marshal(data) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", tp.refreshURL, bytes.NewBuffer(jsonData)) + if err != nil { + return errors.Convert(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := tp.httpClient.Do(req) + if err != nil { + return errors.Convert(err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return errors.Convert(err) + } + + if resp.StatusCode != http.StatusOK { + // Log the response body to aid in debugging token refresh failures. + if tp.logger != nil { + tp.logger.Error(nil, "failed to refresh token from GitHub, status=%d, body=%s", resp.StatusCode, string(body)) + } + return errors.Default.New(fmt.Sprintf("failed to refresh token: %d, body: %s", resp.StatusCode, string(body))) + } + var result struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + RefreshTokenExpiresIn int `json:"refresh_token_expires_in"` + } + if err := json.Unmarshal(body, &result); err != nil { + return errors.Convert(err) + } + + if result.AccessToken == "" { + bodyStr := string(body) + const maxBodySnippet = 512 + if len(bodyStr) > maxBodySnippet { + bodyStr = bodyStr[:maxBodySnippet] + "…" + } + return errors.Default.New(fmt.Sprintf("empty access token returned; response body: %s", bodyStr)) + } + + tp.conn.UpdateToken( + result.AccessToken, + result.RefreshToken, + time.Now().Add(time.Duration(result.ExpiresIn)*time.Second), + time.Now().Add(time.Duration(result.RefreshTokenExpiresIn)*time.Second), + ) + + if tp.dal != nil { + err := tp.dal.UpdateColumns(tp.conn, []dal.DalSet{ + {ColumnName: "token", Value: tp.conn.Token}, + {ColumnName: "refresh_token", Value: tp.conn.RefreshToken}, + {ColumnName: "token_expires_at", Value: tp.conn.TokenExpiresAt}, + {ColumnName: "refresh_token_expires_at", Value: tp.conn.RefreshTokenExpiresAt}, + }) + if err != nil { + tp.logger.Warn(err, "failed to persist refreshed token") + } + } + + return nil +} + +// ForceRefresh refreshes the access token if the current token is still equal to oldToken. +// The oldToken parameter should be the token value observed by the caller when it determined +// that a refresh might be needed; if the token has changed since then, another goroutine has +// already refreshed it and this method returns without performing a redundant refresh. +func (tp *TokenProvider) ForceRefresh(oldToken string) errors.Error { + tp.mu.Lock() + defer tp.mu.Unlock() + + // If the token has changed since the request was made, it means another thread + // has already refreshed it. + if tp.conn.Token != oldToken { + return nil + } + + return tp.refreshToken() +} diff --git a/backend/plugins/github/token/token_provider_test.go b/backend/plugins/github/token/token_provider_test.go new file mode 100644 index 00000000000..1c296376a6e --- /dev/null +++ b/backend/plugins/github/token/token_provider_test.go @@ -0,0 +1,180 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package token + +import ( + "bytes" + "io" + "net/http" + "os" + "sync" + "testing" + "time" + + "github.com/apache/incubator-devlake/core/errors" + "github.com/apache/incubator-devlake/helpers/pluginhelper/api" + "github.com/apache/incubator-devlake/impls/logruslog" + mockdal "github.com/apache/incubator-devlake/mocks/core/dal" + "github.com/apache/incubator-devlake/plugins/github/models" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockRoundTripper struct { + mock.Mock +} + +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Called(req) + return args.Get(0).(*http.Response), args.Error(1) +} + +func TestNeedsRefresh(t *testing.T) { + tp := &TokenProvider{ + conn: &models.GithubConnection{ + GithubConn: models.GithubConn{ + RefreshToken: "refresh_token", + }, + }, + } + + // Not expired, outside buffer + tp.conn.TokenExpiresAt = time.Now().Add(10 * time.Minute) + assert.False(t, tp.needsRefresh()) + + // Inside buffer + tp.conn.TokenExpiresAt = time.Now().Add(1 * time.Minute) + assert.True(t, tp.needsRefresh()) + + // Expired + tp.conn.TokenExpiresAt = time.Now().Add(-1 * time.Minute) + assert.True(t, tp.needsRefresh()) + + // No refresh token + tp.conn.RefreshToken = "" + assert.False(t, tp.needsRefresh()) +} + +func TestTokenProviderConcurrency(t *testing.T) { + mockRT := new(MockRoundTripper) + client := &http.Client{Transport: mockRT} + + conn := &models.GithubConnection{ + GithubConn: models.GithubConn{ + RefreshToken: "refresh_token", + TokenExpiresAt: time.Now().Add(-1 * time.Minute), // Expired + GithubAppKey: models.GithubAppKey{ + AppKey: api.AppKey{ + AppId: "123", + SecretKey: "secret", + }, + }, + }, + } + + logger, _ := logruslog.NewDefaultLogger(logrus.New()) + tp := NewTokenProvider(conn, nil, client, logger) + + // Mock response for refresh + respBody := `{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}` + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(respBody)), + } + + // Expect exactly one call + mockRT.On("RoundTrip", mock.Anything).Return(resp, nil).Once() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + token, err := tp.GetToken() + assert.NoError(t, err) + assert.Equal(t, "new_token", token) + }() + } + wg.Wait() + + mockRT.AssertExpectations(t) +} + +func TestConfigurableBuffer(t *testing.T) { + os.Setenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES", "10") + defer os.Unsetenv("GITHUB_TOKEN_REFRESH_BUFFER_MINUTES") + + tp := &TokenProvider{ + conn: &models.GithubConnection{ + GithubConn: models.GithubConn{ + RefreshToken: "refresh_token", + }, + }, + } + + // 9 minutes remaining (inside 10m buffer) + tp.conn.TokenExpiresAt = time.Now().Add(9 * time.Minute) + assert.True(t, tp.needsRefresh()) + + // 11 minutes remaining (outside 10m buffer) + tp.conn.TokenExpiresAt = time.Now().Add(11 * time.Minute) + assert.False(t, tp.needsRefresh()) +} + +func TestPersistenceFailure(t *testing.T) { + mockRT := new(MockRoundTripper) + client := &http.Client{Transport: mockRT} + mockDal := new(mockdal.Dal) + + conn := &models.GithubConnection{ + GithubConn: models.GithubConn{ + RefreshToken: "refresh_token", + GithubAccessToken: models.GithubAccessToken{ + AccessToken: api.AccessToken{ + Token: "old_token", + }, + }, + GithubAppKey: models.GithubAppKey{ + AppKey: api.AppKey{ + AppId: "123", + SecretKey: "secret", + }, + }, + }, + } + + logger, _ := logruslog.NewDefaultLogger(logrus.New()) + tp := NewTokenProvider(conn, mockDal, client, logger) + + // Mock response for refresh + respBody := `{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}` + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString(respBody)), + } + mockRT.On("RoundTrip", mock.Anything).Return(resp, nil).Once() + + // Mock DAL failure + mockDal.On("UpdateColumns", mock.Anything, mock.Anything, mock.AnythingOfType("[]dal.Clause")).Return(errors.Default.New("db error")) + err := tp.ForceRefresh("old_token") + assert.NoError(t, err) // Should not return error even if persistence fails + + mockRT.AssertExpectations(t) + mockDal.AssertExpectations(t) +}