Skip to content

Commit 41a8a95

Browse files
committed
🥥 api: add real ip middleware
1 parent 6f54cfd commit 41a8a95

File tree

2 files changed

+65
-16
lines changed

2 files changed

+65
-16
lines changed

api/api.go

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"net/http"
99
"net/http/pprof"
10+
"net/netip"
1011
"path"
1112
"strings"
1213

@@ -15,6 +16,7 @@ import (
1516
"github.com/database64128/shadowsocks-go/conn"
1617
"github.com/database64128/shadowsocks-go/tlscerts"
1718
"go.uber.org/zap"
19+
"go4.org/netipx"
1820
)
1921

2022
// Config stores the configuration for the RESTful API.
@@ -25,16 +27,13 @@ type Config struct {
2527
// DebugPprof enables pprof endpoints for debugging and profiling.
2628
DebugPprof bool `json:"debugPprof"`
2729

28-
// EnableTrustedProxyCheck enables trusted proxy checks.
29-
EnableTrustedProxyCheck bool `json:"enableTrustedProxyCheck"`
30+
// TrustedProxies specifies the IP address prefixes of trusted proxies.
31+
TrustedProxies []netip.Prefix `json:"trustedProxies"`
3032

31-
// TrustedProxies is the list of trusted proxies.
32-
// This only takes effect if EnableTrustedProxyCheck is true.
33-
TrustedProxies []string `json:"trustedProxies"`
34-
35-
// ProxyHeader is the header used to determine the client's IP address.
36-
// If empty, the remote peer's address is used.
37-
ProxyHeader string `json:"proxyHeader"`
33+
// RealIPHeaderKey specifies the header field to use for determining
34+
// the client's real IP address when the request is from a trusted proxy.
35+
// If empty, the real IP address is not appended to [http.Request.RemoteAddr].
36+
RealIPHeaderKey string `json:"realIPHeaderKey"`
3837

3938
// StaticPath is the path where static files are served from.
4039
// If empty, static file serving is disabled.
@@ -188,16 +187,21 @@ func (c *Config) NewServer(logger *zap.Logger, listenConfigCache conn.ListenConf
188187
basePath = joinPatternPath(basePath, c.SecretPath)
189188
}
190189

190+
realIP, err := newRealIPMiddleware(logger, c.TrustedProxies, c.RealIPHeaderKey)
191+
if err != nil {
192+
return nil, nil, fmt.Errorf("failed to create real IP middleware: %w", err)
193+
}
194+
191195
if c.DebugPprof {
192196
register := func(path string, handler http.HandlerFunc) {
193197
pattern := "GET " + joinPatternPath(basePath, path)
194-
mux.Handle(pattern, logPprofRequests(logger, handler))
198+
mux.Handle(pattern, realIP(logPprofRequests(logger, handler)))
195199
}
196200

197201
// [pprof.Index] requires the URL path to start with "/debug/pprof/".
198202
indexPath := joinPatternPath(basePath, "/debug/pprof/")
199203
prefix := strings.TrimSuffix(indexPath, "/debug/pprof/")
200-
mux.Handle(indexPath, logPprofRequests(logger, http.StripPrefix(prefix, http.HandlerFunc(pprof.Index))))
204+
mux.Handle(indexPath, realIP(logPprofRequests(logger, http.StripPrefix(prefix, http.HandlerFunc(pprof.Index)))))
201205

202206
register("/debug/pprof/cmdline", pprof.Cmdline)
203207
register("/debug/pprof/profile", pprof.Profile)
@@ -210,11 +214,11 @@ func (c *Config) NewServer(logger *zap.Logger, listenConfigCache conn.ListenConf
210214
sm := ssm.NewServerManager()
211215
sm.RegisterHandlers(func(method, path string, handler restapi.HandlerFunc) {
212216
pattern := method + " " + joinPatternPath(apiSSMv1Path, path)
213-
mux.Handle(pattern, logAPIRequests(logger, handler))
217+
mux.Handle(pattern, realIP(logAPIRequests(logger, handler)))
214218
})
215219

216220
if c.StaticPath != "" {
217-
mux.Handle("GET /", logFileServerRequests(logger, http.FileServer(http.Dir(c.StaticPath))))
221+
mux.Handle("GET /", realIP(logFileServerRequests(logger, http.FileServer(http.Dir(c.StaticPath)))))
218222
}
219223

220224
errorLog, err := zap.NewStdLogAt(logger, zap.ErrorLevel)
@@ -250,6 +254,49 @@ func joinPatternPath(elem ...string) string {
250254
return p
251255
}
252256

257+
// newRealIPMiddleware returns a middleware that appends the content of realIPHeaderKey
258+
// to [http.Request.RemoteAddr] if the request is from a trusted proxy.
259+
//
260+
// If realIPHeaderKey is empty, the middleware is a no-op.
261+
func newRealIPMiddleware(logger *zap.Logger, trustedProxies []netip.Prefix, realIPHeaderKey string) (func(http.Handler) http.Handler, error) {
262+
if realIPHeaderKey == "" {
263+
return func(h http.Handler) http.Handler {
264+
return h
265+
}, nil
266+
}
267+
268+
var sb netipx.IPSetBuilder
269+
for _, p := range trustedProxies {
270+
sb.AddPrefix(p)
271+
}
272+
273+
proxySet, err := sb.IPSet()
274+
if err != nil {
275+
return nil, fmt.Errorf("failed to build trusted proxy prefix set: %w", err)
276+
}
277+
278+
return func(h http.Handler) http.Handler {
279+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
280+
if v := r.Header[realIPHeaderKey]; len(v) > 0 {
281+
proxyAddrPort, err := netip.ParseAddrPort(r.RemoteAddr)
282+
if err != nil {
283+
logger.Warn("Failed to parse HTTP request remote address",
284+
zap.String("remoteAddr", r.RemoteAddr),
285+
zap.Error(err),
286+
)
287+
return
288+
}
289+
290+
if proxySet.Contains(proxyAddrPort.Addr()) {
291+
r.RemoteAddr = fmt.Sprintf("%s (%s: %v)", r.RemoteAddr, realIPHeaderKey, v)
292+
}
293+
}
294+
295+
h.ServeHTTP(w, r)
296+
})
297+
}, nil
298+
}
299+
253300
// logPprofRequests is a middleware that logs pprof requests.
254301
func logPprofRequests(logger *zap.Logger, h http.Handler) http.Handler {
255302
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

docs/config.json

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,9 +520,11 @@
520520
"api": {
521521
"enabled": true,
522522
"debugPprof": false,
523-
"enableTrustedProxyCheck": false,
524-
"trustedProxies": [],
525-
"proxyHeader": "X-Forwarded-For",
523+
"trustedProxies": [
524+
"127.0.0.1/32",
525+
"::1/128"
526+
],
527+
"realIPHeaderKey": "X-Forwarded-For",
526528
"staticPath": "",
527529
"secretPath": "4paZvyoK3dCjyQXU33md5huJMMYVD9o8",
528530
"listeners": [

0 commit comments

Comments
 (0)