Skip to content

Commit b5f24be

Browse files
committed
Add mongodb token store
1 parent f23fea5 commit b5f24be

File tree

7 files changed

+334
-118
lines changed

7 files changed

+334
-118
lines changed

manage/manager.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e
247247
if err != nil {
248248
return
249249
}
250+
access := ti.GetAccess()
250251
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStore, gen oauth2.AccessGenerate) {
251252
cli, cerr := m.GetClient(ti.GetClientID())
252253
if cerr != nil {
@@ -272,7 +273,7 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e
272273
err = verr
273274
return
274275
}
275-
if verr := stor.RemoveByRefresh(refresh); verr != nil {
276+
if verr := stor.RemoveByAccess(access); verr != nil {
276277
err = verr
277278
return
278279
}
@@ -330,19 +331,8 @@ func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err err
330331
err = ErrAccessInvalid
331332
return
332333
} else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查更新令牌是否过期
333-
// 删除过期的访问令牌
334-
if verr := stor.RemoveByRefresh(ti.GetRefresh()); verr != nil {
335-
err = verr
336-
return
337-
}
338334
err = ErrRefreshExpired
339335
} else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { // 检查访问令牌是否过期
340-
if ti.GetRefresh() == "" { // 删除过期的访问令牌
341-
if verr := stor.RemoveByAccess(access); verr != nil {
342-
err = verr
343-
return
344-
}
345-
}
346336
err = ErrAccessExpired
347337
return
348338
}
@@ -369,11 +359,6 @@ func (m *Manager) LoadRefreshToken(refresh string) (info oauth2.TokenInfo, err e
369359
err = ErrRefreshInvalid
370360
return
371361
} else if ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
372-
// 删除过期的更新令牌
373-
if verr := stor.RemoveByRefresh(refresh); verr != nil {
374-
err = verr
375-
return
376-
}
377362
err = ErrRefreshExpired
378363
return
379364
}

models/token.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ import "time"
44

