Skip to content

Commit 3455e5c

Browse files
authored
Implement Basic Auth (#371)
1 parent 4df7cbe commit 3455e5c

File tree

5 files changed

+504
-1
lines changed

5 files changed

+504
-1
lines changed

pkg/gofr/gofr.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package gofr
22

33
import (
44
"fmt"
5-
65
"net/http"
76
"os"
87
"strconv"
@@ -277,6 +276,23 @@ func (o *otelErrorHandler) Handle(e error) {
277276
o.logger.Error(e.Error())
278277
}
279278

279+
func (a *App) EnableBasicAuth(credentials ...string) {
280+
if len(credentials)%2 != 0 {
281+
a.container.Error("Invalid number of arguments for EnableBasicAuth")
282+
}
283+
284+
users := make(map[string]string)
285+
for i := 0; i < len(credentials); i += 2 {
286+
users[credentials[i]] = credentials[i+1]
287+
}
288+
289+
a.httpServer.router.Use(middleware.BasicAuthMiddleware(middleware.BasicAuthProvider{Users: users}))
290+
}
291+
292+
func (a *App) EnableBasicAuthWithFunc(validateFunc func(username, password string) bool) {
293+
a.httpServer.router.Use(middleware.BasicAuthMiddleware(middleware.BasicAuthProvider{ValidateFunc: validateFunc}))
294+
}
295+
280296
func (a *App) Subscribe(topic string, handler SubscribeFunc) {
281297
if a.container.GetSubscriber() == nil {
282298
a.container.Logger.Errorf("Subscriber not initialized in the container")
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package middleware
2+
3+
import (
4+
"encoding/base64"
5+
"net/http"
6+
"strings"
7+
)
8+
9+
const credentialLength = 2
10+
11+
type BasicAuthProvider struct {
12+
Users map[string]string
13+
ValidateFunc func(username, password string) bool
14+
}
15+
16+
func BasicAuthMiddleware(basicAuthProvider BasicAuthProvider) func(handler http.Handler) http.Handler {
17+
return func(handler http.Handler) http.Handler {
18+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19+
authHeader := r.Header.Get("Authorization")
20+
if authHeader == "" {
21+
http.Error(w, "Unauthorized: Authorization header missing", http.StatusUnauthorized)
22+
return
23+
}
24+
25+
authParts := strings.Split(authHeader, " ")
26+
if len(authParts) != 2 || authParts[0] != "basic" {
27+
http.Error(w, "Unauthorized: Invalid Authorization header", http.StatusUnauthorized)
28+
return
29+
}
30+
31+
payload, err := base64.StdEncoding.DecodeString(authParts[1])
32+
if err != nil {
33+
http.Error(w, "Unauthorized: Invalid credentials format", http.StatusUnauthorized)
34+
return
35+
}
36+
37+
credentials := strings.Split(string(payload), ":")
38+
if len(credentials) != credentialLength {
39+
http.Error(w, "Unauthorized: Invalid credentials", http.StatusUnauthorized)
40+
return
41+
}
42+
43+
if basicAuthProvider.ValidateFunc != nil {
44+
if !basicAuthProvider.ValidateFunc(credentials[0], credentials[1]) {
45+
http.Error(w, "Unauthorized: Invalid username or password", http.StatusUnauthorized)
46+
return
47+
}
48+
} else {
49+
if storedPass, ok := basicAuthProvider.Users[credentials[0]]; !ok || storedPass != credentials[1] {
50+
http.Error(w, "Unauthorized: Invalid username or password", http.StatusUnauthorized)
51+
return
52+
}
53+
}
54+
55+
handler.ServeHTTP(w, r)
56+
})
57+
}
58+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestBasicAuthMiddleware(t *testing.T) {
12+
validationFunc := func(user, pass string) bool {
13+
if user == "abc" && pass == "pass123" {
14+
return true
15+
}
16+
17+
return false
18+
}
19+
20+
testCases := []struct {
21+
name string
22+
authHeader string
23+
authProvider BasicAuthProvider
24+
expectedStatusCode int
25+
}{
26+
{
27+
name: "Valid Authorization",
28+
authHeader: "basic dXNlcjpwYXNzd29yZA==",
29+
authProvider: BasicAuthProvider{Users: map[string]string{"user": "password"}},
30+
expectedStatusCode: http.StatusOK,
31+
},
32+
{
33+
name: "Valid Authorization with validation Func",
34+
authHeader: "basic YWJjOnBhc3MxMjM=",
35+
authProvider: BasicAuthProvider{ValidateFunc: validationFunc},
36+
expectedStatusCode: http.StatusOK,
37+
},
38+
{
39+
name: "false from validation Func",
40+
authHeader: "basic dXNlcjpwYXNzd29yZA==",
41+
authProvider: BasicAuthProvider{ValidateFunc: validationFunc},
42+
expectedStatusCode: http.StatusUnauthorized,
43+
},
44+
{
45+
name: "No Authorization Header",
46+
authHeader: "",
47+
authProvider: BasicAuthProvider{},
48+
expectedStatusCode: http.StatusUnauthorized,
49+
},
50+
{
51+
name: "Invalid Authorization Header",
52+
authHeader: "Bearer token",
53+
authProvider: BasicAuthProvider{},
54+
expectedStatusCode: http.StatusUnauthorized,
55+
},
56+
{
57+
name: "Invalid encoding",
58+
authHeader: "basic invalidbase64encoding==",
59+
authProvider: BasicAuthProvider{},
60+
expectedStatusCode: http.StatusUnauthorized,
61+
},
62+
{
63+
name: "improper credentials format",
64+
authHeader: "basic dXNlcis=",
65+
authProvider: BasicAuthProvider{},
66+
expectedStatusCode: http.StatusUnauthorized,
67+
},
68+
{
69+
name: "Unauthorized",
70+
authHeader: "basic dXNlcjpwYXNzd29yZA==",
71+
authProvider: BasicAuthProvider{},
72+
expectedStatusCode: http.StatusUnauthorized,
73+
},
74+
}
75+
76+
for _, tc := range testCases {
77+
t.Run(tc.name, func(t *testing.T) {
78+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
79+
w.WriteHeader(http.StatusOK)
80+
})
81+
82+
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
83+
req.Header.Set("Authorization", tc.authHeader)
84+
rr := httptest.NewRecorder()
85+
86+
authMiddleware := BasicAuthMiddleware(tc.authProvider)
87+
authMiddleware(handler).ServeHTTP(rr, req)
88+
89+
assert.Equal(t, tc.expectedStatusCode, rr.Code)
90+
})
91+
}
92+
}

pkg/gofr/service/basic_auth.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package service
2+
3+
import (
4+
"context"
5+
b64 "encoding/base64"
6+
"net/http"
7+
)
8+
9+
type BasicAuthConfig struct {
10+
UserName string
11+
Password string
12+
}
13+
14+
func (a *BasicAuthConfig) addOption(h HTTP) HTTP {
15+
return &BasicAuthProvider{
16+
userName: a.UserName,
17+
password: a.Password,
18+
HTTP: h,
19+
}
20+
}
21+
22+
type BasicAuthProvider struct {
23+
userName string
24+
password string
25+
26+
HTTP
27+
}
28+
29+
func (ba *BasicAuthProvider) addAuthorizationHeader(headers map[string]string) error {
30+
decodedPassword, err := b64.StdEncoding.DecodeString(ba.password)
31+
if err != nil {
32+
return err
33+
}
34+
35+
encodedAuth := b64.StdEncoding.EncodeToString([]byte(ba.userName + ":" + string(decodedPassword)))
36+
37+
headers["Authorization"] = "basic " + encodedAuth
38+
39+
return nil
40+
}
41+
42+
func (ba *BasicAuthProvider) Get(ctx context.Context, path string, queryParams map[string]interface{}) (*http.Response, error) {
43+
return ba.GetWithHeaders(ctx, path, queryParams, nil)
44+
}
45+
46+
func (ba *BasicAuthProvider) GetWithHeaders(ctx context.Context, path string, queryParams map[string]interface{},
47+
headers map[string]string) (*http.Response, error) {
48+
err := ba.populateHeaders(headers)
49+
if err != nil {
50+
return nil, err
51+
}
52+
53+
return ba.HTTP.GetWithHeaders(ctx, path, queryParams, headers)
54+
}
55+
56+
func (ba *BasicAuthProvider) Post(ctx context.Context, path string, queryParams map[string]interface{},
57+
body []byte) (*http.Response, error) {
58+
return ba.PostWithHeaders(ctx, path, queryParams, body, nil)
59+
}
60+
61+
func (ba *BasicAuthProvider) PostWithHeaders(ctx context.Context, path string, queryParams map[string]interface{},
62+
body []byte, headers map[string]string) (*http.Response, error) {
63+
err := ba.populateHeaders(headers)
64+
if err != nil {
65+
return nil, err
66+
}
67+
68+
return ba.HTTP.PostWithHeaders(ctx, path, queryParams, body, headers)
69+
}
70+
71+
func (ba *BasicAuthProvider) Put(ctx context.Context, api string, queryParams map[string]interface{}, body []byte) (*http.Response, error) {
72+
return ba.PutWithHeaders(ctx, api, queryParams, body, nil)
73+
}
74+
75+
func (ba *BasicAuthProvider) PutWithHeaders(ctx context.Context, path string, queryParams map[string]interface{},
76+
body []byte, headers map[string]string) (*http.Response, error) {
77+
err := ba.populateHeaders(headers)
78+
if err != nil {
79+
return nil, err
80+
}
81+
82+
return ba.HTTP.PutWithHeaders(ctx, path, queryParams, body, headers)
83+
}
84+
85+
func (ba *BasicAuthProvider) Patch(ctx context.Context, path string, queryParams map[string]interface{},
86+
body []byte) (*http.Response, error) {
87+
return ba.PatchWithHeaders(ctx, path, queryParams, body, nil)
88+
}
89+
90+
func (ba *BasicAuthProvider) PatchWithHeaders(ctx context.Context, path string, queryParams map[string]interface{},
91+
body []byte, headers map[string]string) (*http.Response, error) {
92+
err := ba.populateHeaders(headers)
93+
if err != nil {
94+
return nil, err
95+
}
96+
97+
return ba.HTTP.PatchWithHeaders(ctx, path, queryParams, body, headers)
98+
}
99+
100+
func (ba *BasicAuthProvider) Delete(ctx context.Context, path string, body []byte) (*http.Response, error) {
101+
return ba.DeleteWithHeaders(ctx, path, body, nil)
102+
}
103+
104+
func (ba *BasicAuthProvider) DeleteWithHeaders(ctx context.Context, path string, body []byte,
105+
headers map[string]string) (*http.Response, error) {
106+
err := ba.populateHeaders(headers)
107+
if err != nil {
108+
return nil, err
109+
}
110+
111+
return ba.HTTP.DeleteWithHeaders(ctx, path, body, headers)
112+
}
113+
114+
func (ba *BasicAuthProvider) populateHeaders(headers map[string]string) error {
115+
if headers == nil {
116+
headers = make(map[string]string)
117+
}
118+
119+
err := ba.addAuthorizationHeader(headers)
120+
if err != nil {
121+
return err
122+
}
123+
124+
return nil
125+
}

0 commit comments

Comments
 (0)