Skip to content

Commit 6469c67

Browse files
committed
Add manager test file
1 parent 4c41505 commit 6469c67

File tree

3 files changed

+99
-13
lines changed

3 files changed

+99
-13
lines changed

manage/manage_test.go

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,96 @@ import (
44
"testing"
55

66
. "github.com/smartystreets/goconvey/convey"
7+
"gopkg.in/oauth2.v2"
78
"gopkg.in/oauth2.v2/generates"
89
"gopkg.in/oauth2.v2/models"
910
"gopkg.in/oauth2.v2/store/client"
1011
"gopkg.in/oauth2.v2/store/token"
1112
)
1213

1314
func TestManager(t *testing.T) {
14-
Convey("Manager Test", t, func() {
15+
Convey("Manager test", t, func() {
1516
manager := NewManager()
1617

1718
manager.MapClientModel(models.NewClient())
1819
manager.MapTokenModel(models.NewToken())
1920
manager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
2021
manager.MapAccessGenerate(generates.NewAccessGenerate())
2122
manager.MapClientStorage(client.NewTempStore())
22-
manager.MustTokenStorage(token.NewRedisStore(
23-
&token.RedisConfig{Addr: "192.168.33.70:6379"},
24-
))
23+
24+
Convey("GetClient test", func() {
25+
cli, err := manager.GetClient("1")
26+
So(err, ShouldBeNil)
27+
So(cli.GetSecret(), ShouldEqual, "11")
28+
})
29+
30+
Convey("Redis store test", func() {
31+
manager.MustTokenStorage(token.NewRedisStore(
32+
&token.RedisConfig{Addr: "192.168.33.70:6379"},
33+
))
34+
testManager(manager)
35+
})
36+
37+
Convey("MongoDB store test", func() {
38+
manager.MustTokenStorage(token.NewMongoStore(
39+
&token.MongoConfig{URL: "mongodb://admin:[email protected]:27017"},
40+
))
41+
testManager(manager)
42+
})
2543

2644
})
2745
}
46+
47+
func testManager(manager oauth2.Manager) {
48+
reqParams := &oauth2.TokenGenerateRequest{
49+
ClientID: "1",
50+
UserID: "123456",
51+
RedirectURI: "http://localhost/oauth2",
52+
Scope: "all",
53+
}
54+
code, err := manager.GenerateAuthToken(oauth2.Code, reqParams)
55+
So(err, ShouldBeNil)
56+
So(code, ShouldNotBeEmpty)
57+
58+
atParams := &oauth2.TokenGenerateRequest{
59+
ClientID: "1",
60+
RedirectURI: "http://localhost/oauth2",
61+
Code: code,
62+
IsGenerateRefresh: true,
63+
}
64+
accessToken, refreshToken, err := manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, atParams)
65+
So(err, ShouldBeNil)
66+
So(accessToken, ShouldNotBeEmpty)
67+
So(refreshToken, ShouldNotBeEmpty)
68+
69+
_, err = manager.LoadAccessToken(code)
70+
So(err, ShouldNotBeNil)
71+
72+
ainfo, err := manager.LoadAccessToken(accessToken)
73+
So(err, ShouldBeNil)
74+
So(ainfo.GetClientID(), ShouldEqual, atParams.ClientID)
75+
76+
rinfo, err := manager.LoadRefreshToken(refreshToken)
77+
So(err, ShouldBeNil)
78+
So(rinfo.GetClientID(), ShouldEqual, atParams.ClientID)
79+
80+
refreshAT, err := manager.RefreshAccessToken(refreshToken, "owner")
81+
So(err, ShouldBeNil)
82+
So(refreshAT, ShouldNotBeEmpty)
83+
84+
_, err = manager.LoadAccessToken(accessToken)
85+
So(err, ShouldNotBeNil)
86+
87+
refreshAInfo, err := manager.LoadAccessToken(refreshAT)
88+
So(err, ShouldBeNil)
89+
So(refreshAInfo.GetScope(), ShouldEqual, "owner")
90+
91+
err = manager.RemoveRefreshToken(refreshToken)
92+
So(err, ShouldBeNil)
93+
94+
_, err = manager.LoadAccessToken(refreshAT)
95+
So(err, ShouldNotBeNil)
96+
97+
_, err = manager.LoadRefreshToken(refreshToken)
98+
So(err, ShouldNotBeNil)
99+
}

manage/manager.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ type Config struct {
1717
func NewManager() *Manager {
1818
m := &Manager{
1919
injector: inject.New(),
20+
rtcfg: make(map[oauth2.ResponseType]*Config),
21+
gtcfg: make(map[oauth2.GrantType]*Config),
2022
}
2123
// 设定参数默认值
2224
// 设定授权码的有效期为10分钟
@@ -126,14 +128,17 @@ func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
126128

127129
// GetClient 获取客户端信息
128130
func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) {
129-
err = m.injector.Apply(func(stor oauth2.ClientStore) {
131+
_, ierr := m.injector.Invoke(func(stor oauth2.ClientStore) {
130132
cli, err = stor.GetByID(clientID)
131133
if err != nil {
132134
return
133135
} else if cli == nil {
134136
err = ErrClientNotFound
135137
}
136138
})
139+
if err == nil && ierr != nil {
140+
err = ierr
141+
}
137142
return
138143
}
139144

@@ -182,7 +187,7 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen
182187
// GenerateAccessToken 生成访问令牌、更新令牌
183188
// gt 授权模式
184189
// tgr 生成令牌的参数
185-
func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (token, refresh string, err error) {
190+
func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (access, refresh string, err error) {
186191
if gt == oauth2.AuthorizationCodeCredentials { // 授权码模式
187192
ti, terr := m.LoadAccessToken(tgr.Code)
188193
if terr != nil {
@@ -211,7 +216,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene
211216
UserID: tgr.UserID,
212217
CreateAt: time.Now(),
213218
}
214-
tv, rv, terr := gen.Token(td, tgr.IsGenerateRefresh)
219+
av, rv, terr := gen.Token(td, tgr.IsGenerateRefresh)
215220
if terr != nil {
216221
err = terr
217222
return
@@ -223,7 +228,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene
223228
ti.SetAuthType(gt.String())
224229
ti.SetAccessCreateAt(td.CreateAt)
225230
ti.SetAccessExpiresIn(m.gtcfg[gt].TokenExp)
226-
ti.SetAccess(tv)
231+
ti.SetAccess(av)
227232
if rv != "" {
228233
ti.SetRefreshCreateAt(td.CreateAt)
229234
ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshExp)
@@ -233,7 +238,8 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene
233238
if err != nil {
234239
return
235240
}
236-
token = tv
241+
access = av
242+
refresh = rv
237243
})
238244
if ierr != nil && err == nil {
239245
err = ierr
@@ -269,11 +275,11 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e
269275
if scope != "" {
270276
ti.SetScope(scope)
271277
}
272-
if verr := stor.Create(ti); verr != nil {
278+
if verr := stor.RemoveByAccess(access); verr != nil {
273279
err = verr
274280
return
275281
}
276-
if verr := stor.RemoveByAccess(access); verr != nil {
282+
if verr := stor.Create(ti); verr != nil {
277283
err = verr
278284
return
279285
}

store/token/redis.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,16 @@ func (rs *RedisStore) Create(info oauth2.TokenInfo) (err error) {
7070

7171
// remove
7272
func (rs *RedisStore) remove(key string) (err error) {
73-
del := rs.cli.Del(key)
74-
if verr := del.Err(); verr != nil {
73+
info, err := rs.get(key)
74+
if err != nil || info == nil {
75+
return
76+
}
77+
pipe := rs.cli.Pipeline()
78+
pipe.Del(info.GetAccess())
79+
if v := info.GetRefresh(); v != "" {
80+
pipe.Del(v)
81+
}
82+
if _, verr := pipe.Exec(); verr != nil {
7583
err = verr
7684
}
7785
return

0 commit comments

Comments
 (0)