Skip to content

Commit 55e37c2

Browse files
committed
Implement case-insensitive lookups + redirections
fixes #1
1 parent 9b1fa09 commit 55e37c2

File tree

4 files changed

+285
-16
lines changed

4 files changed

+285
-16
lines changed

router.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,21 @@ func NotFound(w http.ResponseWriter, req *http.Request) {
9696
// Router is a http.Handler which can be used to dispatch requests to different
9797
// handler functions via configurable routes
9898
type Router struct {
99-
node
99+
node // embed the root node
100100

101-
// Enables automatic redirection if the current route can't be matched but
101+
// Enables automatic redirection if the current route can't be matched but a
102102
// handler for the path with (without) the trailing slash exists.
103103
// For example if /foo/ is requested but a route only exists for /foo, the
104104
// client is redirected to /foo with http status code 301.
105105
RedirectTrailingSlash bool
106106

107+
// Enables automatic redirection if the current route can't be matched but a
108+
// case-insensitive lookup of the path finds a handler.
109+
// The router then permanent redirects (http status code 301) to the
110+
// corrected path.
111+
// For example /FOO and /Foo could be redirected to /foo.
112+
RedirectCaseInsensitive bool
113+
107114
// Configurable handler func which is used when no matching route is found.
108115
// Default is the NotFound func of this package.
109116
NotFound http.HandlerFunc
@@ -124,8 +131,9 @@ var _ http.Handler = New()
124131
// requested Host.
125132
func New() *Router {
126133
return &Router{
127-
RedirectTrailingSlash: true,
128-
NotFound: NotFound,
134+
RedirectTrailingSlash: true,
135+
RedirectCaseInsensitive: true,
136+
NotFound: NotFound,
129137
}
130138
}
131139

@@ -226,11 +234,18 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
226234
}
227235
http.Redirect(w, req, path, http.StatusMovedPermanently)
228236
return
229-
} else { // Handle 404
230-
if r.NotFound != nil {
231-
r.NotFound(w, req)
232-
} else {
233-
http.NotFound(w, req)
237+
} else if r.RedirectCaseInsensitive {
238+
fixedPath, found := r.findCaseInsensitivePath(req.Method, path, r.RedirectTrailingSlash)
239+
if found {
240+
http.Redirect(w, req, string(fixedPath), http.StatusMovedPermanently)
241+
return
234242
}
235243
}
244+
245+
// Handle 404
246+
if r.NotFound != nil {
247+
r.NotFound(w, req)
248+
} else {
249+
http.NotFound(w, req)
250+
}
236251
}

