Skip to content

Commit b565f7c

Browse files
authored
fix: authentication middleware is implemented by changing from framework droplet to framework gin (#2254)
1 parent ffa596d commit b565f7c

File tree

15 files changed

+163
-195
lines changed

15 files changed

+163
-195
lines changed

api/internal/core/server/http.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import (
2727

2828
"github.com/apisix/manager-api/internal"
2929
"github.com/apisix/manager-api/internal/conf"
30-
"github.com/apisix/manager-api/internal/filter"
3130
"github.com/apisix/manager-api/internal/handler"
3231
)
3332

@@ -37,7 +36,7 @@ func (s *server) setupAPI() {
3736
var newMws []droplet.Middleware
3837
// default middleware order: resp_reshape, auto_input, traffic_log
3938
// We should put err_transform at second to catch all error
40-
newMws = append(newMws, mws[0], &handler.ErrorTransformMiddleware{}, &filter.AuthenticationMiddleware{})
39+
newMws = append(newMws, mws[0], &handler.ErrorTransformMiddleware{})
4140
newMws = append(newMws, mws[1:]...)
4241
return newMws
4342
}

api/internal/core/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type server struct {
3535
options *Options
3636
}
3737

38-
type Options struct {}
38+
type Options struct{}
3939

4040
// NewServer Create a server manager
4141
func NewServer(options *Options) (*server, error) {

api/internal/filter/authentication.go

Lines changed: 53 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,84 +17,67 @@
1717
package filter
1818

1919
import (
20-
"errors"
2120
"net/http"
2221
"strings"
2322

2423
"github.com/dgrijalva/jwt-go"
25-
"github.com/shiningrush/droplet"
26-
"github.com/shiningrush/droplet/data"
27-
"github.com/shiningrush/droplet/middleware"
24+
"github.com/gin-gonic/gin"
2825

2926
"github.com/apisix/manager-api/internal/conf"
3027
"github.com/apisix/manager-api/internal/log"
3128
)
3229

33-
type AuthenticationMiddleware struct {
34-
middleware.BaseMiddleware
35-
}
36-
37-
func (mw *AuthenticationMiddleware) Handle(ctx droplet.Context) error {
38-
httpReq := ctx.Get(middleware.KeyHttpRequest)
39-
if httpReq == nil {
40-
err := errors.New("input middleware cannot get http request")
41-
42-
// Wrong usage, just panic here and let recoverHandler to deal with
43-
panic(err)
44-
}
45-
46-
req := httpReq.(*http.Request)
47-
48-
if req.URL.Path == "/apisix/admin/tool/version" || req.URL.Path == "/apisix/admin/user/login" {
49-
return mw.BaseMiddleware.Handle(ctx)
50-
}
51-
52-
if !strings.HasPrefix(req.URL.Path, "/apisix") {
53-
return mw.BaseMiddleware.Handle(ctx)
54-
}
55-
56-
// Need check the auth header
57-
tokenStr := req.Header.Get("Authorization")
58-
59-
// verify token
60-
token, err := jwt.ParseWithClaims(tokenStr, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
61-
return []byte(conf.AuthConf.Secret), nil
62-
})
63-
64-
// TODO: design the response error code
65-
response := data.Response{Code: 010013, Message: "request unauthorized"}
66-
67-
if err != nil || token == nil || !token.Valid {
68-
log.Warnf("token validate failed: %s", err)
69-
log.Warn("please check the secret in conf.yaml")
70-
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
71-
return nil
72-
}
73-
74-
claims, ok := token.Claims.(*jwt.StandardClaims)
75-
if !ok {
76-
log.Warnf("token validate failed: %s, %v", err, token.Valid)
77-
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
78-
return nil
79-
}
80-
81-
if err := token.Claims.Valid(); err != nil {
82-
log.Warnf("token claims validate failed: %s", err)
83-
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
84-
return nil
30+
func Authentication() gin.HandlerFunc {
31+
return func(c *gin.Context) {
32+
if c.Request.URL.Path == "/apisix/admin/user/login" ||
33+
c.Request.URL.Path == "/apisix/admin/tool/version" ||
34+
!strings.HasPrefix(c.Request.URL.Path, "/apisix") {
35+
c.Next()
36+
return
37+
}
38+
39+
tokenStr := c.GetHeader("Authorization")
40+
// verify token
41+
token, err := jwt.ParseWithClaims(tokenStr, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
42+
return []byte(conf.AuthConf.Secret), nil
43+
})
44+
45+
errResp := gin.H{
46+
"code": 010013,
47+
"message": "request unauthorized",
48+
}
49+
50+
if err != nil || token == nil || !token.Valid {
51+
log.Warnf("token validate failed: %s", err)
52+
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
53+
return
54+
}
55+
56+
claims, ok := token.Claims.(*jwt.StandardClaims)
57+
if !ok {
58+
log.Warnf("token validate failed: %s, %v", err, token.Valid)
59+
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
60+
return
61+
}
62+
63+
if err := token.Claims.Valid(); err != nil {
64+
log.Warnf("token claims validate failed: %s", err)
65+
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
66+
return
67+
}
68+
69+
if claims.Subject == "" {
70+
log.Warn("token claims subject empty")
71+
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
72+
return
73+
}
74+
75+
if _, ok := conf.UserList[claims.Subject]; !ok {
76+
log.Warnf("user not exists by token claims subject %s", claims.Subject)
77+
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
78+
return
79+
}
80+
81+
c.Next()
8582
}
86-
87-
if claims.Subject == "" {
88-
log.Warn("token claims subject empty")
89-
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
90-
return nil
91-
}
92-
93-
if _, ok := conf.UserList[claims.Subject]; !ok {
94-
log.Warnf("user not exists by token claims subject %s", claims.Subject)
95-
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
96-
return nil
97-
}
98-
99-
return mw.BaseMiddleware.Handle(ctx)
10083
}

api/internal/filter/authentication_test.go

Lines changed: 22 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,12 @@
1717
package filter
1818

1919
import (
20-
"errors"
2120
"net/http"
22-
"net/url"
2321
"testing"
2422
"time"
2523

2624
"github.com/dgrijalva/jwt-go"
27-
"github.com/shiningrush/droplet"
28-
"github.com/shiningrush/droplet/data"
29-
"github.com/shiningrush/droplet/middleware"
25+
"github.com/gin-gonic/gin"
3026
"github.com/stretchr/testify/assert"
3127

3228
"github.com/apisix/manager-api/internal/conf"
@@ -44,73 +40,35 @@ func genToken(username string, issueAt, expireAt int64) string {
4440
return signedToken
4541
}
4642

47-
type mockMiddleware struct {
48-
middleware.BaseMiddleware
49-
}
50-
51-
func (mw *mockMiddleware) Handle(ctx droplet.Context) error {
52-
return errors.New("next middleware")
53-
}
54-
55-
func testPanic(t *testing.T, mw AuthenticationMiddleware, ctx droplet.Context) {
56-
defer func() {
57-
panicErr := recover()
58-
assert.Contains(t, panicErr.(error).Error(), "input middleware cannot get http request")
59-
}()
60-
_ = mw.Handle(ctx)
61-
}
62-
6343
func TestAuthenticationMiddleware_Handle(t *testing.T) {
64-
ctx := droplet.NewContext()
65-
fakeReq, _ := http.NewRequest(http.MethodGet, "", nil)
66-
expectOutput := &data.SpecCodeResponse{
67-
Response: data.Response{
68-
Code: 010013,
69-
Message: "request unauthorized",
70-
},
71-
StatusCode: http.StatusUnauthorized,
72-
}
73-
74-
mw := AuthenticationMiddleware{}
75-
mockMw := mockMiddleware{}
76-
mw.SetNext(&mockMw)
77-
78-
// test without http.Request
79-
testPanic(t, mw, ctx)
80-
81-
ctx.Set(middleware.KeyHttpRequest, fakeReq)
44+
r := gin.New()
45+
r.Use(Authentication())
46+
r.GET("/*path", func(c *gin.Context) {
47+
})
8248

83-
// test without token check
84-
fakeReq.URL = &url.URL{Path: "/apisix/admin/user/login"}
85-
assert.Equal(t, mw.Handle(ctx), errors.New("next middleware"))
49+
w := performRequest(r, "GET", "/apisix/admin/user/login", nil)
50+
assert.Equal(t, http.StatusOK, w.Code)
8651

87-
// test without authorization header
88-
fakeReq.URL = &url.URL{Path: "/apisix/admin/routes"}
89-
assert.Nil(t, mw.Handle(ctx))
90-
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
52+
w = performRequest(r, "GET", "/apisix/admin/routes", nil)
53+
assert.Equal(t, http.StatusUnauthorized, w.Code)
9154

9255
// test with token expire
9356
expireToken := genToken("admin", time.Now().Unix(), time.Now().Unix()-60*3600)
94-
fakeReq.Header.Set("Authorization", expireToken)
95-
assert.Nil(t, mw.Handle(ctx))
96-
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
57+
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": expireToken})
58+
assert.Equal(t, http.StatusUnauthorized, w.Code)
9759

98-
// test with temp subject
99-
tempSubjectToken := genToken("", time.Now().Unix(), time.Now().Unix()+60*3600)
100-
fakeReq.Header.Set("Authorization", tempSubjectToken)
101-
assert.Nil(t, mw.Handle(ctx))
102-
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
60+
// test with empty subject
61+
emptySubjectToken := genToken("", time.Now().Unix(), time.Now().Unix()+60*3600)
62+
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": emptySubjectToken})
63+
assert.Equal(t, http.StatusUnauthorized, w.Code)
10364

104-
// test username doesn't exist
105-
userToken := genToken("user1", time.Now().Unix(), time.Now().Unix()+60*3600)
106-
fakeReq.Header.Set("Authorization", userToken)
107-
assert.Nil(t, mw.Handle(ctx))
108-
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
65+
// test token with nonexistent username
66+
nonexistentUserToken := genToken("user1", time.Now().Unix(), time.Now().Unix()+60*3600)
67+
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": nonexistentUserToken})
68+
assert.Equal(t, http.StatusUnauthorized, w.Code)
10969

11070
// test auth success
111-
adminToken := genToken("admin", time.Now().Unix(), time.Now().Unix()+60*3600)
112-
fakeReq.Header.Set("Authorization", adminToken)
113-
ctx.SetOutput("test data")
114-
assert.Equal(t, mw.Handle(ctx), errors.New("next middleware"))
115-
assert.Equal(t, "test data", ctx.Output().(string))
71+
validToken := genToken("admin", time.Now().Unix(), time.Now().Unix()+60*3600)
72+
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": validToken})
73+
assert.Equal(t, http.StatusOK, w.Code)
11674
}

api/internal/filter/ip_filter_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func TestIPFilter_Handle(t *testing.T) {
3535
r.GET("/", func(c *gin.Context) {
3636
})
3737

38-
w := performRequest(r, "GET", "/")
38+
w := performRequest(r, "GET", "/", nil)
3939
assert.Equal(t, 200, w.Code)
4040

4141
// should forbidden
@@ -45,7 +45,7 @@ func TestIPFilter_Handle(t *testing.T) {
4545
r.GET("/fbd", func(c *gin.Context) {
4646
})
4747

48-
w = performRequest(r, "GET", "/fbd")
48+
w = performRequest(r, "GET", "/fbd", nil)
4949
assert.Equal(t, 403, w.Code)
5050

5151
// should allowed
@@ -54,7 +54,7 @@ func TestIPFilter_Handle(t *testing.T) {
5454
r.Use(IPFilter())
5555
r.GET("/test", func(c *gin.Context) {
5656
})
57-
w = performRequest(r, "GET", "/test")
57+
w = performRequest(r, "GET", "/test", nil)
5858
assert.Equal(t, 200, w.Code)
5959

6060
// should forbidden

api/internal/filter/logging_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ import (
2727
"github.com/apisix/manager-api/internal/log"
2828
)
2929

30-
func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder {
30+
func performRequest(r http.Handler, method, path string, headers map[string]string) *httptest.ResponseRecorder {
3131
req := httptest.NewRequest(method, path, nil)
32+
for key, val := range headers {
33+
req.Header.Add(key, val)
34+
}
3235
w := httptest.NewRecorder()
3336
r.ServeHTTP(w, req)
3437
return w
@@ -41,6 +44,6 @@ func TestRequestLogHandler(t *testing.T) {
4144
r.GET("/", func(c *gin.Context) {
4245
})
4346

44-
w := performRequest(r, "GET", "/")
47+
w := performRequest(r, "GET", "/", nil)
4548
assert.Equal(t, 200, w.Code)
4649
}

0 commit comments

Comments
 (0)