Skip to content

Commit 3de4875

Browse files
committed
implement top level matching
1 parent 62edfb2 commit 3de4875

File tree

2 files changed

+332
-15
lines changed

2 files changed

+332
-15
lines changed

rules/rules.go

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"errors"
55
"fmt"
66
"log/slog"
7+
neturl "net/url"
78
"strings"
89
)
910

@@ -17,11 +18,13 @@ type Rule struct {
1718
// The path segments of the url
1819
// nil means all paths allowed
1920
// a path segment of `*` acts as a wild card.
21+
// sub paths automatically match
2022
PathPattern []segmentPattern
2123

2224
// The labels of the host, i.e. ["google", "com"]
2325
// nil means all hosts allowed
2426
// A label of `*` acts as a wild card.
27+
// subdomains automatically match
2528
HostPattern []labelPattern
2629

2730
// The allowed http methods
@@ -121,17 +124,12 @@ func parseHostPattern(input string) (host []labelPattern, rest string, err error
121124
// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`.
122125
type labelPattern string
123126

124-
// An `asterisk` is treated as matching anything
125-
func (lp labelPattern) matches(input string) bool {
126-
return lp == "*" || string(lp) == input
127-
}
128-
129127
func parseLabelPattern(rest string) (labelPattern, string, error) {
130128
if rest == "" {
131129
return "", "", errors.New("expected label, got empty string")
132130
}
133131

134-
// If the label is simply an asterisk, good to go.
132+
// If the label is simply an asterisk, good to go.
135133
if rest[0] == '*' {
136134
return "*", rest[1:], nil
137135
}
@@ -221,11 +219,6 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) {
221219
// Represents a valid url path segmentPattern.
222220
type segmentPattern string
223221

224-
// An `*` is treated as matching anything
225-
func (sp segmentPattern) matches(input string) bool {
226-
return sp == "*" || string(sp) == input
227-
}
228-
229222
func parsePathSegmentPattern(input string) (segmentPattern, string, error) {
230223
if input == "" {
231224
return "", "", nil
@@ -359,9 +352,7 @@ func parseAllowRule(ruleStr string) (Rule, error) {
359352

360353
// Convert segments to strings
361354
rule.PathPattern = make([]segmentPattern, len(segments))
362-
for i, segment := range segments {
363-
rule.PathPattern[i] = segment
364-
}
355+
copy(rule.PathPattern, segments)
365356
rest = remaining
366357

367358
default:
@@ -431,5 +422,63 @@ func (re *Engine) Evaluate(method, url string) Result {
431422

432423
// Matches checks if the rule matches the given method and URL using wildcard patterns
433424
func (re *Engine) matches(r Rule, method, url string) bool {
425+
426+
// Check method patterns if they exist
427+
if r.MethodPatterns != nil {
428+
methodMatches := false
429+
for mp := range r.MethodPatterns {
430+
if string(mp) == method || string(mp) == "*" {
431+
methodMatches = true
432+
break
433+
}
434+
}
435+
if !methodMatches {
436+
return false
437+
}
438+
}
439+
440+
parsedUrl, err := neturl.Parse(url)
441+
if err != nil {
442+
return false
443+
}
444+
445+
if r.HostPattern != nil {
446+
// For a host pattern to match, every label has to match or be an `*`.
447+
// Subdomains also match automatically, meaning if the pattern is "wobble.com"
448+
// and the real is "wibble.wobble.com", it should match. We check this by comparing
449+
// from the end since patterns are stored in reverse order (TLD first).
450+
451+
labels := strings.Split(parsedUrl.Hostname(), ".")
452+
453+
// If the host pattern is longer than the actual host, it's definitely not a match
454+
if len(r.HostPattern) > len(labels) {
455+
return false
456+
}
457+
458+
// Compare from the end of both arrays since pattern is stored in reverse order
459+
for i, lp := range r.HostPattern {
460+
labelIndex := len(labels) - 1 - i
461+
if string(lp) != labels[labelIndex] && lp != "*" {
462+
return false
463+
}
464+
}
465+
}
466+
467+
if r.PathPattern != nil {
468+
segments := strings.Split(parsedUrl.Path, "/")
469+
470+
// If the path pattern is longer than the actual path, definitely not a match
471+
if len(r.PathPattern) > len(segments) {
472+
return false
473+
}
474+
475+
// Each segment in the pattern must be either as asterisk or match the actual path segment
476+
for i, sp := range r.PathPattern {
477+
if string(sp) != segments[i] && sp != "*" {
478+
return false
479+
}
480+
}
481+
}
482+
434483
return true
435484
}

0 commit comments

Comments
 (0)