router_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ func TestRouterNotFound(t *testing.T) {
138138
}{
139139
{"/path/", NotFound, 301, "map[Location:[/path]]"}, // TSR -/
140140
{"/dir", NotFound, 301, "map[Location:[/dir/]]"}, // TSR +/
141+
{"/PATH", NotFound, 301, "map[Location:[/path]]"}, // Fixed Case
142+
{"/DIR/", NotFound, 301, "map[Location:[/dir/]]"}, // Fixed Case
143+
{"/PATH/", NotFound, 301, "map[Location:[/path]]"}, // Fixed Case -/
144+
{"/DIR", NotFound, 301, "map[Location:[/dir/]]"}, // Fixed Case +/
141145
{"/../path", NotFound, 301, "map[Location:[/path]]"}, // CleanPath
142146
{"/nope", NotFound, 404, ""}, // NotFound
143147
{"/nope", nil, 404, ""}, // NotFound

tree.go

Lines changed: 126 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
package httprouter
66

7+
import (
8+
"strings"
9+
"unicode"
10+
)
11+
712
func min(a, b int) int {
813
if a <= b {
914
return a
@@ -55,7 +60,7 @@ func (n *node) addRoute(method, path string, handle Handle) {
5560
n.priority++
5661
// non-empty tree
5762
if len(n.path) > 0 || len(n.children) > 0 {
58-
OUTER:
63+
WALK:
5964
for {
6065
// Find the longest common prefix.
6166
// This also implies that the commom prefix contains no ':' or '*'
@@ -92,7 +97,7 @@ func (n *node) addRoute(method, path string, handle Handle) {
9297
if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
9398
// check for longer wildcard, e.g. :name and :names
9499
if len(n.path) >= len(path) || path[len(n.path)] == '/' {
95-
continue OUTER
100+
continue WALK
96101
}
97102
}
98103

@@ -105,15 +110,15 @@ func (n *node) addRoute(method, path string, handle Handle) {
105110
if n.nType == param && c == '/' && len(n.children) == 1 {
106111
n = n.children[0]
107112
n.priority++
108-
continue OUTER
113+
continue WALK
109114
}
110115

111116
// Check if a child with the next path byte exists
112117
for i, index := range n.indices {
113118
if c == index {
114119
i = n.incrementChildPrio(i)
115120
n = n.children[i]
116-
continue OUTER
121+
continue WALK
117122
}
118123
}
119124

@@ -253,8 +258,7 @@ func (n *node) insertChild(method, path string, handle Handle) {
253258
// made if a handle exists with an extra (without the) trailing slash for the
254259
// given path.
255260
func (n *node) getValue(method, path string) (handle Handle, vars map[string]string, tsr bool) {
256-
// Walk tree nodes
257-
OUTER:
261+
WALK: // Outer loop for walking the tree nodes
258262
for len(path) >= len(n.path) && path[:len(n.path)] == n.path {
259263
path = path[len(n.path):]
260264

@@ -349,7 +353,7 @@ OUTER:
349353
for i, index := range n.indices {
350354
if c == index {
351355
n = n.children[i]
352-
continue OUTER
356+
continue WALK
353357
}
354358
}
355359

@@ -365,3 +369,118 @@ OUTER:
365369
tsr = (len(path)+1 == len(n.path) && n.path[len(path)] == '/' && n.handle[method] != nil) || (path == "/")
366370
return
367371
}
372+
373+
// Makes a case-insensitive lookup of the given path and tries to find a handler
374+
// for the given request method. It can optionally also fix trailing slashes.
375+
// It returns the case-corrected path and a bool indicating wether the lookup
376+
// was successful.
377+
func (n *node) findCaseInsensitivePath(method, path string, fixTrailingSlash bool) (ciPath []byte, found bool) {
378+
ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory
379+
380+
// Outer loop for walking the tree nodes
381+
for len(path) >= len(n.path) && strings.ToLower(path[:len(n.path)]) == strings.ToLower(n.path) {
382+
path = path[len(n.path):]
383+
ciPath = append(ciPath, n.path...)
384+
385+
if len(path) == 0 {
386+
// Check if this node has a handle registered for the given node
387+
if n.handle[method] != nil {
388+
return ciPath, true
389+
}
390+
391+
// No handle found.
392+
// Try to fix the path by adding a trailing slash
393+
if fixTrailingSlash {
394+
for i, index := range n.indices {
395+
if index == '/' {
396+
n = n.children[i]
397+
if (n.path == "/" && n.handle[method] != nil) ||
398+
(n.nType == catchAll && n.children[0].handle[method] != nil) {
399+
return append(ciPath, '/'), true
400+
}
401+
return
402+
}
403+
}
404+
}
405+
return
406+
407+
} else if n.wildChild {
408+
n = n.children[0]
409+
410+
switch n.nType {
411+
case param:
412+
// find param end (either '/'' or path end)
413+
k := 0
414+
for k < len(path) && path[k] != '/' {
415+
k++
416+
}
417+
418+
// add param value to case insensitive path
419+
ciPath = append(ciPath, path[:k]...)
420+
421+
// we need to go deeper!
422+
if k < len(path) {
423+
if len(n.children) > 0 {
424+
path = path[k:]
425+
n = n.children[0]
426+
continue
427+
} else { // ... but we can't
428+
if fixTrailingSlash && len(path) == k+1 {
429+
return ciPath, true
430+
}
431+
return
432+
}
433+
}
434+
435+
if n.handle[method] != nil {
436+
return ciPath, true
437+
} else if fixTrailingSlash && len(n.children) == 1 {
438+
// No handle found. Check if a handle for this path + a
439+
// trailing slash exists
440+
n = n.children[0]
441+
if n.path == "/" && n.handle[method] != nil {
442+
return append(ciPath, '/'), true
443+
}
444+
}
445+
return
446+
447+
case catchAll:
448+
return append(ciPath, path...), true
449+
450+
default:
451+
panic("Unknown node type")
452+
}
453+
454+
} else {
455+
r := unicode.ToLower(rune(path[0]))
456+
for i, index := range n.indices {
457+
// must use recursive approach since both index and
458+
// ToLower(index) could exist. We must check both.
459+
if r == unicode.ToLower(rune(index)) {
460+
out, found := n.children[i].findCaseInsensitivePath(method, path, fixTrailingSlash)
461+
if found {
462+
return append(ciPath, out...), true
463+
}
464+
}
465+
}
466+
467+
// Nothing found. We can recommend to redirect to the same URL
468+
// without a trailing slash if a leaf exists for that path
469+
found = (fixTrailingSlash && path == "/" && n.handle[method] != nil)
470+
return
471+
}
472+
}
473+
474+
// Nothing found.
475+
// Try to fix the path by adding / removing a trailing slash
476+
if fixTrailingSlash {
477+
if len(path)+1 == len(n.path) && n.path[len(path)] == '/' &&
478+
strings.ToLower(path) == strings.ToLower(n.path[:len(path)]) &&
479+
n.handle[method] != nil {
480+
return append(ciPath, n.path...), true
481+
} else if path == "/" {
482+
return ciPath, true
483+
}
484+
}
485+
return
486+
}

tree_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,134 @@ func TestTreeTrailingSlashRedirect(t *testing.T) {
457457
}
458458
}
459459
}
460+
461+
func TestTreeFindCaseInsensitivePath(t *testing.T) {
462+
tree := &node{}
463+
464+
routes := [...]string{
465+
"/hi",
466+
"/b/",
467+
"/ABC/",
468+
"/search/:query",
469+
"/cmd/:tool/",
470+
"/src/*filepath",
471+
"/x",
472+
"/x/y",
473+
"/y/",
474+
"/y/z",
475+
"/0/:id",
476+
"/0/:id/1",
477+
"/1/:id/",
478+
"/1/:id/2",
479+
"/aa",
480+
"/a/",
481+
"/doc",
482+
"/doc/go_faq.html",
483+
"/doc/go1.html",
484+
"/doc/go/away",
485+
"/no/a",
486+
"/no/b",
487+
}
488+
489+
for _, route := range routes {
490+
recv := catchPanic(func() {
491+
tree.addRoute("GET", route, fakeHandler(route))
492+
})
493+
if recv != nil {
494+
t.Fatalf("panic inserting route '%s': %v", route, recv)
495+
}
496+
}
497+
498+
// Check out == in for all registered routes
499+
// With fixTrailingSlash = true
500+
for _, route := range routes {
501+
out, found := tree.findCaseInsensitivePath("GET", route, true)
502+
if !found {
503+
t.Errorf("Route '%s' not found!", route)
504+
} else if string(out) != route {
505+
t.Errorf("Wrong result for route '%s': %s", route, string(out))
506+
}
507+
}
508+
// With fixTrailingSlash = false
509+
for _, route := range routes {
510+
out, found := tree.findCaseInsensitivePath("GET", route, false)
511+
if !found {
512+
t.Errorf("Route '%s' not found!", route)
513+
} else if string(out) != route {
514+
t.Errorf("Wrong result for route '%s': %s", route, string(out))
515+
}
516+
}
517+
518+
tests := []struct {
519+
in string
520+
out string
521+
found bool
522+
slash bool
523+
}{
524+
{"/HI", "/hi", true, false},
525+
{"/HI/", "/hi", true, true},
526+
{"/B", "/b/", true, true},
527+
{"/B/", "/b/", true, false},
528+
{"/abc", "/ABC/", true, true},
529+
{"/abc/", "/ABC/", true, false},
530+
{"/aBc", "/ABC/", true, true},
531+
{"/aBc/", "/ABC/", true, false},
532+
{"/abC", "/ABC/", true, true},
533+
{"/abC/", "/ABC/", true, false},
534+
{"/SEARCH/QUERY", "/search/QUERY", true, false},
535+
{"/SEARCH/QUERY/", "/search/QUERY", true, true},
536+
{"/CMD/TOOL/", "/cmd/TOOL/", true, false},
537+
{"/CMD/TOOL", "/cmd/TOOL/", true, true},
538+
{"/SRC/FILE/PATH", "/src/FILE/PATH", true, false},
539+
{"/x/Y", "/x/y", true, false},
540+
{"/x/Y/", "/x/y", true, true},
541+
{"/X/y", "/x/y", true, false},
542+
{"/X/y/", "/x/y", true, true},
543+
{"/X/Y", "/x/y", true, false},
544+
{"/X/Y/", "/x/y", true, true},
545+
{"/Y/", "/y/", true, false},
546+
{"/Y", "/y/", true, true},
547+
{"/Y/z", "/y/z", true, false},
548+
{"/Y/z/", "/y/z", true, true},
549+
{"/Y/Z", "/y/z", true, false},
550+
{"/Y/Z/", "/y/z", true, true},
551+
{"/y/Z", "/y/z", true, false},
552+
{"/y/Z/", "/y/z", true, true},
553+
{"/Aa", "/aa", true, false},
554+
{"/Aa/", "/aa", true, true},
555+
{"/AA", "/aa", true, false},
556+
{"/AA/", "/aa", true, true},
557+
{"/aA", "/aa", true, false},
558+
{"/aA/", "/aa", true, true},
559+
{"/A/", "/a/", true, false},
560+
{"/A", "/a/", true, true},
561+
{"/DOC", "/doc", true, false},
562+
{"/DOC/", "/doc", true, true},
563+
{"/NO", "", false, true},
564+
{"/DOC/GO", "", false, true},
565+
}
566+
// With fixTrailingSlash = true
567+
for _, test := range tests {
568+
out, found := tree.findCaseInsensitivePath("GET", test.in, true)
569+
if found != test.found || (found && (string(out) != test.out)) {
570+
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
571+
test.in, string(out), found, test.out, test.found)
572+
return
573+
}
574+
}
575+
// With fixTrailingSlash = false
576+
for _, test := range tests {
577+
out, found := tree.findCaseInsensitivePath("GET", test.in, false)
578+
if test.slash {
579+
if found { // test needs a trailingSlash fix. It must not be found!
580+
t.Errorf("Found without fixTrailingSlash: %s; got %s", test.in, string(out))
581+
}
582+
} else {
583+
if found != test.found || (found && (string(out) != test.out)) {
584+
t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t",
585+
test.in, string(out), found, test.out, test.found)
586+
return
587+
}
588+
}
589+
}
590+
}

0 commit comments

Comments
 (0)