55
// Token 令牌信息
66
type Token struct {
7-
ClientID string // 客户端标识
8-
UserID string // 用户标识
9-
RedirectURI string // 重定向URI
10-
Scope string // 权限范围
11-
AuthType string // 令牌授权类型
12-
Access string // 访问令牌
13-
AccessCreateAt time.Time // 访问令牌创建时间
14-
AccessExpiresIn time.Duration // 访问令牌有效期
15-
Refresh string // 更新令牌
16-
RefreshCreateAt time.Time // 更新令牌创建时间
17-
RefreshExpiresIn time.Duration // 更新令牌有效期
7+
ClientID string `bson:"ClientID"` // 客户端标识
8+
UserID string `bson:"UserID"` // 用户标识
9+
RedirectURI string `bson:"RedirectURI"` // 重定向URI
10+
Scope string `bson:"Scope"` // 权限范围
11+
AuthType string `bson:"AuthType"` // 令牌授权类型
12+
Access string `bson:"Access"` // 访问令牌
13+
AccessCreateAt time.Time `bson:"AccessCreateAt"` // 访问令牌创建时间
14+
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"` // 访问令牌有效期
15+
Refresh string `bson:"Refresh"` // 更新令牌
16+
RefreshCreateAt time.Time `bson:"RefreshCreateAt"` // 更新令牌创建时间
17+
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` // 更新令牌有效期
1818
}
1919

2020
// GetClientID 客户端ID

store/token/mongo.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package token
2+
3+
import (
4+
"time"
5+
6+
"gopkg.in/LyricTian/lib.v2/mongo"
7+
"gopkg.in/mgo.v2"
8+
"gopkg.in/mgo.v2/bson"
9+
"gopkg.in/oauth2.v2"
10+
"gopkg.in/oauth2.v2/models"
11+
)
12+
13+
// MongoConfig MongoDB Configuration
14+
type MongoConfig struct {
15+
// Connection String
16+
URL string
17+
// DB Name(default oauth2)
18+
DB string
19+
// Collection Name(default tokens)
20+
C string
21+
}
22+
23+
// NewMongoStore 创建MongoDB的令牌存储
24+
func NewMongoStore(cfg *MongoConfig) (store oauth2.TokenStore, err error) {
25+
if cfg.DB == "" {
26+
cfg.DB = "oauth2"
27+
}
28+
if cfg.C == "" {
29+
cfg.C = "tokens"
30+
}
31+
handler, err := mongo.InitHandlerWithDB(cfg.URL, cfg.DB)
32+
if err != nil {
33+
return
34+
}
35+
// 创建自动过期索引
36+
err = handler.C(cfg.C).EnsureIndex(mgo.Index{
37+
Key: []string{"ExpiredAt"},
38+
ExpireAfter: time.Second,
39+
})
40+
if err != nil {
41+
return
42+
}
43+
err = handler.C(cfg.C).EnsureIndexKey("Access")
44+
if err != nil {
45+
return
46+
}
47+
err = handler.C(cfg.C).EnsureIndexKey("Refresh")
48+
if err != nil {
49+
return
50+
}
51+
store = &MongoStore{
52+
handler: handler,
53+
cfg: cfg,
54+
}
55+
return
56+
}
57+
58+
// MongoStore MongoDB Store
59+
type MongoStore struct {
60+
cfg *MongoConfig
61+
handler *mongo.Handler
62+
}
63+
64+
// Create 存储令牌信息
65+
func (ms *MongoStore) Create(info oauth2.TokenInfo) (err error) {
66+
tm := info.(*models.Token)
67+
var expiredAt time.Time
68+
if refresh := tm.Refresh; refresh != "" {
69+
expiredAt = tm.RefreshCreateAt.Add(tm.RefreshExpiresIn)
70+
rinfo, rerr := ms.GetByRefresh(refresh)
71+
if rerr != nil {
72+
err = rerr
73+
return
74+
}
75+
if rinfo != nil {
76+
expiredAt = rinfo.GetRefreshCreateAt().Add(rinfo.GetRefreshExpiresIn())
77+
}
78+
}
79+
if expiredAt.IsZero() {
80+
expiredAt = tm.AccessCreateAt.Add(tm.AccessExpiresIn)
81+
}
82+
doc := map[string]interface{}{
83+
"ExpiredAt": expiredAt,
84+
"ClientID": tm.ClientID,
85+
"UserID": tm.UserID,
86+
"RedirectURI": tm.RedirectURI,
87+
"Scope": tm.Scope,
88+
"AuthType": tm.AuthType,
89+
"Access": tm.Access,
90+
"AccessCreateAt": tm.AccessCreateAt,
91+
"AccessExpiresIn": tm.AccessExpiresIn,
92+
"Refresh": tm.Refresh,
93+
"RefreshCreateAt": tm.RefreshCreateAt,
94+
"RefreshExpiresIn": tm.RefreshExpiresIn,
95+
}
96+
97+
ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) {
98+
err = c.Insert(doc)
99+
})
100+
return
101+
}
102+
103+
func (ms *MongoStore) remove(selector interface{}) (err error) {
104+
ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) {
105+
err = c.Remove(selector)
106+
})
107+
return
108+
}
109+
110+
// RemoveByAccess 移除令牌
111+
func (ms *MongoStore) RemoveByAccess(access string) (err error) {
112+
err = ms.remove(bson.M{"Access": access})
113+
return
114+
}
115+
116+
// RemoveByRefresh 移除令牌
117+
func (ms *MongoStore) RemoveByRefresh(refresh string) (err error) {
118+
err = ms.remove(bson.M{"Refresh": refresh})
119+
return
120+
}
121+
122+
func (ms *MongoStore) get(find interface{}) (info oauth2.TokenInfo, err error) {
123+
ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) {
124+
var tm models.Token
125+
aerr := c.Find(find).Select(bson.M{"_id": 0}).One(&tm)
126+
if aerr != nil {
127+
if aerr == mgo.ErrNotFound {
128+
return
129+
}
130+
err = aerr
131+
return
132+
}
133+
info = &tm
134+
})
135+
return
136+
}
137+
138+
// GetByAccess 获取令牌数据
139+
func (ms *MongoStore) GetByAccess(access string) (info oauth2.TokenInfo, err error) {
140+
info, err = ms.get(bson.M{"Access": access})
141+
return
142+
}
143+
144+
// GetByRefresh 获取令牌数据
145+
func (ms *MongoStore) GetByRefresh(refresh string) (info oauth2.TokenInfo, err error) {
146+
info, err = ms.get(bson.M{"Refresh": refresh})
147+
return
148+
}

store/token/mongo_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package token
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/smartystreets/goconvey/convey"
7+
)
8+
9+
const (
10+
mongoURL = "mongodb://admin:[email protected]:27017"
11+
)
12+
13+
func TestMongoStore(t *testing.T) {
14+
Convey("Test mongo store", t, func() {
15+
cfg := &MongoConfig{
16+
URL: mongoURL,
17+
}
18+
store, err := NewMongoStore(cfg)
19+
So(err, ShouldBeNil)
20+
21+
Convey("Test mongo store access", func() {
22+
testAccessStore(store)
23+
})
24+
25+
Convey("Test mongo store refresh", func() {
26+
testRefreshStore(store)
27+
})
28+
})
29+
}
30+
31+
func TestMongoStoreAccessExpired(t *testing.T) {
32+
Convey("Test mongo store access token expired", t, func() {
33+
cfg := &MongoConfig{
34+
URL: mongoURL,
35+
}
36+
store, err := NewMongoStore(cfg)
37+
So(err, ShouldBeNil)
38+
testAccessExpired(store)
39+
})
40+
}
41+
42+
func TestMongoStoreRefreshExpired(t *testing.T) {
43+
Convey("Test mongo store refresh token expired", t, func() {
44+
cfg := &MongoConfig{
45+
URL: mongoURL,
46+
}
47+
store, err := NewMongoStore(cfg)
48+
So(err, ShouldBeNil)
49+
testRefreshExpired(store)
50+
})
51+
}

store/token/redis_config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package token
22

33
import "time"
44

5-
// RedisConfig Redis配置参数
5+
// RedisConfig Redis Configuration
66
type RedisConfig struct {
77
// The network type, either tcp or unix.
88
// Default is tcp.

0 commit comments

Comments
 (0)