Skip to content

Commit 69732ad

Browse files
committed
Fixed server package
1 parent 222cdc9 commit 69732ad

File tree

5 files changed

+67
-27
lines changed

5 files changed

+67
-27
lines changed

manage.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ type Manager interface {
2929
GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
3030

3131
// RefreshAccessToken 更新访问令牌
32-
// refresh 更新令牌
33-
// scope 作用域
34-
RefreshAccessToken(refresh, scope string) (accessToken TokenInfo, err error)
32+
// tgr 生成令牌的请求参数
33+
RefreshAccessToken(tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
3534

3635
// RemoveAccessToken 删除访问令牌
3736
// access 访问令牌

server/authorize.go

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package server
22

33
import (
4-
"encoding/base64"
54
"net/http"
6-
"strings"
75

86
"gopkg.in/oauth2.v2"
97
)
@@ -18,36 +16,40 @@ type AuthorizeRequest struct {
1816
UserID string
1917
}
2018

21-
// ClientHandler 获取客户端信息
19+
// ClientHandler 客户端处理(获取请求的客户端认证信息)
2220
type ClientHandler func(r *http.Request) (clientID, clientSecret string, err error)
2321

24-
// UserHandler 获取用户信息
22+
// UserHandler 用户处理(密码模式,根据用户名、密码获取用户标识)
2523
type UserHandler func(username, password string) (userID string, err error)
2624

25+
// ScopeHandler 授权范围处理(更新令牌时的授权范围检查)
26+
type ScopeHandler func(new, old string) (err error)
27+
28+
// TokenRequestHandler 令牌请求处理
29+
type TokenRequestHandler struct {
30+
ClientHandler ClientHandler
31+
UserHandler UserHandler
32+
ScopeHandler ScopeHandler
33+
}
34+
2735
// ClientFormHandler 客户端表单信息
2836
func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) {
2937
clientID = r.Form.Get("client_id")
3038
clientSecret = r.Form.Get("client_secret")
39+
if clientID == "" || clientSecret == "" {
40+
err = ErrAuthorizationFormInvalid
41+
}
3142
return
3243
}
3344

3445
// ClientBasicHandler 客户端基础认证信息
3546
func ClientBasicHandler(r *http.Request) (clientID, clientSecret string, err error) {
36-
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
37-
if len(s) != 2 || s[0] != "Basic" {
38-
err = ErrAuthorizationHeaderInvalid
39-
return
40-
}
41-
b, err := base64.StdEncoding.DecodeString(s[1])
42-
if err != nil {
43-
return
44-
}
45-
pair := strings.SplitN(string(b), ":", 2)
46-
if len(pair) != 2 {
47+
username, password, ok := r.BasicAuth()
48+
if !ok {
4749
err = ErrAuthorizationHeaderInvalid
4850
return
4951
}
50-
clientID = pair[0]
51-
clientSecret = pair[1]
52+
clientID = username
53+
clientSecret = password
5254
return
5355
}

server/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ type Config struct {
1010
AllowedResponseType []oauth2.ResponseType
1111
// AllowedGrantType 允许的授权模式(默认authorization_code)
1212
AllowedGrantType []oauth2.GrantType
13+
// Handler 令牌请求处理
14+
Handler *TokenRequestHandler
1315
}
1416

1517
// NewConfig 创建默认的配置参数
@@ -18,5 +20,6 @@ func NewConfig() *Config {
1820
TokenType: "Bearer",
1921
AllowedResponseType: []oauth2.ResponseType{oauth2.Code},
2022
AllowedGrantType: []oauth2.GrantType{oauth2.AuthorizationCodeCredentials},
23+
Handler: &TokenRequestHandler{},
2124
}
2225
}

server/error.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ var (
1818
// ErrUserInvalid User invalid
1919
ErrUserInvalid = errors.New("user invalid")
2020

21+
// ErrAuthorizationFormInvalid Authorization form invalid
22+
ErrAuthorizationFormInvalid = errors.New("authorization form invalid")
23+
2124
// ErrAuthorizationHeaderInvalid Authorization header invalid
2225
ErrAuthorizationHeaderInvalid = errors.New("authorization header invalid")
26+
27+
// ErrRefreshInvalid Refresh token invalid
28+
ErrRefreshInvalid = errors.New("refresh token invalid")
2329
)

server/server.go

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,21 @@ type Server struct {
2424
manager oauth2.Manager
2525
}
2626

27+
// SetClientHandler 设置客户端处理
28+
func (s *Server) SetClientHandler(handler ClientHandler) {
29+
s.cfg.Handler.ClientHandler = handler
30+
}
31+
32+
// SetUserHandler 设置用户处理
33+
func (s *Server) SetUserHandler(handler UserHandler) {
34+
s.cfg.Handler.UserHandler = handler
35+
}
36+
37+
// SetScopeHandler 设置授权范围处理
38+
func (s *Server) SetScopeHandler(handler ScopeHandler) {
39+
s.cfg.Handler.ScopeHandler = handler
40+
}
41+
2742
// checkResponseType 检查允许的授权类型
2843
func (s *Server) checkResponseType(rt oauth2.ResponseType) bool {
2944
for _, art := range s.cfg.AllowedResponseType {
@@ -95,7 +110,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, authReq *Authoriz
95110
// HandleTokenRequest 处理令牌请求
96111
// cli 获取客户端信息
97112
// user 获取用户信息
98-
func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch ClientHandler, uh UserHandler) (err error) {
113+
func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err error) {
99114
if r.Method != "POST" {
100115
err = ErrRequestMethodInvalid
101116
return
@@ -111,14 +126,10 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch C
111126
}
112127

113128
var ti oauth2.TokenInfo
114-
clientID, clientSecret, err := ch(r)
129+
clientID, clientSecret, err := s.cfg.Handler.ClientHandler(r)
115130
if err != nil {
116131
return
117132
}
118-
if clientID == "" || clientSecret == "" {
119-
err = ErrClientInvalid
120-
return
121-
}
122133
tgr := &oauth2.TokenGenerateRequest{
123134
ClientID: clientID,
124135
ClientSecret: clientSecret,
@@ -131,19 +142,38 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch C
131142
tgr.IsGenerateRefresh = true
132143
ti, err = s.manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, tgr)
133144
case oauth2.PasswordCredentials:
134-
userID, uerr := uh(r.Form.Get("username"), r.Form.Get("password"))
145+
userID, uerr := s.cfg.Handler.UserHandler(r.Form.Get("username"), r.Form.Get("password"))
135146
if uerr != nil {
136147
err = uerr
137148
return
138149
}
139150
tgr.UserID = userID
140151
tgr.Scope = r.Form.Get("scope")
141152
tgr.IsGenerateRefresh = true
153+
ti, err = s.manager.GenerateAccessToken(oauth2.PasswordCredentials, tgr)
142154
case oauth2.ClientCredentials:
143155
tgr.Scope = r.Form.Get("scope")
156+
ti, err = s.manager.GenerateAccessToken(oauth2.ClientCredentials, tgr)
144157
case oauth2.RefreshCredentials:
145158
tgr.Refresh = r.Form.Get("refresh_token")
146159
tgr.Scope = r.Form.Get("scope")
160+
if tgr.Scope != "" { // 检查授权范围
161+
rti, rerr := s.manager.LoadRefreshToken(tgr.Refresh)
162+
if rerr != nil {
163+
err = rerr
164+
return
165+
} else if rti.GetClientID() != tgr.ClientID {
166+
err = ErrRefreshInvalid
167+
return
168+
} else if verr := s.cfg.Handler.ScopeHandler(tgr.Scope, rti.GetScope()); verr != nil {
169+
err = verr
170+
return
171+
}
172+
}
173+
ti, err = s.manager.RefreshAccessToken(tgr)
174+
if err == nil {
175+
ti.SetRefresh("")
176+
}
147177
}
148178

149179
if err != nil {

0 commit comments

Comments
 (0)