@@ -2,31 +2,97 @@ package httpmw
2
2
3
3
import (
4
4
"fmt"
5
+ "net"
5
6
"net/http"
7
+ "strings"
8
+ "sync"
6
9
"time"
7
10
8
11
"github.com/go-chi/httprate"
9
12
13
+ "cdr.dev/slog"
10
14
"github.com/coder/wgtunnel/tunneld/httpapi"
11
15
"github.com/coder/wgtunnel/tunnelsdk"
12
16
)
13
17
18
+ type RateLimitConfig struct {
19
+ Log slog.Logger
20
+
21
+ // Count of the amount of requests allowed in the Window. If the Count is
22
+ // zero, the rate limiter is disabled.
23
+ Count int
24
+ Window time.Duration
25
+
26
+ // RealIPHeader is the header to use to get the real IP address of the
27
+ // request. If this is empty, the request's RemoteAddr is used.
28
+ RealIPHeader string
29
+ }
30
+
14
31
// RateLimit returns a handler that limits requests based on IP.
15
- func RateLimit (count int , window time. Duration ) func (http.Handler ) http.Handler {
16
- if count <= 0 {
32
+ func RateLimit (cfg RateLimitConfig ) func (http.Handler ) http.Handler {
33
+ if cfg . Count <= 0 {
17
34
return func (handler http.Handler ) http.Handler {
18
35
return handler
19
36
}
20
37
}
21
38
39
+ var logMissingHeaderOnce sync.Once
40
+
22
41
return httprate .Limit (
23
- count ,
24
- window ,
25
- httprate .WithKeyByIP (),
42
+ cfg .Count ,
43
+ cfg .Window ,
44
+ httprate .WithKeyFuncs (func (r * http.Request ) (string , error ) {
45
+ if cfg .RealIPHeader != "" {
46
+ val := r .Header .Get (cfg .RealIPHeader )
47
+ if val != "" {
48
+ val = strings .TrimSpace (strings .Split (val , "," )[0 ])
49
+ return canonicalizeIP (val ), nil
50
+ }
51
+
52
+ logMissingHeaderOnce .Do (func () {
53
+ cfg .Log .Warn (r .Context (), "real IP header not found or invalid on request" , slog .F ("header" , cfg .RealIPHeader ), slog .F ("value" , val ))
54
+ })
55
+ }
56
+
57
+ return httprate .KeyByIP (r )
58
+ }),
26
59
httprate .WithLimitHandler (func (rw http.ResponseWriter , r * http.Request ) {
27
60
httpapi .Write (r .Context (), rw , http .StatusTooManyRequests , tunnelsdk.Response {
28
- Message : fmt .Sprintf ("You've been rate limited for sending more than %v requests in %v." , count , window ),
61
+ Message : fmt .Sprintf ("You've been rate limited for sending more than %v requests in %v." , cfg . Count , cfg . Window ),
29
62
})
30
63
}),
31
64
)
32
65
}
66
+
67
+ // canonicalizeIP returns a form of ip suitable for comparison to other IPs.
68
+ // For IPv4 addresses, this is simply the whole string.
69
+ // For IPv6 addresses, this is the /64 prefix.
70
+ //
71
+ // This function is taken directly from go-chi/httprate:
72
+ // https://github.com/go-chi/httprate/blob/0ea2148d09a46ae62efcad05b70d87418d8e4f43/httprate.go#L111
73
+ func canonicalizeIP (ip string ) string {
74
+ isIPv6 := false
75
+ // This is how net.ParseIP decides if an address is IPv6
76
+ // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704
77
+ for i := 0 ; ! isIPv6 && i < len (ip ); i ++ {
78
+ switch ip [i ] {
79
+ case '.' :
80
+ // IPv4
81
+ return ip
82
+ case ':' :
83
+ // IPv6
84
+ isIPv6 = true
85
+ }
86
+ }
87
+ if ! isIPv6 {
88
+ // Not an IP address at all
89
+ return ip
90
+ }
91
+
92
+ ipv6 := net .ParseIP (ip )
93
+ if ipv6 == nil {
94
+ return ip
95
+ }
96
+
97
+ return ipv6 .Mask (net .CIDRMask (64 , 128 )).String ()
98
+ }
0 commit comments