44 "errors"
55 "fmt"
66 "log/slog"
7+ neturl "net/url"
78 "strings"
89)
910
@@ -17,11 +18,13 @@ type Rule struct {
1718 // The path segments of the url
1819 // nil means all paths allowed
1920 // a path segment of `*` acts as a wild card.
21+ // sub paths automatically match
2022 PathPattern []segmentPattern
2123
2224 // The labels of the host, i.e. ["google", "com"]
2325 // nil means all hosts allowed
2426 // A label of `*` acts as a wild card.
27+ // subdomains automatically match
2528 HostPattern []labelPattern
2629
2730 // The allowed http methods
@@ -121,17 +124,12 @@ func parseHostPattern(input string) (host []labelPattern, rest string, err error
121124// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`.
122125type labelPattern string
123126
124- // An `asterisk` is treated as matching anything
125- func (lp labelPattern ) matches (input string ) bool {
126- return lp == "*" || string (lp ) == input
127- }
128-
129127func parseLabelPattern (rest string ) (labelPattern , string , error ) {
130128 if rest == "" {
131129 return "" , "" , errors .New ("expected label, got empty string" )
132130 }
133131
134- // If the label is simply an asterisk, good to go.
132+ // If the label is simply an asterisk, good to go.
135133 if rest [0 ] == '*' {
136134 return "*" , rest [1 :], nil
137135 }
@@ -221,11 +219,6 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) {
221219// Represents a valid url path segmentPattern.
222220type segmentPattern string
223221
224- // An `*` is treated as matching anything
225- func (sp segmentPattern ) matches (input string ) bool {
226- return sp == "*" || string (sp ) == input
227- }
228-
229222func parsePathSegmentPattern (input string ) (segmentPattern , string , error ) {
230223 if input == "" {
231224 return "" , "" , nil
@@ -359,9 +352,7 @@ func parseAllowRule(ruleStr string) (Rule, error) {
359352
360353 // Convert segments to strings
361354 rule .PathPattern = make ([]segmentPattern , len (segments ))
362- for i , segment := range segments {
363- rule .PathPattern [i ] = segment
364- }
355+ copy (rule .PathPattern , segments )
365356 rest = remaining
366357
367358 default :
@@ -431,5 +422,63 @@ func (re *Engine) Evaluate(method, url string) Result {
431422
432423// Matches checks if the rule matches the given method and URL using wildcard patterns
433424func (re * Engine ) matches (r Rule , method , url string ) bool {
425+
426+ // Check method patterns if they exist
427+ if r .MethodPatterns != nil {
428+ methodMatches := false
429+ for mp := range r .MethodPatterns {
430+ if string (mp ) == method || string (mp ) == "*" {
431+ methodMatches = true
432+ break
433+ }
434+ }
435+ if ! methodMatches {
436+ return false
437+ }
438+ }
439+
440+ parsedUrl , err := neturl .Parse (url )
441+ if err != nil {
442+ return false
443+ }
444+
445+ if r .HostPattern != nil {
446+ // For a host pattern to match, every label has to match or be an `*`.
447+ // Subdomains also match automatically, meaning if the pattern is "wobble.com"
448+ // and the real is "wibble.wobble.com", it should match. We check this by comparing
449+ // from the end since patterns are stored in reverse order (TLD first).
450+
451+ labels := strings .Split (parsedUrl .Hostname (), "." )
452+
453+ // If the host pattern is longer than the actual host, it's definitely not a match
454+ if len (r .HostPattern ) > len (labels ) {
455+ return false
456+ }
457+
458+ // Compare from the end of both arrays since pattern is stored in reverse order
459+ for i , lp := range r .HostPattern {
460+ labelIndex := len (labels ) - 1 - i
461+ if string (lp ) != labels [labelIndex ] && lp != "*" {
462+ return false
463+ }
464+ }
465+ }
466+
467+ if r .PathPattern != nil {
468+ segments := strings .Split (parsedUrl .Path , "/" )
469+
470+ // If the path pattern is longer than the actual path, definitely not a match
471+ if len (r .PathPattern ) > len (segments ) {
472+ return false
473+ }
474+
475+ // Each segment in the pattern must be either as asterisk or match the actual path segment
476+ for i , sp := range r .PathPattern {
477+ if string (sp ) != segments [i ] && sp != "*" {
478+ return false
479+ }
480+ }
481+ }
482+
434483 return true
435484}
0 commit comments