@@ -17,47 +17,52 @@ type Rule struct {
17
17
// The path segments of the url
18
18
// nil means all paths allowed
19
19
// a path segment of `*` acts as a wild card.
20
- Path []string
20
+ PathPattern []segmentPattern
21
21
22
22
// The labels of the host, i.e. ["google", "com"]
23
- // nil means no hosts allowed
24
- // subdomains automatically match
25
- Host []string
23
+ // nil means all hosts allowed
24
+ // A label of `*` acts as a wild card.
25
+ HostPattern []labelPattern
26
26
27
27
// The allowed http methods
28
28
// nil means all methods allowed
29
- Methods map [string ]struct {}
29
+ MethodPatterns map [methodPattern ]struct {}
30
30
31
31
// Raw rule string for logging
32
- Raw string
32
+ Raw string
33
33
}
34
34
35
- type httpToken string
35
+ type methodPattern string
36
+
37
+ // An asterisk is treated as matching any method
38
+ func (t methodPattern ) matches (input string ) bool {
39
+ return t == "*" || string (t ) == input
40
+ }
36
41
37
42
// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by
38
43
// various systems.
39
44
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6
40
- func parseHTTPToken (token string ) (httpToken , string , error ) {
45
+ func parseMethodPattern (token string ) (methodPattern , string , error ) {
41
46
if token == "" {
42
47
return "" , "" , errors .New ("expected http token, got empty string" )
43
48
}
44
- return doParseHTTPToken (token , nil )
49
+ return doParseMethodPattern (token , nil )
45
50
}
46
51
47
- func doParseHTTPToken (token string , acc []byte ) (httpToken , string , error ) {
52
+ func doParseMethodPattern (token string , acc []byte ) (methodPattern , string , error ) {
48
53
// BASE CASE: if the token passed in is empty, we're done parsing
49
54
if token == "" {
50
- return httpToken (acc ), "" , nil
55
+ return methodPattern (acc ), "" , nil
51
56
}
52
57
53
58
// If the next byte in the string is not a valid http token character, we're done parsing.
54
59
if ! isHTTPTokenChar (token [0 ]) {
55
- return httpToken (acc ), token , nil
60
+ return methodPattern (acc ), token , nil
56
61
}
57
62
58
63
// The next character is valid, so the http token continues
59
64
acc = append (acc , token [0 ])
60
- return doParseHTTPToken (token [1 :], acc )
65
+ return doParseMethodPattern (token [1 :], acc )
61
66
}
62
67
63
68
// The valid characters that can be in an http token (like the lexer/parser kind of token).
@@ -85,16 +90,16 @@ func isHTTPTokenChar(c byte) bool {
85
90
// Represents a valid host.
86
91
// https://datatracker.ietf.org/doc/html/rfc952
87
92
// https://datatracker.ietf.org/doc/html/rfc1123#page-13
88
- func parseHost (input string ) (host []label , rest string , err error ) {
93
+ func parseHostPattern (input string ) (host []labelPattern , rest string , err error ) {
89
94
rest = input
90
- var label label
95
+ var label labelPattern
91
96
92
97
if input == "" {
93
98
return nil , "" , errors .New ("expected host, got empty string" )
94
99
}
95
100
96
101
// There should be at least one label.
97
- label , rest , err = parseLabel (rest )
102
+ label , rest , err = parseLabelPattern (rest )
98
103
if err != nil {
99
104
return nil , "" , err
100
105
}
@@ -108,7 +113,7 @@ func parseHost(input string) (host []label, rest string, err error) {
108
113
break
109
114
}
110
115
111
- label , rest , err = parseLabel (rest )
116
+ label , rest , err = parseLabelPattern (rest )
112
117
if err != nil {
113
118
return nil , "" , err
114
119
}
@@ -119,9 +124,9 @@ func parseHost(input string) (host []label, rest string, err error) {
119
124
}
120
125
121
126
// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`.
122
- type label string
127
+ type labelPattern string
123
128
124
- func parseLabel (rest string ) (label , string , error ) {
129
+ func parseLabelPattern (rest string ) (labelPattern , string , error ) {
125
130
if rest == "" {
126
131
return "" , "" , errors .New ("expected label, got empty string" )
127
132
}
@@ -141,7 +146,7 @@ func parseLabel(rest string) (label, string, error) {
141
146
return "" , "" , fmt .Errorf ("invalid label: %s" , rest [:i ])
142
147
}
143
148
144
- return label (rest [:i ]), rest [i :], nil
149
+ return labelPattern (rest [:i ]), rest [i :], nil
145
150
}
146
151
147
152
func isValidLabelChar (c byte ) bool {
@@ -163,12 +168,12 @@ func isValidLabelChar(c byte) bool {
163
168
}
164
169
}
165
170
166
- func parsePath (input string ) ([]segment , string , error ) {
171
+ func parsePathPattern (input string ) ([]segmentPattern , string , error ) {
167
172
if input == "" {
168
173
return nil , "" , nil
169
174
}
170
175
171
- var segments []segment
176
+ var segments []segmentPattern
172
177
rest := input
173
178
174
179
// If the path doesn't start with '/', it's not a valid absolute path
@@ -185,12 +190,12 @@ func parsePath(input string) ([]segment, string, error) {
185
190
}
186
191
187
192
// Parse the next segment
188
- seg , remaining , err := parsePathSegment (rest )
193
+ seg , remaining , err := parsePathSegmentPattern (rest )
189
194
if err != nil {
190
195
return nil , "" , err
191
196
}
192
197
193
- // If we got an empty segment and there's still input,
198
+ // If we got an empty segment and there's still input,
194
199
// it means we hit an invalid character
195
200
if seg == "" && remaining != "" {
196
201
break
@@ -208,18 +213,18 @@ func parsePath(input string) ([]segment, string, error) {
208
213
return segments , rest , nil
209
214
}
210
215
211
- // Represents a valid url path segment .
212
- type segment string
216
+ // Represents a valid url path segmentPattern .
217
+ type segmentPattern string
213
218
214
- func parsePathSegment (input string ) (segment , string , error ) {
219
+ func parsePathSegmentPattern (input string ) (segmentPattern , string , error ) {
215
220
if input == "" {
216
221
return "" , "" , nil
217
222
}
218
223
219
224
var i int
220
225
for i = 0 ; i < len (input ); i ++ {
221
226
c := input [i ]
222
-
227
+
223
228
// Check for percent-encoded characters (%XX)
224
229
if c == '%' {
225
230
if i + 2 >= len (input ) || ! isHexDigit (input [i + 1 ]) || ! isHexDigit (input [i + 2 ]) {
@@ -228,14 +233,14 @@ func parsePathSegment(input string) (segment, string, error) {
228
233
i += 2
229
234
continue
230
235
}
231
-
236
+
232
237
// Check for valid pchar characters
233
238
if ! isPChar (c ) {
234
239
break
235
240
}
236
241
}
237
242
238
- return segment (input [:i ]), input [i :], nil
243
+ return segmentPattern (input [:i ]), input [i :], nil
239
244
}
240
245
241
246
// isUnreserved returns true if the character is unreserved per RFC 3986
@@ -286,8 +291,72 @@ func parseKey(rule string) (string, string, error) {
286
291
return "" , "" , errors .New ("expected key" )
287
292
}
288
293
289
- func parseAllowRule (string ) (Rule , error ) {
290
- return Rule {}, nil
294
+ func parseAllowRule (ruleStr string ) (Rule , error ) {
295
+ rule := Rule {
296
+ Raw : ruleStr ,
297
+ }
298
+
299
+ rest := ruleStr
300
+
301
+ for rest != "" {
302
+ // Parse the key
303
+ key , valueRest , err := parseKey (rest )
304
+ if err != nil {
305
+ return Rule {}, fmt .Errorf ("failed to parse key: %v" , err )
306
+ }
307
+
308
+ // Parse the value based on the key type
309
+ switch key {
310
+ case "method" :
311
+ token , remaining , err := parseMethodPattern (valueRest )
312
+ if err != nil {
313
+ return Rule {}, fmt .Errorf ("failed to parse method: %v" , err )
314
+ }
315
+
316
+ // Initialize Methods map if needed
317
+ if rule .MethodPatterns == nil {
318
+ rule .MethodPatterns = make (map [methodPattern ]struct {})
319
+ }
320
+ rule .MethodPatterns [token ] = struct {}{}
321
+ rest = remaining
322
+
323
+ case "domain" :
324
+ hostLabels , remaining , err := parseHostPattern (valueRest )
325
+ if err != nil {
326
+ return Rule {}, fmt .Errorf ("failed to parse domain: %v" , err )
327
+ }
328
+
329
+ // Convert labels to strings in reverse order (TLD first)
330
+ rule .HostPattern = make ([]labelPattern , len (hostLabels ))
331
+ for i , label := range hostLabels {
332
+ rule .HostPattern [len (hostLabels )- 1 - i ] = label
333
+ }
334
+ rest = remaining
335
+
336
+ case "path" :
337
+ segments , remaining , err := parsePathPattern (valueRest )
338
+ if err != nil {
339
+ return Rule {}, fmt .Errorf ("failed to parse path: %v" , err )
340
+ }
341
+
342
+ // Convert segments to strings
343
+ rule .PathPattern = make ([]segmentPattern , len (segments ))
344
+ for i , segment := range segments {
345
+ rule .PathPattern [i ] = segment
346
+ }
347
+ rest = remaining
348
+
349
+ default :
350
+ return Rule {}, fmt .Errorf ("unknown key: %s" , key )
351
+ }
352
+
353
+ // Skip whitespace or comma separators
354
+ for rest != "" && (rest [0 ] == ' ' || rest [0 ] == '\t' || rest [0 ] == ',' ) {
355
+ rest = rest [1 :]
356
+ }
357
+ }
358
+
359
+ return rule , nil
291
360
}
292
361
293
362
// ParseAllowSpecs parses a slice of --allow specs into allow Rules.
@@ -342,17 +411,85 @@ func (re *Engine) Evaluate(method, url string) Result {
342
411
}
343
412
}
344
413
345
- // Matches checks if the rule matches the given method and URL using wildcard patterns
346
- func ( re * Engine ) matches ( r Rule , method , url string ) bool {
347
- // If the rule doesn't have any method filters, don't restrict the allowed methods
348
- if r . Methods == nil {
349
- return true
414
+ type protocol string
415
+
416
+ func parseProtocol ( input string ) ( protocol , string , error ) {
417
+ if input == "" {
418
+ return "" , "" , errors . New ( "expected protocol, got empty string" )
350
419
}
351
420
352
- // If the rule has method filters and the provided method is not one of them, block the request.
353
- if _ , methodIsAllowed := r .Methods [method ]; ! methodIsAllowed {
354
- return false
421
+ // Look for "://" separator
422
+ if idx := strings .Index (input , "://" ); idx > 0 {
423
+ protocolPart := input [:idx ]
424
+ rest := input [idx + 3 :]
425
+
426
+ // Validate protocol characters (scheme per RFC 3986)
427
+ // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
428
+ if len (protocolPart ) == 0 {
429
+ return "" , "" , errors .New ("empty protocol" )
430
+ }
431
+
432
+ // First character must be alpha
433
+ if ! ((protocolPart [0 ] >= 'A' && protocolPart [0 ] <= 'Z' ) ||
434
+ (protocolPart [0 ] >= 'a' && protocolPart [0 ] <= 'z' )) {
435
+ return "" , "" , errors .New ("protocol must start with a letter" )
436
+ }
437
+
438
+ // Rest can be alphanumeric, +, -, or .
439
+ for i := 1 ; i < len (protocolPart ); i ++ {
440
+ c := protocolPart [i ]
441
+ if ! ((c >= 'A' && c <= 'Z' ) || (c >= 'a' && c <= 'z' ) ||
442
+ (c >= '0' && c <= '9' ) || c == '+' || c == '-' || c == '.' ) {
443
+ return "" , "" , fmt .Errorf ("invalid character in protocol: %c" , c )
444
+ }
445
+ }
446
+
447
+ return protocol (protocolPart ), rest , nil
355
448
}
356
449
450
+ // No protocol found
451
+ return "" , input , nil
452
+ }
453
+
454
+ type port uint16
455
+
456
+ func parsePort (input string ) (port , string , error ) {
457
+ if input == "" {
458
+ return 0 , "" , nil
459
+ }
460
+
461
+ // Port must start with ':'
462
+ if input [0 ] != ':' {
463
+ return 0 , input , nil
464
+ }
465
+
466
+ // Find the end of the port number
467
+ i := 1
468
+ for i < len (input ) && input [i ] >= '0' && input [i ] <= '9' {
469
+ i ++
470
+ }
471
+
472
+ // No digits found after ':'
473
+ if i == 1 {
474
+ return 0 , "" , errors .New ("expected port number after ':'" )
475
+ }
476
+
477
+ portStr := input [1 :i ]
478
+ rest := input [i :]
479
+
480
+ // Convert to uint16 (port range is 0-65535)
481
+ portNum := 0
482
+ for _ , digit := range portStr {
483
+ portNum = portNum * 10 + int (digit - '0' )
484
+ if portNum > 65535 {
485
+ return 0 , "" , errors .New ("port number too large (max 65535)" )
486
+ }
487
+ }
488
+
489
+ return port (portNum ), rest , nil
490
+ }
491
+
492
+ // Matches checks if the rule matches the given method and URL using wildcard patterns
493
+ func (re * Engine ) matches (r Rule , method , url string ) bool {
357
494
return true
358
495
}
0 commit comments