Skip to content

Commit 20ac475

Browse files
committed
update LoadAccessToken to consider refresh exp
1 parent ebe017a commit 20ac475

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

manage/manage_test.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package manage_test
22

33
import (
44
"testing"
5+
"time"
56

67
"gopkg.in/oauth2.v3"
78
"gopkg.in/oauth2.v3/manage"
@@ -128,10 +129,11 @@ func testZeroRefreshExpirationManager(tgr *oauth2.TokenGenerateRequest, manager
128129
So(code, ShouldNotBeEmpty)
129130

130131
atParams := &oauth2.TokenGenerateRequest{
131-
ClientID: tgr.ClientID,
132-
ClientSecret: "11",
133-
RedirectURI: tgr.RedirectURI,
134-
Code: code,
132+
ClientID: tgr.ClientID,
133+
ClientSecret: "11",
134+
RedirectURI: tgr.RedirectURI,
135+
AccessTokenExp: time.Hour,
136+
Code: code,
135137
}
136138
ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCode, atParams)
137139
So(err, ShouldBeNil)
@@ -145,4 +147,11 @@ func testZeroRefreshExpirationManager(tgr *oauth2.TokenGenerateRequest, manager
145147
So(tokenInfo, ShouldNotBeNil)
146148
So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken)
147149
So(tokenInfo.GetRefreshExpiresIn(), ShouldEqual, 0)
150+
151+
// LoadAccessToken also checks refresh expiry
152+
tokenInfo, err = manager.LoadAccessToken(accessToken)
153+
So(err, ShouldBeNil)
154+
So(tokenInfo, ShouldNotBeNil)
155+
So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken)
156+
So(tokenInfo.GetRefreshExpiresIn(), ShouldEqual, 0)
148157
}

manage/manager.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err err
464464
} else if ti == nil || ti.GetAccess() != access {
465465
err = errors.ErrInvalidAccessToken
466466
return
467-
} else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
467+
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
468+
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
468469
err = errors.ErrExpiredRefreshToken
469470
return
470471
} else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {

store/token.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (ts *TokenStore) Create(info oauth2.TokenInfo) (err error) {
6363
if err != nil {
6464
return
6565
}
66-
_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: true, TTL: aexp})
66+
_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
6767
return
6868
})
6969
return

0 commit comments

Comments
 (0)