Skip to content

Commit 154dd51

Browse files
authored
Merge pull request #6 from bootjp/feature/2
fix missing status code
2 parents 5bb5362 + d8eaaa3 commit 154dd51

File tree

2 files changed

+83
-21
lines changed

2 files changed

+83
-21
lines changed

path_auth.go

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,7 @@ var (
3131
)
3232

3333
// ErrKeyAuthMissing is error type when PathAuth middleware is unable to extract value from lookups
34-
type ErrKeyAuthMissing struct {
35-
Err error
36-
}
37-
38-
// Error returns errors text
39-
func (e *ErrKeyAuthMissing) Error() string {
40-
return e.Err.Error()
41-
}
42-
43-
// Unwrap unwraps error
44-
func (e *ErrKeyAuthMissing) Unwrap() error {
45-
return e.Err
46-
}
34+
var ErrKeyAuthMissing = echo.NewHTTPError(http.StatusBadRequest, "Missing key in the request")
4735

4836
// PathAuth returns an PathAuth middleware.
4937
//
@@ -74,20 +62,35 @@ func PathAuthWithConfig(config PathAuthConfig) echo.MiddlewareFunc {
7462
if config.Skipper(c) {
7563
return next(c)
7664
}
77-
valid, err := config.Validator(c.Param(config.Param), c)
78-
if err != nil {
65+
66+
if !extract(config.Param, c.ParamNames()) {
7967
return &echo.HTTPError{
80-
Code: http.StatusUnauthorized,
81-
Message: "Unauthorized",
82-
Internal: err,
68+
Code: http.StatusBadRequest,
69+
Message: http.StatusText(http.StatusBadRequest),
70+
Internal: ErrKeyAuthMissing,
8371
}
8472
}
8573

86-
if valid {
74+
valid, err := config.Validator(c.Param(config.Param), c)
75+
if err == nil && valid {
8776
return next(c)
8877
}
8978

90-
return echo.NewHTTPError(http.StatusBadRequest)
79+
return &echo.HTTPError{
80+
Code: http.StatusUnauthorized,
81+
Message: http.StatusText(http.StatusUnauthorized),
82+
Internal: err,
83+
}
9184
}
9285
}
9386
}
87+
88+
func extract(cParam string, params []string) bool {
89+
for _, param := range params {
90+
if cParam == param {
91+
return true
92+
}
93+
}
94+
95+
return false
96+
}

path_auth_test.go

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func TestKeyAuth(t *testing.T) {
2626
handlerCalled := false
2727
handler := func(c echo.Context) error {
2828
handlerCalled = true
29+
//nolint:wrapcheck
2930
return c.String(http.StatusOK, "test")
3031
}
3132
middlewareChain := PathAuth("apikey", testKeyValidator)(handler)
@@ -48,6 +49,7 @@ func TestKeyAuth(t *testing.T) {
4849
handlerCalled := false
4950
handler := func(c echo.Context) error {
5051
handlerCalled = true
52+
//nolint:wrapcheck
5153
return c.String(http.StatusOK, "test")
5254
}
5355
middlewareChain := PathAuth("apikey", testKeyValidator)(handler)
@@ -63,9 +65,55 @@ func TestKeyAuth(t *testing.T) {
6365
err := middlewareChain(c)
6466

6567
assert.Error(t, err)
68+
assert.EqualError(t, err, "code=401, message=Unauthorized, internal=some user defined error")
6669
assert.False(t, handlerCalled)
6770
})
71+
t.Run("auth no error failed", func(t *testing.T) {
72+
handlerCalled := false
73+
handler := func(c echo.Context) error {
74+
handlerCalled = true
75+
//nolint:wrapcheck
76+
return c.String(http.StatusOK, "test")
77+
}
78+
middlewareChain := PathAuth("apikey", testKeyValidator)(handler)
6879

80+
e := echo.New()
81+
e.GET("/:apikey", middlewareChain)
82+
83+
req := httptest.NewRequest(http.MethodGet, "/", nil)
84+
rec := httptest.NewRecorder()
85+
86+
c := e.NewContext(req, rec)
87+
e.Router().Find(http.MethodGet, "/no-error", c)
88+
err := middlewareChain(c)
89+
90+
assert.Error(t, err)
91+
assert.EqualError(t, err, "code=401, message=Unauthorized")
92+
assert.False(t, handlerCalled)
93+
})
94+
t.Run("auth nokey", func(t *testing.T) {
95+
handlerCalled := false
96+
handler := func(c echo.Context) error {
97+
handlerCalled = true
98+
//nolint:wrapcheck
99+
return c.String(http.StatusOK, "test")
100+
}
101+
middlewareChain := PathAuth("undef", testKeyValidator)(handler)
102+
103+
e := echo.New()
104+
e.GET("/:apikey", middlewareChain)
105+
106+
req := httptest.NewRequest(http.MethodGet, "/", nil)
107+
rec := httptest.NewRecorder()
108+
109+
c := e.NewContext(req, rec)
110+
e.Router().Find(http.MethodGet, "/error-key", c)
111+
err := middlewareChain(c)
112+
113+
assert.Error(t, err)
114+
assert.EqualError(t, err, "code=400, message=Bad Request, internal=code=400, message=Missing key in the request")
115+
assert.False(t, handlerCalled)
116+
})
69117
}
70118

71119
func TestPathAuthWithConfig(t *testing.T) {
@@ -103,7 +151,7 @@ func TestPathAuthWithConfig(t *testing.T) {
103151
return req
104152
},
105153
expectHandlerCalled: false,
106-
expectError: "code=400, message=Bad Request",
154+
expectError: "code=401, message=Unauthorized",
107155
},
108156
}
109157

@@ -112,6 +160,7 @@ func TestPathAuthWithConfig(t *testing.T) {
112160
handlerCalled := false
113161
handler := func(c echo.Context) error {
114162
handlerCalled = true
163+
//nolint:wrapcheck
115164
return c.String(http.StatusOK, "test")
116165
}
117166
config := PathAuthConfig{
@@ -154,6 +203,7 @@ func TestPathAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
154203
"PathAuth: requires a validator function",
155204
func() {
156205
handler := func(c echo.Context) error {
206+
//nolint:wrapcheck
157207
return c.String(http.StatusOK, "test")
158208
}
159209
PathAuthWithConfig(PathAuthConfig{
@@ -169,6 +219,7 @@ func TestPathAuthWithConfig_panicsOnEmptyParam(t *testing.T) {
169219
"PathAuth: requires a param",
170220
func() {
171221
handler := func(c echo.Context) error {
222+
//nolint:wrapcheck
172223
return c.String(http.StatusOK, "test")
173224
}
174225
PathAuthWithConfig(PathAuthConfig{
@@ -185,6 +236,7 @@ func TestPathAuthWithConfig_panicsOnEmptyParam(t *testing.T) {
185236
"PathAuth: requires a param",
186237
func() {
187238
handler := func(c echo.Context) error {
239+
//nolint:wrapcheck
188240
return c.String(http.StatusOK, "test")
189241
}
190242
PathAuth("", func(auth string, c echo.Context) (bool, error) {
@@ -193,3 +245,10 @@ func TestPathAuthWithConfig_panicsOnEmptyParam(t *testing.T) {
193245
},
194246
)
195247
}
248+
249+
func TestExtract(t *testing.T) {
250+
251+
assert.True(t, extract("apikey", []string{"apikey", "valid-key"}))
252+
assert.False(t, extract("apikey", []string{"valid-key"}))
253+
254+
}

0 commit comments

Comments
 (0)