77 "fmt"
88 "net/http"
99 "net/http/pprof"
10+ "net/netip"
1011 "path"
1112 "strings"
1213
@@ -15,6 +16,7 @@ import (
1516 "github.com/database64128/shadowsocks-go/conn"
1617 "github.com/database64128/shadowsocks-go/tlscerts"
1718 "go.uber.org/zap"
19+ "go4.org/netipx"
1820)
1921
2022// Config stores the configuration for the RESTful API.
@@ -25,16 +27,13 @@ type Config struct {
2527 // DebugPprof enables pprof endpoints for debugging and profiling.
2628 DebugPprof bool `json:"debugPprof"`
2729
28- // EnableTrustedProxyCheck enables trusted proxy checks .
29- EnableTrustedProxyCheck bool `json:"enableTrustedProxyCheck "`
30+ // TrustedProxies specifies the IP address prefixes of trusted proxies .
31+ TrustedProxies []netip. Prefix `json:"trustedProxies "`
3032
31- // TrustedProxies is the list of trusted proxies.
32- // This only takes effect if EnableTrustedProxyCheck is true.
33- TrustedProxies []string `json:"trustedProxies"`
34-
35- // ProxyHeader is the header used to determine the client's IP address.
36- // If empty, the remote peer's address is used.
37- ProxyHeader string `json:"proxyHeader"`
33+ // RealIPHeaderKey specifies the header field to use for determining
34+ // the client's real IP address when the request is from a trusted proxy.
35+ // If empty, the real IP address is not appended to [http.Request.RemoteAddr].
36+ RealIPHeaderKey string `json:"realIPHeaderKey"`
3837
3938 // StaticPath is the path where static files are served from.
4039 // If empty, static file serving is disabled.
@@ -188,16 +187,21 @@ func (c *Config) NewServer(logger *zap.Logger, listenConfigCache conn.ListenConf
188187 basePath = joinPatternPath (basePath , c .SecretPath )
189188 }
190189
190+ realIP , err := newRealIPMiddleware (logger , c .TrustedProxies , c .RealIPHeaderKey )
191+ if err != nil {
192+ return nil , nil , fmt .Errorf ("failed to create real IP middleware: %w" , err )
193+ }
194+
191195 if c .DebugPprof {
192196 register := func (path string , handler http.HandlerFunc ) {
193197 pattern := "GET " + joinPatternPath (basePath , path )
194- mux .Handle (pattern , logPprofRequests (logger , handler ))
198+ mux .Handle (pattern , realIP ( logPprofRequests (logger , handler ) ))
195199 }
196200
197201 // [pprof.Index] requires the URL path to start with "/debug/pprof/".
198202 indexPath := joinPatternPath (basePath , "/debug/pprof/" )
199203 prefix := strings .TrimSuffix (indexPath , "/debug/pprof/" )
200- mux .Handle (indexPath , logPprofRequests (logger , http .StripPrefix (prefix , http .HandlerFunc (pprof .Index ))))
204+ mux .Handle (indexPath , realIP ( logPprofRequests (logger , http .StripPrefix (prefix , http .HandlerFunc (pprof .Index ) ))))
201205
202206 register ("/debug/pprof/cmdline" , pprof .Cmdline )
203207 register ("/debug/pprof/profile" , pprof .Profile )
@@ -210,11 +214,11 @@ func (c *Config) NewServer(logger *zap.Logger, listenConfigCache conn.ListenConf
210214 sm := ssm .NewServerManager ()
211215 sm .RegisterHandlers (func (method , path string , handler restapi.HandlerFunc ) {
212216 pattern := method + " " + joinPatternPath (apiSSMv1Path , path )
213- mux .Handle (pattern , logAPIRequests (logger , handler ))
217+ mux .Handle (pattern , realIP ( logAPIRequests (logger , handler ) ))
214218 })
215219
216220 if c .StaticPath != "" {
217- mux .Handle ("GET /" , logFileServerRequests (logger , http .FileServer (http .Dir (c .StaticPath ))))
221+ mux .Handle ("GET /" , realIP ( logFileServerRequests (logger , http .FileServer (http .Dir (c .StaticPath ) ))))
218222 }
219223
220224 errorLog , err := zap .NewStdLogAt (logger , zap .ErrorLevel )
@@ -250,6 +254,49 @@ func joinPatternPath(elem ...string) string {
250254 return p
251255}
252256
257+ // newRealIPMiddleware returns a middleware that appends the content of realIPHeaderKey
258+ // to [http.Request.RemoteAddr] if the request is from a trusted proxy.
259+ //
260+ // If realIPHeaderKey is empty, the middleware is a no-op.
261+ func newRealIPMiddleware (logger * zap.Logger , trustedProxies []netip.Prefix , realIPHeaderKey string ) (func (http.Handler ) http.Handler , error ) {
262+ if realIPHeaderKey == "" {
263+ return func (h http.Handler ) http.Handler {
264+ return h
265+ }, nil
266+ }
267+
268+ var sb netipx.IPSetBuilder
269+ for _ , p := range trustedProxies {
270+ sb .AddPrefix (p )
271+ }
272+
273+ proxySet , err := sb .IPSet ()
274+ if err != nil {
275+ return nil , fmt .Errorf ("failed to build trusted proxy prefix set: %w" , err )
276+ }
277+
278+ return func (h http.Handler ) http.Handler {
279+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
280+ if v := r .Header [realIPHeaderKey ]; len (v ) > 0 {
281+ proxyAddrPort , err := netip .ParseAddrPort (r .RemoteAddr )
282+ if err != nil {
283+ logger .Warn ("Failed to parse HTTP request remote address" ,
284+ zap .String ("remoteAddr" , r .RemoteAddr ),
285+ zap .Error (err ),
286+ )
287+ return
288+ }
289+
290+ if proxySet .Contains (proxyAddrPort .Addr ()) {
291+ r .RemoteAddr = fmt .Sprintf ("%s (%s: %v)" , r .RemoteAddr , realIPHeaderKey , v )
292+ }
293+ }
294+
295+ h .ServeHTTP (w , r )
296+ })
297+ }, nil
298+ }
299+
253300// logPprofRequests is a middleware that logs pprof requests.
254301func logPprofRequests (logger * zap.Logger , h http.Handler ) http.Handler {
255302 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
0 commit comments