@@ -24,6 +24,21 @@ type Server struct {
2424 manager oauth2.Manager
2525}
2626
27+ // SetClientHandler 设置客户端处理
28+ func (s * Server ) SetClientHandler (handler ClientHandler ) {
29+ s .cfg .Handler .ClientHandler = handler
30+ }
31+
32+ // SetUserHandler 设置用户处理
33+ func (s * Server ) SetUserHandler (handler UserHandler ) {
34+ s .cfg .Handler .UserHandler = handler
35+ }
36+
37+ // SetScopeHandler 设置授权范围处理
38+ func (s * Server ) SetScopeHandler (handler ScopeHandler ) {
39+ s .cfg .Handler .ScopeHandler = handler
40+ }
41+
2742// checkResponseType 检查允许的授权类型
2843func (s * Server ) checkResponseType (rt oauth2.ResponseType ) bool {
2944 for _ , art := range s .cfg .AllowedResponseType {
@@ -95,7 +110,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, authReq *Authoriz
95110// HandleTokenRequest 处理令牌请求
96111// cli 获取客户端信息
97112// user 获取用户信息
98- func (s * Server ) HandleTokenRequest (w http.ResponseWriter , r * http.Request , ch ClientHandler , uh UserHandler ) (err error ) {
113+ func (s * Server ) HandleTokenRequest (w http.ResponseWriter , r * http.Request ) (err error ) {
99114 if r .Method != "POST" {
100115 err = ErrRequestMethodInvalid
101116 return
@@ -111,14 +126,10 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch C
111126 }
112127
113128 var ti oauth2.TokenInfo
114- clientID , clientSecret , err := ch (r )
129+ clientID , clientSecret , err := s . cfg . Handler . ClientHandler (r )
115130 if err != nil {
116131 return
117132 }
118- if clientID == "" || clientSecret == "" {
119- err = ErrClientInvalid
120- return
121- }
122133 tgr := & oauth2.TokenGenerateRequest {
123134 ClientID : clientID ,
124135 ClientSecret : clientSecret ,
@@ -131,19 +142,38 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch C
131142 tgr .IsGenerateRefresh = true
132143 ti , err = s .manager .GenerateAccessToken (oauth2 .AuthorizationCodeCredentials , tgr )
133144 case oauth2 .PasswordCredentials :
134- userID , uerr := uh (r .Form .Get ("username" ), r .Form .Get ("password" ))
145+ userID , uerr := s . cfg . Handler . UserHandler (r .Form .Get ("username" ), r .Form .Get ("password" ))
135146 if uerr != nil {
136147 err = uerr
137148 return
138149 }
139150 tgr .UserID = userID
140151 tgr .Scope = r .Form .Get ("scope" )
141152 tgr .IsGenerateRefresh = true
153+ ti , err = s .manager .GenerateAccessToken (oauth2 .PasswordCredentials , tgr )
142154 case oauth2 .ClientCredentials :
143155 tgr .Scope = r .Form .Get ("scope" )
156+ ti , err = s .manager .GenerateAccessToken (oauth2 .ClientCredentials , tgr )
144157 case oauth2 .RefreshCredentials :
145158 tgr .Refresh = r .Form .Get ("refresh_token" )
146159 tgr .Scope = r .Form .Get ("scope" )
160+ if tgr .Scope != "" { // 检查授权范围
161+ rti , rerr := s .manager .LoadRefreshToken (tgr .Refresh )
162+ if rerr != nil {
163+ err = rerr
164+ return
165+ } else if rti .GetClientID () != tgr .ClientID {
166+ err = ErrRefreshInvalid
167+ return
168+ } else if verr := s .cfg .Handler .ScopeHandler (tgr .Scope , rti .GetScope ()); verr != nil {
169+ err = verr
170+ return
171+ }
172+ }
173+ ti , err = s .manager .RefreshAccessToken (tgr )
174+ if err == nil {
175+ ti .SetRefresh ("" )
176+ }
147177 }
148178
149179 if err != nil {
0 commit comments