diff --git a/example/server/server.go b/example/server/server.go index eea240d..7ecc124 100644 --- a/example/server/server.go +++ b/example/server/server.go @@ -75,12 +75,12 @@ func main() { srv.SetUserAuthorizationHandler(userAuthorizeHandler) - srv.SetInternalErrorHandler(func(err error) (re *errors.Response) { + srv.SetInternalErrorHandler(func(ctx context.Context, err error) (re *errors.Response) { log.Println("Internal Error:", err.Error()) return }) - srv.SetResponseErrorHandler(func(re *errors.Response) { + srv.SetResponseErrorHandler(func(ctx context.Context, re *errors.Response) { log.Println("Response Error:", re.Error.Error()) }) diff --git a/server/handler.go b/server/handler.go index 81d54a1..638fad3 100755 --- a/server/handler.go +++ b/server/handler.go @@ -33,13 +33,13 @@ type ( RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error) // ResponseErrorHandler response error handing - ResponseErrorHandler func(re *errors.Response) + ResponseErrorHandler func(ctx context.Context, re *errors.Response) // InternalErrorHandler internal error handing - InternalErrorHandler func(err error) (re *errors.Response) + InternalErrorHandler func(ctx context.Context, err error) (re *errors.Response) // PreRedirectErrorHandler is used to override "redirect-on-error" behavior - PreRedirectErrorHandler func(w http.ResponseWriter, req *AuthorizeRequest, err error) error + PreRedirectErrorHandler func(ctx context.Context, w http.ResponseWriter, req *AuthorizeRequest, err error) error // AuthorizeScopeHandler set the authorized scope AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) diff --git a/server/server.go b/server/server.go index f4dba2d..d1cfae6 100755 --- a/server/server.go +++ b/server/server.go @@ -63,7 +63,7 @@ type Server struct { func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if fn := s.PreRedirectErrorHandler; fn != nil { - return fn(w, req, err) + return fn(req.Request.Context(), w, req, err) } return s.redirectError(w, req, err) @@ -74,7 +74,7 @@ func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err return err } - data, _, _ := s.GetErrorData(err) + data, _, _ := s.GetErrorData(req.Request.Context(), err) return s.redirect(w, req, data) } @@ -89,8 +89,8 @@ func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map return nil } -func (s *Server) tokenError(w http.ResponseWriter, err error) error { - data, statusCode, header := s.GetErrorData(err) +func (s *Server) tokenError(ctx context.Context, w http.ResponseWriter, err error) error { + data, statusCode, header := s.GetErrorData(ctx, err) return s.token(w, data, header, statusCode) } @@ -511,19 +511,19 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) erro gt, tgr, err := s.ValidationTokenRequest(r) if err != nil { - return s.tokenError(w, err) + return s.tokenError(ctx, w, err) } ti, err := s.GetAccessToken(ctx, gt, tgr) if err != nil { - return s.tokenError(w, err) + return s.tokenError(ctx, w, err) } return s.token(w, s.GetTokenData(ti), nil) } // GetErrorData get error response data -func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { +func (s *Server) GetErrorData(ctx context.Context, err error) (map[string]interface{}, int, http.Header) { var re errors.Response if v, ok := errors.Descriptions[err]; ok { re.Error = err @@ -531,7 +531,7 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head re.StatusCode = errors.StatusCodes[err] } else { if fn := s.InternalErrorHandler; fn != nil { - if v := fn(err); v != nil { + if v := fn(ctx, err); v != nil { re = *v } } @@ -544,7 +544,7 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head } if fn := s.ResponseErrorHandler; fn != nil { - fn(&re) + fn(ctx, &re) } data := make(map[string]interface{}) diff --git a/server/server_test.go b/server/server_test.go index 1eb0bcd..d110150 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -296,12 +296,12 @@ func TestClientCredentials(t *testing.T) { srv = server.NewDefaultServer(manager) srv.SetClientInfoHandler(server.ClientFormHandler) - srv.SetInternalErrorHandler(func(err error) (re *errors.Response) { + srv.SetInternalErrorHandler(func(ctx context.Context, err error) (re *errors.Response) { t.Log("OAuth 2.0 Error:", err.Error()) return }) - srv.SetResponseErrorHandler(func(re *errors.Response) { + srv.SetResponseErrorHandler(func(ctx context.Context, re *errors.Response) { t.Log("Response Error:", re.Error) })