Skip to content

Commit 1a58389

Browse files
committed
redis auth
1 parent b89507b commit 1a58389

File tree

8 files changed

+145
-8
lines changed

8 files changed

+145
-8
lines changed

auth/auth.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package auth
22

33
import (
4+
"context"
45
"errors"
56
"net/http"
67
"net/url"
@@ -10,7 +11,7 @@ import (
1011
)
1112

1213
type Auth interface {
13-
Validate(wr http.ResponseWriter, req *http.Request) (string, bool)
14+
Validate(ctx context.Context, wr http.ResponseWriter, req *http.Request) (string, bool)
1415
Stop()
1516
}
1617

@@ -29,6 +30,10 @@ func NewAuth(paramstr string, logger *clog.CondLogger) (Auth, error) {
2930
return NewHMACAuth(url, logger)
3031
case "cert":
3132
return NewCertAuth(url, logger)
33+
case "redis":
34+
return NewRedisAuth(url, false, logger)
35+
case "redis-cluster":
36+
return NewRedisAuth(url, true, logger)
3237
case "none":
3338
return NoAuth{}, nil
3439
default:

auth/basic.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package auth
22

33
import (
4+
"context"
45
"encoding/base64"
56
"errors"
67
"fmt"
@@ -114,7 +115,7 @@ func (auth *BasicAuth) reloadLoop(interval time.Duration) {
114115
}
115116
}
116117

117-
func (auth *BasicAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) {
118+
func (auth *BasicAuth) Validate(_ context.Context, wr http.ResponseWriter, req *http.Request) (string, bool) {
118119
hdr := req.Header.Get("Proxy-Authorization")
119120
if hdr == "" {
120121
requireBasicAuth(wr, req, auth.hiddenDomain)

auth/cert.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package auth
33
import (
44
"bufio"
55
"bytes"
6+
"context"
67
"encoding/hex"
78
"errors"
89
"fmt"
@@ -65,7 +66,7 @@ func NewCertAuth(param_url *url.URL, logger *clog.CondLogger) (*CertAuth, error)
6566
return auth, nil
6667
}
6768

68-
func (auth *CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) {
69+
func (auth *CertAuth) Validate(_ context.Context, wr http.ResponseWriter, req *http.Request) (string, bool) {
6970
if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 || len(req.TLS.VerifiedChains[0]) < 1 {
7071
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
7172
return "", false

auth/common.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package auth
22

33
import (
44
"crypto/subtle"
5+
"errors"
56
"net"
67
"net/http"
78
"strconv"
89
"strings"
10+
11+
"github.com/tg123/go-htpasswd"
912
)
1013

1114
func matchHiddenDomain(host, hidden_domain string) bool {
@@ -28,3 +31,16 @@ func requireBasicAuth(wr http.ResponseWriter, req *http.Request, hidden_domain s
2831
wr.Write([]byte(AUTH_REQUIRED_MSG))
2932
}
3033
}
34+
35+
func makePasswdMatcher(encoded string) (htpasswd.EncodedPasswd, error) {
36+
for _, p := range htpasswd.DefaultSystems {
37+
matcher, err := p(encoded)
38+
if err != nil {
39+
return nil, err
40+
}
41+
if matcher != nil {
42+
return matcher, nil
43+
}
44+
}
45+
return nil, errors.New("no suitable password encoding system found")
46+
}

auth/hmac.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package auth
22

33
import (
4+
"context"
45
"crypto/hmac"
56
"crypto/sha256"
67
"encoding/base64"
@@ -83,7 +84,7 @@ func VerifyHMACLoginAndPassword(secret []byte, login, password string) bool {
8384
return hmac.Equal(token.Signature[:], expectedMAC)
8485
}
8586

86-
func (auth *HMACAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) {
87+
func (auth *HMACAuth) Validate(_ context.Context, wr http.ResponseWriter, req *http.Request) (string, bool) {
8788
hdr := req.Header.Get("Proxy-Authorization")
8889
if hdr == "" {
8990
requireBasicAuth(wr, req, auth.hiddenDomain)

auth/noauth.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package auth
22

3-
import "net/http"
3+
import (
4+
"context"
5+
"net/http"
6+
)
47

58
type NoAuth struct{}
69

7-
func (_ NoAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) {
10+
func (_ NoAuth) Validate(_ context.Context, _ http.ResponseWriter, _ *http.Request) (string, bool) {
811
return "", true
912
}
1013

auth/redis.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"net/http"
7+
"net/url"
8+
"strconv"
9+
"strings"
10+
11+
clog "github.com/SenseUnit/dumbproxy/log"
12+
13+
"github.com/redis/go-redis/v9"
14+
)
15+
16+
type RedisAuth struct {
17+
logger *clog.CondLogger
18+
hiddenDomain string
19+
r redis.Cmdable
20+
keyPrefix string
21+
}
22+
23+
func NewRedisAuth(param_url *url.URL, cluster bool, logger *clog.CondLogger) (*RedisAuth, error) {
24+
values, err := url.ParseQuery(param_url.RawQuery)
25+
if err != nil {
26+
return nil, err
27+
}
28+
auth := &RedisAuth{
29+
logger: logger,
30+
hiddenDomain: strings.ToLower(values.Get("hidden_domain")),
31+
keyPrefix: values.Get("key_prefix"),
32+
}
33+
if cluster {
34+
opts, err := redis.ParseClusterURL(values.Get("url"))
35+
if err != nil {
36+
return nil, err
37+
}
38+
auth.r = redis.NewClusterClient(opts)
39+
} else {
40+
opts, err := redis.ParseURL(values.Get("url"))
41+
if err != nil {
42+
return nil, err
43+
}
44+
auth.r = redis.NewClient(opts)
45+
}
46+
return auth, nil
47+
}
48+
49+
func (auth *RedisAuth) Validate(ctx context.Context, wr http.ResponseWriter, req *http.Request) (string, bool) {
50+
hdr := req.Header.Get("Proxy-Authorization")
51+
if hdr == "" {
52+
requireBasicAuth(wr, req, auth.hiddenDomain)
53+
return "", false
54+
}
55+
hdr_parts := strings.SplitN(hdr, " ", 2)
56+
if len(hdr_parts) != 2 || strings.ToLower(hdr_parts[0]) != "basic" {
57+
requireBasicAuth(wr, req, auth.hiddenDomain)
58+
return "", false
59+
}
60+
61+
token := hdr_parts[1]
62+
data, err := base64.StdEncoding.DecodeString(token)
63+
if err != nil {
64+
requireBasicAuth(wr, req, auth.hiddenDomain)
65+
return "", false
66+
}
67+
68+
pair := strings.SplitN(string(data), ":", 2)
69+
if len(pair) != 2 {
70+
requireBasicAuth(wr, req, auth.hiddenDomain)
71+
return "", false
72+
}
73+
74+
login := pair[0]
75+
password := pair[1]
76+
77+
encodedPasswd, err := auth.r.Get(ctx, auth.keyPrefix+login).Result()
78+
if err != nil {
79+
auth.logger.Debug("error fetching key %q from Redis: %v", auth.keyPrefix+login, err)
80+
requireBasicAuth(wr, req, auth.hiddenDomain)
81+
return "", false
82+
}
83+
matcher, err := makePasswdMatcher(encodedPasswd)
84+
if err != nil {
85+
auth.logger.Debug("can't create password matcher from Redis key %q: %v", auth.keyPrefix+login, err)
86+
requireBasicAuth(wr, req, auth.hiddenDomain)
87+
return "", false
88+
}
89+
90+
if matcher.MatchesPassword(password) {
91+
if auth.hiddenDomain != "" &&
92+
(req.Host == auth.hiddenDomain || req.URL.Host == auth.hiddenDomain) {
93+
wr.Header().Set("Content-Length", strconv.Itoa(len([]byte(AUTH_TRIGGERED_MSG))))
94+
wr.Header().Set("Pragma", "no-cache")
95+
wr.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
96+
wr.Header().Set("Expires", EPOCH_EXPIRE)
97+
wr.Header()["Date"] = nil
98+
wr.WriteHeader(http.StatusOK)
99+
wr.Write([]byte(AUTH_TRIGGERED_MSG))
100+
return "", false
101+
} else {
102+
return login, true
103+
}
104+
}
105+
requireBasicAuth(wr, req, auth.hiddenDomain)
106+
return "", false
107+
}
108+
109+
func (auth *RedisAuth) Stop() {
110+
}

handler/handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
192192
return
193193
}
194194

195-
username, ok := s.auth.Validate(wr, req)
195+
ctx := req.Context()
196+
username, ok := s.auth.Validate(ctx, wr, req)
196197
localAddr := getLocalAddr(req.Context())
197198
s.logger.Info("Request: %v => %v %q %v %v %v", req.RemoteAddr, localAddr, username, req.Proto, req.Method, req.URL)
198199

@@ -208,7 +209,6 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
208209
ipHints = &hintValues[0]
209210
}
210211
}
211-
ctx := req.Context()
212212
ctx = ddto.BoundDialerParamsToContext(ctx, ipHints, trimAddrPort(localAddr))
213213
ctx = ddto.FilterParamsToContext(ctx, req, username)
214214
req = req.WithContext(ctx)

0 commit comments

Comments
 (0)