Skip to content

Commit 1a66e79

Browse files
ysinghcCopilot
andauthored
feat(github): add support for refresh tokens and token management (#8667)
* feat(github): add support for refresh tokens and token management * Update backend/plugins/github/token/token_provider.go Co-authored-by: Copilot <[email protected]> * Update backend/plugins/github/token/token_provider.go Co-authored-by: Copilot <[email protected]> * Update backend/plugins/github/token/token_provider.go Co-authored-by: Copilot <[email protected]> * added documentation for roundtripper and fixed the infinite loop issue and added tests * fixed the illegal plugins/github/models import * Remove error conversion on token refresh failure and add token provider tests * renamed TestIsUserToServerToken to TestTokenTypeClassification for clarity * Update backend/plugins/github/token/token_provider.go Co-authored-by: Copilot <[email protected]> * Update backend/plugins/github/token/token_provider.go Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent a0f1c98 commit 1a66e79

File tree

10 files changed

+661
-1
lines changed

10 files changed

+661
-1
lines changed

backend/helpers/pluginhelper/api/api_client.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,11 @@ func (apiClient *ApiClient) SetLogger(logger log.Logger) {
299299
apiClient.logger = logger
300300
}
301301

302+
// GetClient returns the underlying http.Client
303+
func (apiClient *ApiClient) GetClient() *http.Client {
304+
return apiClient.client
305+
}
306+
302307
func (apiClient *ApiClient) logDebug(format string, a ...interface{}) {
303308
if apiClient.logger != nil {
304309
apiClient.logger.Debug(format, a...)

backend/plugins/github/models/connection.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@ type GithubConn struct {
5656
helper.MultiAuth `mapstructure:",squash"`
5757
GithubAccessToken `mapstructure:",squash" authMethod:"AccessToken"`
5858
GithubAppKey `mapstructure:",squash" authMethod:"AppKey"`
59+
RefreshToken string `mapstructure:"refreshToken" json:"refreshToken" gorm:"type:text;serializer:encdec"`
60+
TokenExpiresAt time.Time `mapstructure:"tokenExpiresAt" json:"tokenExpiresAt"`
61+
RefreshTokenExpiresAt time.Time `mapstructure:"refreshTokenExpiresAt" json:"refreshTokenExpiresAt"`
62+
}
63+
64+
// UpdateToken updates the token and refresh token information
65+
func (conn *GithubConn) UpdateToken(newToken, newRefreshToken string, expiry, refreshExpiry time.Time) {
66+
conn.Token = newToken
67+
conn.RefreshToken = newRefreshToken
68+
conn.TokenExpiresAt = expiry
69+
conn.RefreshTokenExpiresAt = refreshExpiry
70+
71+
// Update the internal tokens slice used by SetupAuthentication
72+
conn.tokens = []string{newToken}
73+
conn.tokenIndex = 0
5974
}
6075

6176
// PrepareApiClient splits Token to tokens for SetupAuthentication to utilize
@@ -249,7 +264,7 @@ func (conn *GithubConn) typeIs(token string) string {
249264
// total len is 40, {prefix}{showPrefix}{secret}{showSuffix}
250265
// fine-grained tokens
251266
// github_pat_{82_characters}
252-
classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_", "ghr_"}
267+
classicalTokenClassicalPrefixes := []string{"ghp_", "gho_", "ghs_", "ghr_", "ghu_"}
253268
classicalTokenFindGrainedPrefixes := []string{"github_pat_"}
254269
for _, prefix := range classicalTokenClassicalPrefixes {
255270
if strings.HasPrefix(token, prefix) {

backend/plugins/github/models/connection_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,14 @@ func TestGithubConnection_Sanitize(t *testing.T) {
227227
})
228228
}
229229
}
230+
231+
func TestTokenTypeClassification(t *testing.T) {
232+
conn := &GithubConn{}
233+
assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghp_123"))
234+
assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("gho_123"))
235+
assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghu_123"))
236+
assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghs_123"))
237+
assert.Equal(t, GithubTokenTypeClassical, conn.typeIs("ghr_123"))
238+
assert.Equal(t, GithubTokenTypeFineGrained, conn.typeIs("github_pat_123"))
239+
assert.Equal(t, GithubTokenTypeUnknown, conn.typeIs("some_other_token"))
240+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one or more
3+
contributor license agreements. See the NOTICE file distributed with
4+
this work for additional information regarding copyright ownership.
5+
The ASF licenses this file to You under the Apache License, Version 2.0
6+
(the "License"); you may not use this file except in compliance with
7+
the License. You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
*/
17+
18+
package migrationscripts
19+
20+
import (
21+
"time"
22+
23+
"github.com/apache/incubator-devlake/core/context"
24+
"github.com/apache/incubator-devlake/core/errors"
25+
"github.com/apache/incubator-devlake/helpers/migrationhelper"
26+
)
27+
28+
type githubConnection20241120 struct {
29+
RefreshToken string `gorm:"type:text;serializer:encdec"`
30+
TokenExpiresAt time.Time
31+
RefreshTokenExpiresAt time.Time
32+
}
33+
34+
func (githubConnection20241120) TableName() string {
35+
return "_tool_github_connections"
36+
}
37+
38+
type addRefreshTokenFields struct{}
39+
40+
func (*addRefreshTokenFields) Up(basicRes context.BasicRes) errors.Error {
41+
return migrationhelper.AutoMigrateTables(
42+
basicRes,
43+
&githubConnection20241120{},
44+
)
45+
}
46+
47+
func (*addRefreshTokenFields) Version() uint64 {
48+
return 20241120000001
49+
}
50+
51+
func (*addRefreshTokenFields) Name() string {
52+
return "add refresh token fields to github_connections"
53+
}

backend/plugins/github/models/migrationscripts/register.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,6 @@ func All() []plugin.MigrationScript {
5555
new(addIsDraftToPr),
5656
new(changeIssueComponentType),
5757
new(addIndexToGithubJobs),
58+
new(addRefreshTokenFields),
5859
}
5960
}

backend/plugins/github/tasks/api_client.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/apache/incubator-devlake/core/plugin"
2727
"github.com/apache/incubator-devlake/helpers/pluginhelper/api"
2828
"github.com/apache/incubator-devlake/plugins/github/models"
29+
"github.com/apache/incubator-devlake/plugins/github/token"
2930
)
3031

3132
func CreateApiClient(taskCtx plugin.TaskContext, connection *models.GithubConnection) (*api.ApiAsyncClient, errors.Error) {
@@ -34,6 +35,24 @@ func CreateApiClient(taskCtx plugin.TaskContext, connection *models.GithubConnec
3435
return nil, err
3536
}
3637

38+
// Inject TokenProvider if refresh token is present
39+
if connection.RefreshToken != "" {
40+
logger := taskCtx.GetLogger()
41+
db := taskCtx.GetDal()
42+
43+
// Create TokenProvider
44+
tp := token.NewTokenProvider(connection, db, apiClient.GetClient(), logger)
45+
46+
// Wrap the transport
47+
baseTransport := apiClient.GetClient().Transport
48+
if baseTransport == nil {
49+
baseTransport = http.DefaultTransport
50+
}
51+
52+
rt := token.NewRefreshRoundTripper(baseTransport, tp)
53+
apiClient.GetClient().Transport = rt
54+
}
55+
3756
// create rate limit calculator
3857
rateLimiter := &api.ApiRateLimitCalculator{
3958
UserRateLimitPerHour: connection.RateLimitPerHour,
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one or more
3+
contributor license agreements. See the NOTICE file distributed with
4+
this work for additional information regarding copyright ownership.
5+
The ASF licenses this file to You under the Apache License, Version 2.0
6+
(the "License"); you may not use this file except in compliance with
7+
the License. You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
*/
17+
18+
package token
19+
20+
import (
21+
"net/http"
22+
)
23+
24+
// RefreshRoundTripper is an HTTP transport middleware that automatically manages OAuth token refreshes.
25+
// It wraps an underlying http.RoundTripper and provides token refresh on auth failures.
26+
// On 401's the round tripper will:
27+
// - Force a refresh of the OAuth token via the TokenProvider
28+
// - Retry the original request with the new token
29+
type RefreshRoundTripper struct {
30+
base http.RoundTripper
31+
tokenProvider *TokenProvider
32+
}
33+
34+
func NewRefreshRoundTripper(base http.RoundTripper, tp *TokenProvider) *RefreshRoundTripper {
35+
return &RefreshRoundTripper{
36+
base: base,
37+
tokenProvider: tp,
38+
}
39+
}
40+
41+
// RoundTrip implements the http.RoundTripper interface and handles automatic token refresh on 401 responses.
42+
// It clones the request, adds the Authorization header, and retries once with a refreshed token if needed.
43+
func (rt *RefreshRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
44+
return rt.roundTripWithRetry(req, false)
45+
}
46+
47+
// roundTripWithRetry performs the actual request with retry on 401.
48+
// The refreshAttempted parameter tracks whether a refresh has already been tried for this request
49+
// to prevent infinite retry loops if token refresh itself fails.
50+
func (rt *RefreshRoundTripper) roundTripWithRetry(req *http.Request, refreshAttempted bool) (*http.Response, error) {
51+
// Get token
52+
token, err := rt.tokenProvider.GetToken()
53+
if err != nil {
54+
return nil, err
55+
}
56+
57+
// Clone request before modifying
58+
reqClone := req.Clone(req.Context())
59+
reqClone.Header.Set("Authorization", "Bearer "+token)
60+
61+
// Execute request
62+
resp, reqErr := rt.base.RoundTrip(reqClone)
63+
if reqErr != nil {
64+
return nil, reqErr
65+
}
66+
67+
// Reactive refresh on 401
68+
if resp.StatusCode == http.StatusUnauthorized && !refreshAttempted {
69+
// Close previous response body
70+
resp.Body.Close()
71+
72+
// Force refresh
73+
if err := rt.tokenProvider.ForceRefresh(token); err != nil {
74+
return nil, err
75+
}
76+
77+
// Get new token
78+
newToken, err := rt.tokenProvider.GetToken()
79+
if err != nil {
80+
return nil, err
81+
}
82+
83+
// Retry request with new token
84+
reqRetry := req.Clone(req.Context())
85+
reqRetry.Header.Set("Authorization", "Bearer "+newToken)
86+
return rt.roundTripWithRetry(reqRetry, true)
87+
}
88+
89+
return resp, nil
90+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one or more
3+
contributor license agreements. See the NOTICE file distributed with
4+
this work for additional information regarding copyright ownership.
5+
The ASF licenses this file to You under the Apache License, Version 2.0
6+
(the "License"); you may not use this file except in compliance with
7+
the License. You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
*/
17+
18+
package token
19+
20+
import (
21+
"bytes"
22+
"io"
23+
"net/http"
24+
"testing"
25+
"time"
26+
27+
"github.com/apache/incubator-devlake/helpers/pluginhelper/api"
28+
"github.com/apache/incubator-devlake/impls/logruslog"
29+
"github.com/apache/incubator-devlake/plugins/github/models"
30+
"github.com/sirupsen/logrus"
31+
"github.com/stretchr/testify/assert"
32+
"github.com/stretchr/testify/mock"
33+
)
34+
35+
func TestRoundTripper401Refresh(t *testing.T) {
36+
mockRT := new(MockRoundTripper)
37+
client := &http.Client{Transport: mockRT}
38+
39+
conn := &models.GithubConnection{
40+
GithubConn: models.GithubConn{
41+
RefreshToken: "refresh_token",
42+
GithubAccessToken: models.GithubAccessToken{
43+
AccessToken: api.AccessToken{
44+
Token: "old_token",
45+
},
46+
},
47+
TokenExpiresAt: time.Now().Add(10 * time.Minute), // Not expired
48+
GithubAppKey: models.GithubAppKey{
49+
AppKey: api.AppKey{
50+
AppId: "123",
51+
SecretKey: "secret",
52+
},
53+
},
54+
},
55+
}
56+
57+
logger, _ := logruslog.NewDefaultLogger(logrus.New())
58+
tp := NewTokenProvider(conn, nil, client, logger)
59+
rt := NewRefreshRoundTripper(mockRT, tp)
60+
61+
// Request
62+
req, _ := http.NewRequest("GET", "https://api.github.com/user", nil)
63+
64+
// 1. First call returns 401
65+
resp401 := &http.Response{
66+
StatusCode: 401,
67+
Body: io.NopCloser(bytes.NewBufferString("Unauthorized")),
68+
}
69+
mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
70+
return r.Header.Get("Authorization") == "Bearer old_token" && r.URL.String() == "https://api.github.com/user"
71+
})).Return(resp401, nil).Once()
72+
73+
// 2. Refresh call (triggered by 401)
74+
respRefresh := &http.Response{
75+
StatusCode: 200,
76+
Body: io.NopCloser(bytes.NewBufferString(`{"access_token":"new_token","refresh_token":"new_refresh_token","expires_in":3600,"refresh_token_expires_in":3600}`)),
77+
}
78+
// The refresh call uses the same client, so it goes through mockRT too!
79+
mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
80+
return r.URL.String() == "https://github.com/login/oauth/access_token"
81+
})).Return(respRefresh, nil).Once()
82+
83+
// 3. Retry call with new token
84+
resp200 := &http.Response{
85+
StatusCode: 200,
86+
Body: io.NopCloser(bytes.NewBufferString("Success")),
87+
}
88+
mockRT.On("RoundTrip", mock.MatchedBy(func(r *http.Request) bool {
89+
return r.Header.Get("Authorization") == "Bearer new_token" && r.URL.String() == "https://api.github.com/user"
90+
})).Return(resp200, nil).Once()
91+
92+
// Execute
93+
resp, err := rt.RoundTrip(req)
94+
assert.NoError(t, err)
95+
assert.Equal(t, 200, resp.StatusCode)
96+
97+
body, _ := io.ReadAll(resp.Body)
98+
assert.Equal(t, "Success", string(body))
99+
100+
mockRT.AssertExpectations(t)
101+
}

0 commit comments

Comments
 (0)