Skip to content

Commit df1706e

Browse files
committed
feat: add Forwarded header parsing and real IP extraction with tests
1 parent ce0b12a commit df1706e

File tree

3 files changed

+135
-4
lines changed

3 files changed

+135
-4
lines changed

context.go

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type Context interface {
4040
// Scheme returns the HTTP protocol scheme, `http` or `https`.
4141
Scheme() string
4242

43+
SchemeForwarded() *Forwarded
44+
4345
// RealIP returns the client's network address based on `X-Forwarded-For`
4446
// or `X-Real-IP` request header.
4547
// The behavior can be configured using `Echo#IPExtractor`.
@@ -234,6 +236,13 @@ const (
234236
ContextKeyHeaderAllow = "echo_header_allow"
235237
)
236238

239+
type Forwarded struct {
240+
By []string
241+
For []string
242+
Host []string
243+
Proto []string
244+
}
245+
237246
const (
238247
defaultMemory = 32 << 20 // 32 MB
239248
indexPage = "index.html"
@@ -293,24 +302,85 @@ func (c *context) Scheme() string {
293302
return "http"
294303
}
295304

305+
func (c *context) SchemeForwarded() *Forwarded {
306+
// Parse and get "Forwarded" header.
307+
// See : https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded
308+
if scheme := c.request.Header.Get(HeaderForwarded); scheme != "" {
309+
f, err := c.parseForwarded(scheme)
310+
if err != nil {
311+
return nil
312+
}
313+
return &f
314+
}
315+
return nil
316+
}
317+
318+
func (c *context) parseForwarded(input string) (Forwarded, error) {
319+
forwarded := Forwarded{}
320+
entries := strings.Split(input, ",")
321+
322+
for _, entry := range entries {
323+
entry = strings.TrimSpace(entry)
324+
pairs := strings.Split(entry, ";")
325+
326+
for _, pair := range pairs {
327+
parts := strings.SplitN(pair, "=", 2)
328+
if len(parts) != 2 {
329+
return forwarded, fmt.Errorf("invalid pair: %s", pair)
330+
}
331+
332+
key := strings.TrimSpace(parts[0])
333+
value, err := url.QueryUnescape(strings.TrimSpace(parts[1]))
334+
if err != nil {
335+
return forwarded, fmt.Errorf("failed to unescape value: %w", err)
336+
}
337+
value = strings.Trim(value, "\"[]")
338+
switch key {
339+
case "by":
340+
forwarded.By = append(forwarded.By, value)
341+
case "for":
342+
forwarded.For = append(forwarded.For, value)
343+
case "host":
344+
forwarded.Host = append(forwarded.Host, value)
345+
case "proto":
346+
forwarded.Proto = append(forwarded.Proto, value)
347+
default:
348+
return forwarded, fmt.Errorf("unknown key: %s", key)
349+
}
350+
}
351+
}
352+
353+
return forwarded, nil
354+
}
355+
296356
func (c *context) RealIP() string {
297357
if c.echo != nil && c.echo.IPExtractor != nil {
298358
return c.echo.IPExtractor(c.request)
299359
}
360+
// Check if the "Forwarded" header is present in the request.
361+
if d := c.request.Header.Get(HeaderForwarded); d != "" {
362+
// Parse the "Forwarded" header.
363+
scheme, err := c.parseForwarded(d)
364+
if err != nil {
365+
return "" // Return an empty string if parsing fails.
366+
}
367+
if len(scheme.For) > 0 {
368+
return scheme.For[0] // Return first for item
369+
}
370+
return ""
371+
}
300372
// Fall back to legacy behavior
301373
if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
302374
i := strings.IndexAny(ip, ",")
303375
if i > 0 {
304376
xffip := strings.TrimSpace(ip[:i])
305-
xffip = strings.TrimPrefix(xffip, "[")
306-
xffip = strings.TrimSuffix(xffip, "]")
377+
xffip = strings.Trim(xffip, "\"[]")
307378
return xffip
308379
}
309380
return ip
310381
}
311382
if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
312-
ip = strings.TrimPrefix(ip, "[")
313-
ip = strings.TrimSuffix(ip, "]")
383+
ip = strings.Trim(ip, "\"[]")
314384
return ip
315385
}
316386
ra, _, _ := net.SplitHostPort(c.request.RemoteAddr)

context_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,50 @@ func TestContext_Scheme(t *testing.T) {
961961
}
962962
}
963963

964+
func TestContext_SchemeForwarded(t *testing.T) {
965+
tests := []struct {
966+
c Context
967+
s *Forwarded
968+
}{
969+
{
970+
&context{
971+
request: &http.Request{
972+
Header: http.Header{HeaderForwarded: []string{"for=192.0.2.60;proto=http;by=203.0.113.43"}},
973+
},
974+
},
975+
&Forwarded{
976+
For: []string{"192.0.2.60"},
977+
Proto: []string{"http"},
978+
By: []string{"203.0.113.43"},
979+
},
980+
},
981+
{
982+
&context{
983+
request: &http.Request{
984+
Header: http.Header{HeaderForwarded: []string{"for=192.0.2.43, for=198.51.100.17"}},
985+
},
986+
},
987+
&Forwarded{
988+
For: []string{"192.0.2.43", "198.51.100.17"},
989+
},
990+
},
991+
{
992+
&context{
993+
request: &http.Request{
994+
Header: http.Header{HeaderForwarded: []string{"for=192.0.2.43, for=[2001:db8:cafe::17]"}},
995+
},
996+
},
997+
&Forwarded{
998+
For: []string{"192.0.2.43", "2001:db8:cafe::17"},
999+
},
1000+
},
1001+
}
1002+
1003+
for _, tt := range tests {
1004+
assert.Equal(t, tt.s, tt.c.SchemeForwarded())
1005+
}
1006+
}
1007+
9641008
func TestContext_IsWebSocket(t *testing.T) {
9651009
tests := []struct {
9661010
c Context
@@ -1062,6 +1106,22 @@ func TestContext_RealIP(t *testing.T) {
10621106
},
10631107
"127.0.0.1",
10641108
},
1109+
{
1110+
&context{
1111+
request: &http.Request{
1112+
Header: http.Header{HeaderForwarded: []string{"for=192.0.2.43, for=198.51.100.17"}},
1113+
},
1114+
},
1115+
"192.0.2.43",
1116+
},
1117+
{
1118+
&context{
1119+
request: &http.Request{
1120+
Header: http.Header{HeaderForwarded: []string{"for=[2001:db8:85a3:8d3:1319:8a2e:370:7348], for=2001:db8::1"}},
1121+
},
1122+
},
1123+
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
1124+
},
10651125
{
10661126
&context{
10671127
request: &http.Request{

echo.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ const (
221221
HeaderUpgrade = "Upgrade"
222222
HeaderVary = "Vary"
223223
HeaderWWWAuthenticate = "WWW-Authenticate"
224+
HeaderForwarded = "Forwarded"
224225
HeaderXForwardedFor = "X-Forwarded-For"
225226
HeaderXForwardedProto = "X-Forwarded-Proto"
226227
HeaderXForwardedProtocol = "X-Forwarded-Protocol"

0 commit comments

Comments
 (0)