@@ -17,47 +17,52 @@ type Rule struct {
1717 // The path segments of the url
1818 // nil means all paths allowed
1919 // a path segment of `*` acts as a wild card.
20- Path []string
20+ PathPattern []segmentPattern
2121
2222 // The labels of the host, i.e. ["google", "com"]
23- // nil means no hosts allowed
24- // subdomains automatically match
25- Host []string
23+ // nil means all hosts allowed
24+ // A label of `*` acts as a wild card.
25+ HostPattern []labelPattern
2626
2727 // The allowed http methods
2828 // nil means all methods allowed
29- Methods map [string ]struct {}
29+ MethodPatterns map [methodPattern ]struct {}
3030
3131 // Raw rule string for logging
32- Raw string
32+ Raw string
3333}
3434
35- type httpToken string
35+ type methodPattern string
36+
37+ // An asterisk is treated as matching any method
38+ func (t methodPattern ) matches (input string ) bool {
39+ return t == "*" || string (t ) == input
40+ }
3641
3742// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by
3843// various systems.
3944// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6
40- func parseHTTPToken (token string ) (httpToken , string , error ) {
45+ func parseMethodPattern (token string ) (methodPattern , string , error ) {
4146 if token == "" {
4247 return "" , "" , errors .New ("expected http token, got empty string" )
4348 }
44- return doParseHTTPToken (token , nil )
49+ return doParseMethodPattern (token , nil )
4550}
4651
47- func doParseHTTPToken (token string , acc []byte ) (httpToken , string , error ) {
52+ func doParseMethodPattern (token string , acc []byte ) (methodPattern , string , error ) {
4853 // BASE CASE: if the token passed in is empty, we're done parsing
4954 if token == "" {
50- return httpToken (acc ), "" , nil
55+ return methodPattern (acc ), "" , nil
5156 }
5257
5358 // If the next byte in the string is not a valid http token character, we're done parsing.
5459 if ! isHTTPTokenChar (token [0 ]) {
55- return httpToken (acc ), token , nil
60+ return methodPattern (acc ), token , nil
5661 }
5762
5863 // The next character is valid, so the http token continues
5964 acc = append (acc , token [0 ])
60- return doParseHTTPToken (token [1 :], acc )
65+ return doParseMethodPattern (token [1 :], acc )
6166}
6267
6368// The valid characters that can be in an http token (like the lexer/parser kind of token).
@@ -85,16 +90,16 @@ func isHTTPTokenChar(c byte) bool {
8590// Represents a valid host.
8691// https://datatracker.ietf.org/doc/html/rfc952
8792// https://datatracker.ietf.org/doc/html/rfc1123#page-13
88- func parseHost (input string ) (host []label , rest string , err error ) {
93+ func parseHostPattern (input string ) (host []labelPattern , rest string , err error ) {
8994 rest = input
90- var label label
95+ var label labelPattern
9196
9297 if input == "" {
9398 return nil , "" , errors .New ("expected host, got empty string" )
9499 }
95100
96101 // There should be at least one label.
97- label , rest , err = parseLabel (rest )
102+ label , rest , err = parseLabelPattern (rest )
98103 if err != nil {
99104 return nil , "" , err
100105 }
@@ -108,7 +113,7 @@ func parseHost(input string) (host []label, rest string, err error) {
108113 break
109114 }
110115
111- label , rest , err = parseLabel (rest )
116+ label , rest , err = parseLabelPattern (rest )
112117 if err != nil {
113118 return nil , "" , err
114119 }
@@ -119,9 +124,9 @@ func parseHost(input string) (host []label, rest string, err error) {
119124}
120125
121126// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`.
122- type label string
127+ type labelPattern string
123128
124- func parseLabel (rest string ) (label , string , error ) {
129+ func parseLabelPattern (rest string ) (labelPattern , string , error ) {
125130 if rest == "" {
126131 return "" , "" , errors .New ("expected label, got empty string" )
127132 }
@@ -141,7 +146,7 @@ func parseLabel(rest string) (label, string, error) {
141146 return "" , "" , fmt .Errorf ("invalid label: %s" , rest [:i ])
142147 }
143148
144- return label (rest [:i ]), rest [i :], nil
149+ return labelPattern (rest [:i ]), rest [i :], nil
145150}
146151
147152func isValidLabelChar (c byte ) bool {
@@ -163,12 +168,12 @@ func isValidLabelChar(c byte) bool {
163168 }
164169}
165170
166- func parsePath (input string ) ([]segment , string , error ) {
171+ func parsePathPattern (input string ) ([]segmentPattern , string , error ) {
167172 if input == "" {
168173 return nil , "" , nil
169174 }
170175
171- var segments []segment
176+ var segments []segmentPattern
172177 rest := input
173178
174179 // If the path doesn't start with '/', it's not a valid absolute path
@@ -185,12 +190,12 @@ func parsePath(input string) ([]segment, string, error) {
185190 }
186191
187192 // Parse the next segment
188- seg , remaining , err := parsePathSegment (rest )
193+ seg , remaining , err := parsePathSegmentPattern (rest )
189194 if err != nil {
190195 return nil , "" , err
191196 }
192197
193- // If we got an empty segment and there's still input,
198+ // If we got an empty segment and there's still input,
194199 // it means we hit an invalid character
195200 if seg == "" && remaining != "" {
196201 break
@@ -208,18 +213,18 @@ func parsePath(input string) ([]segment, string, error) {
208213 return segments , rest , nil
209214}
210215
211- // Represents a valid url path segment .
212- type segment string
216+ // Represents a valid url path segmentPattern .
217+ type segmentPattern string
213218
214- func parsePathSegment (input string ) (segment , string , error ) {
219+ func parsePathSegmentPattern (input string ) (segmentPattern , string , error ) {
215220 if input == "" {
216221 return "" , "" , nil
217222 }
218223
219224 var i int
220225 for i = 0 ; i < len (input ); i ++ {
221226 c := input [i ]
222-
227+
223228 // Check for percent-encoded characters (%XX)
224229 if c == '%' {
225230 if i + 2 >= len (input ) || ! isHexDigit (input [i + 1 ]) || ! isHexDigit (input [i + 2 ]) {
@@ -228,14 +233,14 @@ func parsePathSegment(input string) (segment, string, error) {
228233 i += 2
229234 continue
230235 }
231-
236+
232237 // Check for valid pchar characters
233238 if ! isPChar (c ) {
234239 break
235240 }
236241 }
237242
238- return segment (input [:i ]), input [i :], nil
243+ return segmentPattern (input [:i ]), input [i :], nil
239244}
240245
241246// isUnreserved returns true if the character is unreserved per RFC 3986
@@ -286,8 +291,72 @@ func parseKey(rule string) (string, string, error) {
286291 return "" , "" , errors .New ("expected key" )
287292}
288293
289- func parseAllowRule (string ) (Rule , error ) {
290- return Rule {}, nil
294+ func parseAllowRule (ruleStr string ) (Rule , error ) {
295+ rule := Rule {
296+ Raw : ruleStr ,
297+ }
298+
299+ rest := ruleStr
300+
301+ for rest != "" {
302+ // Parse the key
303+ key , valueRest , err := parseKey (rest )
304+ if err != nil {
305+ return Rule {}, fmt .Errorf ("failed to parse key: %v" , err )
306+ }
307+
308+ // Parse the value based on the key type
309+ switch key {
310+ case "method" :
311+ token , remaining , err := parseMethodPattern (valueRest )
312+ if err != nil {
313+ return Rule {}, fmt .Errorf ("failed to parse method: %v" , err )
314+ }
315+
316+ // Initialize Methods map if needed
317+ if rule .MethodPatterns == nil {
318+ rule .MethodPatterns = make (map [methodPattern ]struct {})
319+ }
320+ rule .MethodPatterns [token ] = struct {}{}
321+ rest = remaining
322+
323+ case "domain" :
324+ hostLabels , remaining , err := parseHostPattern (valueRest )
325+ if err != nil {
326+ return Rule {}, fmt .Errorf ("failed to parse domain: %v" , err )
327+ }
328+
329+ // Convert labels to strings in reverse order (TLD first)
330+ rule .HostPattern = make ([]labelPattern , len (hostLabels ))
331+ for i , label := range hostLabels {
332+ rule .HostPattern [len (hostLabels )- 1 - i ] = label
333+ }
334+ rest = remaining
335+
336+ case "path" :
337+ segments , remaining , err := parsePathPattern (valueRest )
338+ if err != nil {
339+ return Rule {}, fmt .Errorf ("failed to parse path: %v" , err )
340+ }
341+
342+ // Convert segments to strings
343+ rule .PathPattern = make ([]segmentPattern , len (segments ))
344+ for i , segment := range segments {
345+ rule .PathPattern [i ] = segment
346+ }
347+ rest = remaining
348+
349+ default :
350+ return Rule {}, fmt .Errorf ("unknown key: %s" , key )
351+ }
352+
353+ // Skip whitespace or comma separators
354+ for rest != "" && (rest [0 ] == ' ' || rest [0 ] == '\t' || rest [0 ] == ',' ) {
355+ rest = rest [1 :]
356+ }
357+ }
358+
359+ return rule , nil
291360}
292361
293362// ParseAllowSpecs parses a slice of --allow specs into allow Rules.
@@ -342,17 +411,85 @@ func (re *Engine) Evaluate(method, url string) Result {
342411 }
343412}
344413
345- // Matches checks if the rule matches the given method and URL using wildcard patterns
346- func ( re * Engine ) matches ( r Rule , method , url string ) bool {
347- // If the rule doesn't have any method filters, don't restrict the allowed methods
348- if r . Methods == nil {
349- return true
414+ type protocol string
415+
416+ func parseProtocol ( input string ) ( protocol , string , error ) {
417+ if input == "" {
418+ return "" , "" , errors . New ( "expected protocol, got empty string" )
350419 }
351420
352- // If the rule has method filters and the provided method is not one of them, block the request.
353- if _ , methodIsAllowed := r .Methods [method ]; ! methodIsAllowed {
354- return false
421+ // Look for "://" separator
422+ if idx := strings .Index (input , "://" ); idx > 0 {
423+ protocolPart := input [:idx ]
424+ rest := input [idx + 3 :]
425+
426+ // Validate protocol characters (scheme per RFC 3986)
427+ // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
428+ if len (protocolPart ) == 0 {
429+ return "" , "" , errors .New ("empty protocol" )
430+ }
431+
432+ // First character must be alpha
433+ if ! ((protocolPart [0 ] >= 'A' && protocolPart [0 ] <= 'Z' ) ||
434+ (protocolPart [0 ] >= 'a' && protocolPart [0 ] <= 'z' )) {
435+ return "" , "" , errors .New ("protocol must start with a letter" )
436+ }
437+
438+ // Rest can be alphanumeric, +, -, or .
439+ for i := 1 ; i < len (protocolPart ); i ++ {
440+ c := protocolPart [i ]
441+ if ! ((c >= 'A' && c <= 'Z' ) || (c >= 'a' && c <= 'z' ) ||
442+ (c >= '0' && c <= '9' ) || c == '+' || c == '-' || c == '.' ) {
443+ return "" , "" , fmt .Errorf ("invalid character in protocol: %c" , c )
444+ }
445+ }
446+
447+ return protocol (protocolPart ), rest , nil
355448 }
356449
450+ // No protocol found
451+ return "" , input , nil
452+ }
453+
454+ type port uint16
455+
456+ func parsePort (input string ) (port , string , error ) {
457+ if input == "" {
458+ return 0 , "" , nil
459+ }
460+
461+ // Port must start with ':'
462+ if input [0 ] != ':' {
463+ return 0 , input , nil
464+ }
465+
466+ // Find the end of the port number
467+ i := 1
468+ for i < len (input ) && input [i ] >= '0' && input [i ] <= '9' {
469+ i ++
470+ }
471+
472+ // No digits found after ':'
473+ if i == 1 {
474+ return 0 , "" , errors .New ("expected port number after ':'" )
475+ }
476+
477+ portStr := input [1 :i ]
478+ rest := input [i :]
479+
480+ // Convert to uint16 (port range is 0-65535)
481+ portNum := 0
482+ for _ , digit := range portStr {
483+ portNum = portNum * 10 + int (digit - '0' )
484+ if portNum > 65535 {
485+ return 0 , "" , errors .New ("port number too large (max 65535)" )
486+ }
487+ }
488+
489+ return port (portNum ), rest , nil
490+ }
491+
492+ // Matches checks if the rule matches the given method and URL using wildcard patterns
493+ func (re * Engine ) matches (r Rule , method , url string ) bool {
357494 return true
358495}
0 commit comments