Skip to content

Commit f23fea5

Browse files
committed
Add redis token store
1 parent 7c480ea commit f23fea5

File tree

11 files changed

+358
-31
lines changed

11 files changed

+358
-31
lines changed

generates/access_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ func TestAccess(t *testing.T) {
1313
Convey("Test Access Generate", t, func() {
1414
data := &oauth2.GenerateBasic{
1515
Client: &models.Client{
16-
ClientID: "123456",
17-
Secret: "123456",
16+
ID: "123456",
17+
Secret: "123456",
1818
},
1919
UserID: "000000",
2020
CreateAt: time.Now(),

generates/authorize_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ func TestAuthorize(t *testing.T) {
1313
Convey("Test Authorize Generate", t, func() {
1414
data := &oauth2.GenerateBasic{
1515
Client: &models.Client{
16-
ClientID: "123456",
17-
Secret: "123456",
16+
ID: "123456",
17+
Secret: "123456",
1818
},
1919
UserID: "000000",
2020
CreateAt: time.Now(),

manage/manager.go

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ func (m *Manager) MapTokenGenerate(gen oauth2.AccessGenerate) {
8787
}
8888

8989
// MapClientStorage 注入客户端信息存储接口
90-
func (m *Manager) MapClientStorage(stor oauth2.ClientStorage) {
90+
func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
9191
if stor == nil {
9292
panic(ErrNilValue)
9393
}
9494
m.injector.Map(stor)
9595
}
9696

9797
// MustClientStorage 强制注入客户端信息存储接口
98-
func (m *Manager) MustClientStorage(stor oauth2.ClientStorage, err error) {
98+
func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
9999
if err != nil {
100100
panic(err)
101101
}
@@ -106,15 +106,15 @@ func (m *Manager) MustClientStorage(stor oauth2.ClientStorage, err error) {
106106
}
107107

108108
// MapTokenStorage 注入令牌信息存储接口
109-
func (m *Manager) MapTokenStorage(stor oauth2.TokenStorage) {
109+
func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
110110
if stor == nil {
111111
panic(ErrNilValue)
112112
}
113113
m.injector.Map(stor)
114114
}
115115

116116
// MustTokenStorage 强制注入令牌信息存储接口
117-
func (m *Manager) MustTokenStorage(stor oauth2.TokenStorage, err error) {
117+
func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
118118
if err != nil {
119119
panic(err)
120120
}
@@ -126,7 +126,7 @@ func (m *Manager) MustTokenStorage(stor oauth2.TokenStorage, err error) {
126126

127127
// GetClient 获取客户端信息
128128
func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) {
129-
err = m.injector.Apply(func(stor oauth2.ClientStorage) {
129+
err = m.injector.Apply(func(stor oauth2.ClientStore) {
130130
cli, err = stor.GetByID(clientID)
131131
if err != nil {
132132
return
@@ -148,7 +148,7 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen
148148
err = verr
149149
return
150150
}
151-
_, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStorage) {
151+
_, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStore) {
152152
td := &oauth2.GenerateBasic{
153153
Client: cli,
154154
UserID: tgr.UserID,
@@ -191,7 +191,12 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene
191191
} else if ti.GetRedirectURI() != tgr.RedirectURI || ti.GetClientID() != tgr.ClientID {
192192
err = ErrAuthTokenInvalid
193193
return
194+
} else if verr := m.RemoveAccessToken(tgr.Code); verr != nil { // 删除授权码
195+
err = verr
196+
return
194197
}
198+
tgr.UserID = ti.GetUserID()
199+
tgr.Scope = ti.GetScope()
195200
}
196201
cli, err := m.GetClient(tgr.ClientID)
197202
if err != nil {
@@ -200,7 +205,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene
200205
err = ErrClientInvalid
201206
return
202207
}
203-
_, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AccessGenerate, stor oauth2.TokenStorage) {
208+
_, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AccessGenerate, stor oauth2.TokenStore) {
204209
td := &oauth2.GenerateBasic{
205210
Client: cli,
206211
UserID: tgr.UserID,
@@ -242,7 +247,7 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e
242247
if err != nil {
243248
return
244249
}
245-
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage, gen oauth2.AccessGenerate) {
250+
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStore, gen oauth2.AccessGenerate) {
246251
cli, cerr := m.GetClient(ti.GetClientID())
247252
if cerr != nil {
248253
err = cerr
@@ -285,7 +290,7 @@ func (m *Manager) RemoveAccessToken(access string) (err error) {
285290
err = ErrAccessInvalid
286291
return
287292
}
288-
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) {
293+
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) {
289294
err = stor.RemoveByAccess(access)
290295
})
291296
if ierr != nil && err == nil {
@@ -300,7 +305,7 @@ func (m *Manager) RemoveRefreshToken(refresh string) (err error) {
300305
err = ErrAccessInvalid
301306
return
302307
}
303-
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) {
308+
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) {
304309
err = stor.RemoveByRefresh(refresh)
305310
})
306311
if ierr != nil && err == nil {
@@ -315,7 +320,7 @@ func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err err
315320
err = ErrAccessInvalid
316321
return
317322
}
318-
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) {
323+
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) {
319324
ct := time.Now()
320325
ti, terr := stor.GetByAccess(access)
321326
if terr != nil {
@@ -355,7 +360,7 @@ func (m *Manager) LoadRefreshToken(refresh string) (info oauth2.TokenInfo, err e
355360
err = ErrRefreshInvalid
356361
return
357362
}
358-
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) {
363+
_, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) {
359364
ti, terr := stor.GetByRefresh(refresh)
360365
if terr != nil {
361366
err = terr

model.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ type (
1212
GetSecret() string
1313
// 客户端域名URL
1414
GetDomain() string
15-
// Other data
16-
GetOtherData() interface{}
15+
// 用户数据
16+
GetUserData() interface{}
1717
}
1818

1919
// TokenInfo 令牌信息模型接口

models/client.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ package models
22

33
// Client 客户端信息
44
type Client struct {
5-
ClientID string // 客户端ID
6-
Secret string // 密钥
7-
Domain string // 域名url
5+
ID string // 客户端ID
6+
Secret string // 密钥
7+
Domain string // 域名url
88
}
99

1010
// GetID 客户端ID
1111
func (c *Client) GetID() string {
12-
return c.ClientID
12+
return c.ID
1313
}
1414

1515
// GetSecret 客户端秘钥
@@ -22,7 +22,7 @@ func (c *Client) GetDomain() string {
2222
return c.Domain
2323
}
2424

25-
// GetOtherData Other data
26-
func (c *Client) GetOtherData() interface{} {
25+
// GetUserData 用户数据
26+
func (c *Client) GetUserData() interface{} {
2727
return nil
2828
}

storages/.gitkeep

Whitespace-only changes.

storage.go renamed to store.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ package oauth2
22

33
// 提供存储接口
44
type (
5-
// ClientStorage 客户端信息存储接口
6-
ClientStorage interface {
5+
// ClientStore 客户端信息存储接口
6+
ClientStore interface {
77
// GetByID 根据ID获取客户端信息
88
GetByID(id string) (ClientInfo, error)
99
}
1010

11-
// TokenStorage 令牌信息存储接口
12-
TokenStorage interface {
11+
// TokenStore 令牌信息存储接口
12+
TokenStore interface {
1313
// Create 创建并存储新的令牌信息
1414
Create(info TokenInfo) error
1515

@@ -19,9 +19,6 @@ type (
1919
// RemoveByRefresh 使用更新令牌删除令牌信息
2020
RemoveByRefresh(refresh string) error
2121

22-
// 使用访问令牌取出令牌信息数据(获取并删除)
23-
TakeByAccess(access string) (TokenInfo, error)
24-
2522
// 使用访问令牌获取令牌信息数据
2623
GetByAccess(access string) (TokenInfo, error)
2724

store/client/temp.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package client
2+
3+
import (
4+
"errors"
5+
6+
"gopkg.in/oauth2.v2"
7+
"gopkg.in/oauth2.v2/models"
8+
)
9+
10+
// NewTempStore 创建客户端临时存储实例
11+
func NewTempStore() oauth2.ClientStore {
12+
return &TempStore{
13+
data: map[string]*models.Client{
14+
"1": &models.Client{
15+
ID: "1",
16+
Secret: "11",
17+
Domain: "http://localhost",
18+
},
19+
},
20+
}
21+
}
22+
23+
// TempStore 客户端信息的临时存储
24+
type TempStore struct {
25+
data map[string]*models.Client
26+
}
27+
28+
// GetByID 获取客户端信息
29+
func (ts *TempStore) GetByID(id string) (cli oauth2.ClientInfo, err error) {
30+
if c, ok := ts.data[id]; ok {
31+
cli = c
32+
return
33+
}
34+
err = errors.New("not found")
35+
return
36+
}

store/token/redis.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package token
2+
3+
import (
4+
"encoding/json"
5+
6+
"gopkg.in/oauth2.v2"
7+
"gopkg.in/oauth2.v2/models"
8+
"gopkg.in/redis.v4"
9+
)
10+
11+
// NewRedisStore 创建redis存储的实例
12+
func NewRedisStore(cfg *RedisConfig) (store oauth2.TokenStore, err error) {
13+
opt := &redis.Options{
14+
Network: cfg.Network,
15+
Addr: cfg.Addr,
16+
Password: cfg.Password,
17+
DB: cfg.DB,
18+
MaxRetries: cfg.MaxRetries,
19+
DialTimeout: cfg.DialTimeout,
20+
ReadTimeout: cfg.ReadTimeout,
21+
WriteTimeout: cfg.WriteTimeout,
22+
PoolSize: cfg.PoolSize,
23+
PoolTimeout: cfg.PoolTimeout,
24+
}
25+
cli := redis.NewClient(opt)
26+
if verr := cli.Ping().Err(); verr != nil {
27+
err = verr
28+
return
29+
}
30+
store = &RedisStore{cli: cli}
31+
return
32+
}
33+
34+
// RedisStore 令牌的redis存储
35+
type RedisStore struct {
36+
cli *redis.Client
37+
}
38+
39+
// Create 存储令牌信息
40+
func (rs *RedisStore) Create(info oauth2.TokenInfo) (err error) {
41+
jv, err := json.Marshal(info)
42+
if err != nil {
43+
return
44+
}
45+
pipe := rs.cli.Pipeline()
46+
47+
aexp := info.GetAccessExpiresIn()
48+
if refresh := info.GetRefresh(); refresh != "" {
49+
exp := info.GetRefreshExpiresIn()
50+
ttl := rs.cli.TTL(refresh)
51+
if verr := ttl.Err(); verr != nil {
52+
err = verr
53+
return
54+
}
55+
if v := ttl.Val(); v.Seconds() > 0 {
56+
exp = v
57+
}
58+
if aexp.Seconds() > exp.Seconds() {
59+
aexp = exp
60+
}
61+
pipe.Set(refresh, jv, exp)
62+
}
63+
pipe.Set(info.GetAccess(), jv, aexp)
64+
65+
if _, verr := pipe.Exec(); verr != nil {
66+
err = verr
67+
}
68+
return
69+
}
70+
71+
// remove
72+
func (rs *RedisStore) remove(key string) (err error) {
73+
del := rs.cli.Del(key)
74+
if verr := del.Err(); verr != nil {
75+
err = verr
76+
}
77+
return
78+
}
79+
80+
// RemoveByAccess 移除令牌
81+
func (rs *RedisStore) RemoveByAccess(access string) (err error) {
82+
err = rs.remove(access)
83+
return
84+
}
85+
86+
// RemoveByRefresh 移除令牌
87+
func (rs *RedisStore) RemoveByRefresh(refresh string) (err error) {
88+
err = rs.remove(refresh)
89+
return
90+
}
91+
92+
func (rs *RedisStore) get(key string) (ti oauth2.TokenInfo, err error) {
93+
gv, gerr := rs.cli.Get(key).Result()
94+
if gerr != nil {
95+
if gerr == redis.Nil {
96+
return
97+
}
98+
err = gerr
99+
return
100+
}
101+
var tm models.Token
102+
if verr := json.Unmarshal([]byte(gv), &tm); verr != nil {
103+
err = verr
104+
return
105+
}
106+
ti = &tm
107+
return
108+
}
109+
110+
// GetByAccess 获取令牌数据
111+
func (rs *RedisStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) {
112+
ti, err = rs.get(access)
113+
return
114+
}
115+
116+
// GetByRefresh 获取令牌数据
117+
func (rs *RedisStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) {
118+
ti, err = rs.get(refresh)
119+
return
120+
}

0 commit comments

Comments
 (0)