4
4
"errors"
5
5
"fmt"
6
6
"log/slog"
7
+ neturl "net/url"
7
8
"strings"
8
9
)
9
10
@@ -17,11 +18,13 @@ type Rule struct {
17
18
// The path segments of the url
18
19
// nil means all paths allowed
19
20
// a path segment of `*` acts as a wild card.
21
+ // sub paths automatically match
20
22
PathPattern []segmentPattern
21
23
22
24
// The labels of the host, i.e. ["google", "com"]
23
25
// nil means all hosts allowed
24
26
// A label of `*` acts as a wild card.
27
+ // subdomains automatically match
25
28
HostPattern []labelPattern
26
29
27
30
// The allowed http methods
@@ -121,17 +124,12 @@ func parseHostPattern(input string) (host []labelPattern, rest string, err error
121
124
// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`.
122
125
type labelPattern string
123
126
124
- // An `asterisk` is treated as matching anything
125
- func (lp labelPattern ) matches (input string ) bool {
126
- return lp == "*" || string (lp ) == input
127
- }
128
-
129
127
func parseLabelPattern (rest string ) (labelPattern , string , error ) {
130
128
if rest == "" {
131
129
return "" , "" , errors .New ("expected label, got empty string" )
132
130
}
133
131
134
- // If the label is simply an asterisk, good to go.
132
+ // If the label is simply an asterisk, good to go.
135
133
if rest [0 ] == '*' {
136
134
return "*" , rest [1 :], nil
137
135
}
@@ -221,11 +219,6 @@ func parsePathPattern(input string) ([]segmentPattern, string, error) {
221
219
// Represents a valid url path segmentPattern.
222
220
type segmentPattern string
223
221
224
- // An `*` is treated as matching anything
225
- func (sp segmentPattern ) matches (input string ) bool {
226
- return sp == "*" || string (sp ) == input
227
- }
228
-
229
222
func parsePathSegmentPattern (input string ) (segmentPattern , string , error ) {
230
223
if input == "" {
231
224
return "" , "" , nil
@@ -359,9 +352,7 @@ func parseAllowRule(ruleStr string) (Rule, error) {
359
352
360
353
// Convert segments to strings
361
354
rule .PathPattern = make ([]segmentPattern , len (segments ))
362
- for i , segment := range segments {
363
- rule .PathPattern [i ] = segment
364
- }
355
+ copy (rule .PathPattern , segments )
365
356
rest = remaining
366
357
367
358
default :
@@ -431,5 +422,63 @@ func (re *Engine) Evaluate(method, url string) Result {
431
422
432
423
// Matches checks if the rule matches the given method and URL using wildcard patterns
433
424
func (re * Engine ) matches (r Rule , method , url string ) bool {
425
+
426
+ // Check method patterns if they exist
427
+ if r .MethodPatterns != nil {
428
+ methodMatches := false
429
+ for mp := range r .MethodPatterns {
430
+ if string (mp ) == method || string (mp ) == "*" {
431
+ methodMatches = true
432
+ break
433
+ }
434
+ }
435
+ if ! methodMatches {
436
+ return false
437
+ }
438
+ }
439
+
440
+ parsedUrl , err := neturl .Parse (url )
441
+ if err != nil {
442
+ return false
443
+ }
444
+
445
+ if r .HostPattern != nil {
446
+ // For a host pattern to match, every label has to match or be an `*`.
447
+ // Subdomains also match automatically, meaning if the pattern is "wobble.com"
448
+ // and the real is "wibble.wobble.com", it should match. We check this by comparing
449
+ // from the end since patterns are stored in reverse order (TLD first).
450
+
451
+ labels := strings .Split (parsedUrl .Hostname (), "." )
452
+
453
+ // If the host pattern is longer than the actual host, it's definitely not a match
454
+ if len (r .HostPattern ) > len (labels ) {
455
+ return false
456
+ }
457
+
458
+ // Compare from the end of both arrays since pattern is stored in reverse order
459
+ for i , lp := range r .HostPattern {
460
+ labelIndex := len (labels ) - 1 - i
461
+ if string (lp ) != labels [labelIndex ] && lp != "*" {
462
+ return false
463
+ }
464
+ }
465
+ }
466
+
467
+ if r .PathPattern != nil {
468
+ segments := strings .Split (parsedUrl .Path , "/" )
469
+
470
+ // If the path pattern is longer than the actual path, definitely not a match
471
+ if len (r .PathPattern ) > len (segments ) {
472
+ return false
473
+ }
474
+
475
+ // Each segment in the pattern must be either as asterisk or match the actual path segment
476
+ for i , sp := range r .PathPattern {
477
+ if string (sp ) != segments [i ] && sp != "*" {
478
+ return false
479
+ }
480
+ }
481
+ }
482
+
434
483
return true
435
484
}
0 commit comments