Skip to content

Commit ec944b6

Browse files
committed
update names to reflect we're pattern parsing
1 parent 3fb4ef0 commit ec944b6

File tree

2 files changed

+369
-77
lines changed

2 files changed

+369
-77
lines changed

rules/rules.go

Lines changed: 177 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,52 @@ type Rule struct {
1717
// The path segments of the url
1818
// nil means all paths allowed
1919
// a path segment of `*` acts as a wild card.
20-
Path []string
20+
PathPattern []segmentPattern
2121

2222
// The labels of the host, i.e. ["google", "com"]
23-
// nil means no hosts allowed
24-
// subdomains automatically match
25-
Host []string
23+
// nil means all hosts allowed
24+
// A label of `*` acts as a wild card.
25+
HostPattern []labelPattern
2626

2727
// The allowed http methods
2828
// nil means all methods allowed
29-
Methods map[string]struct{}
29+
MethodPatterns map[methodPattern]struct{}
3030

3131
// Raw rule string for logging
32-
Raw string
32+
Raw string
3333
}
3434

35-
type httpToken string
35+
type methodPattern string
36+
37+
// An asterisk is treated as matching any method
38+
func (t methodPattern) matches(input string) bool {
39+
return t == "*" || string(t) == input
40+
}
3641

3742
// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by
3843
// various systems.
3944
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6
40-
func parseHTTPToken(token string) (httpToken, string, error) {
45+
func parseMethodPattern(token string) (methodPattern, string, error) {
4146
if token == "" {
4247
return "", "", errors.New("expected http token, got empty string")
4348
}
44-
return doParseHTTPToken(token, nil)
49+
return doParseMethodPattern(token, nil)
4550
}
4651

47-
func doParseHTTPToken(token string, acc []byte) (httpToken, string, error) {
52+
func doParseMethodPattern(token string, acc []byte) (methodPattern, string, error) {
4853
// BASE CASE: if the token passed in is empty, we're done parsing
4954
if token == "" {
50-
return httpToken(acc), "", nil
55+
return methodPattern(acc), "", nil
5156
}
5257

5358
// If the next byte in the string is not a valid http token character, we're done parsing.
5459
if !isHTTPTokenChar(token[0]) {
55-
return httpToken(acc), token, nil
60+
return methodPattern(acc), token, nil
5661
}
5762

5863
// The next character is valid, so the http token continues
5964
acc = append(acc, token[0])
60-
return doParseHTTPToken(token[1:], acc)
65+
return doParseMethodPattern(token[1:], acc)
6166
}
6267

6368
// The valid characters that can be in an http token (like the lexer/parser kind of token).
@@ -85,16 +90,16 @@ func isHTTPTokenChar(c byte) bool {
8590
// Represents a valid host.
8691
// https://datatracker.ietf.org/doc/html/rfc952
8792
// https://datatracker.ietf.org/doc/html/rfc1123#page-13
88-
func parseHost(input string) (host []label, rest string, err error) {
93+
func parseHostPattern(input string) (host []labelPattern, rest string, err error) {
8994
rest = input
90-
var label label
95+
var label labelPattern
9196

9297
if input == "" {
9398
return nil, "", errors.New("expected host, got empty string")
9499
}
95100

96101
// There should be at least one label.
97-
label, rest, err = parseLabel(rest)
102+
label, rest, err = parseLabelPattern(rest)
98103
if err != nil {
99104
return nil, "", err
100105
}
@@ -108,7 +113,7 @@ func parseHost(input string) (host []label, rest string, err error) {
108113
break
109114
}
110115

111-
label, rest, err = parseLabel(rest)
116+
label, rest, err = parseLabelPattern(rest)
112117
if err != nil {
113118
return nil, "", err
114119
}
@@ -119,9 +124,9 @@ func parseHost(input string) (host []label, rest string, err error) {
119124
}
120125

121126
// Represents a valid label in a hostname. For example, wobble in `wib-ble.wobble.com`.
122-
type label string
127+
type labelPattern string
123128

124-
func parseLabel(rest string) (label, string, error) {
129+
func parseLabelPattern(rest string) (labelPattern, string, error) {
125130
if rest == "" {
126131
return "", "", errors.New("expected label, got empty string")
127132
}
@@ -141,7 +146,7 @@ func parseLabel(rest string) (label, string, error) {
141146
return "", "", fmt.Errorf("invalid label: %s", rest[:i])
142147
}
143148

144-
return label(rest[:i]), rest[i:], nil
149+
return labelPattern(rest[:i]), rest[i:], nil
145150
}
146151

147152
func isValidLabelChar(c byte) bool {
@@ -163,12 +168,12 @@ func isValidLabelChar(c byte) bool {
163168
}
164169
}
165170

166-
func parsePath(input string) ([]segment, string, error) {
171+
func parsePathPattern(input string) ([]segmentPattern, string, error) {
167172
if input == "" {
168173
return nil, "", nil
169174
}
170175

171-
var segments []segment
176+
var segments []segmentPattern
172177
rest := input
173178

174179
// If the path doesn't start with '/', it's not a valid absolute path
@@ -185,12 +190,12 @@ func parsePath(input string) ([]segment, string, error) {
185190
}
186191

187192
// Parse the next segment
188-
seg, remaining, err := parsePathSegment(rest)
193+
seg, remaining, err := parsePathSegmentPattern(rest)
189194
if err != nil {
190195
return nil, "", err
191196
}
192197

193-
// If we got an empty segment and there's still input,
198+
// If we got an empty segment and there's still input,
194199
// it means we hit an invalid character
195200
if seg == "" && remaining != "" {
196201
break
@@ -208,18 +213,18 @@ func parsePath(input string) ([]segment, string, error) {
208213
return segments, rest, nil
209214
}
210215

211-
// Represents a valid url path segment.
212-
type segment string
216+
// Represents a valid url path segmentPattern.
217+
type segmentPattern string
213218

214-
func parsePathSegment(input string) (segment, string, error) {
219+
func parsePathSegmentPattern(input string) (segmentPattern, string, error) {
215220
if input == "" {
216221
return "", "", nil
217222
}
218223

219224
var i int
220225
for i = 0; i < len(input); i++ {
221226
c := input[i]
222-
227+
223228
// Check for percent-encoded characters (%XX)
224229
if c == '%' {
225230
if i+2 >= len(input) || !isHexDigit(input[i+1]) || !isHexDigit(input[i+2]) {
@@ -228,14 +233,14 @@ func parsePathSegment(input string) (segment, string, error) {
228233
i += 2
229234
continue
230235
}
231-
236+
232237
// Check for valid pchar characters
233238
if !isPChar(c) {
234239
break
235240
}
236241
}
237242

238-
return segment(input[:i]), input[i:], nil
243+
return segmentPattern(input[:i]), input[i:], nil
239244
}
240245

241246
// isUnreserved returns true if the character is unreserved per RFC 3986
@@ -286,8 +291,72 @@ func parseKey(rule string) (string, string, error) {
286291
return "", "", errors.New("expected key")
287292
}
288293

289-
func parseAllowRule(string) (Rule, error) {
290-
return Rule{}, nil
294+
func parseAllowRule(ruleStr string) (Rule, error) {
295+
rule := Rule{
296+
Raw: ruleStr,
297+
}
298+
299+
rest := ruleStr
300+
301+
for rest != "" {
302+
// Parse the key
303+
key, valueRest, err := parseKey(rest)
304+
if err != nil {
305+
return Rule{}, fmt.Errorf("failed to parse key: %v", err)
306+
}
307+
308+
// Parse the value based on the key type
309+
switch key {
310+
case "method":
311+
token, remaining, err := parseMethodPattern(valueRest)
312+
if err != nil {
313+
return Rule{}, fmt.Errorf("failed to parse method: %v", err)
314+
}
315+
316+
// Initialize Methods map if needed
317+
if rule.MethodPatterns == nil {
318+
rule.MethodPatterns = make(map[methodPattern]struct{})
319+
}
320+
rule.MethodPatterns[token] = struct{}{}
321+
rest = remaining
322+
323+
case "domain":
324+
hostLabels, remaining, err := parseHostPattern(valueRest)
325+
if err != nil {
326+
return Rule{}, fmt.Errorf("failed to parse domain: %v", err)
327+
}
328+
329+
// Convert labels to strings in reverse order (TLD first)
330+
rule.HostPattern = make([]labelPattern, len(hostLabels))
331+
for i, label := range hostLabels {
332+
rule.HostPattern[len(hostLabels)-1-i] = label
333+
}
334+
rest = remaining
335+
336+
case "path":
337+
segments, remaining, err := parsePathPattern(valueRest)
338+
if err != nil {
339+
return Rule{}, fmt.Errorf("failed to parse path: %v", err)
340+
}
341+
342+
// Convert segments to strings
343+
rule.PathPattern = make([]segmentPattern, len(segments))
344+
for i, segment := range segments {
345+
rule.PathPattern[i] = segment
346+
}
347+
rest = remaining
348+
349+
default:
350+
return Rule{}, fmt.Errorf("unknown key: %s", key)
351+
}
352+
353+
// Skip whitespace or comma separators
354+
for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == ',') {
355+
rest = rest[1:]
356+
}
357+
}
358+
359+
return rule, nil
291360
}
292361

293362
// ParseAllowSpecs parses a slice of --allow specs into allow Rules.
@@ -342,17 +411,85 @@ func (re *Engine) Evaluate(method, url string) Result {
342411
}
343412
}
344413

345-
// Matches checks if the rule matches the given method and URL using wildcard patterns
346-
func (re *Engine) matches(r Rule, method, url string) bool {
347-
// If the rule doesn't have any method filters, don't restrict the allowed methods
348-
if r.Methods == nil {
349-
return true
414+
type protocol string
415+
416+
func parseProtocol(input string) (protocol, string, error) {
417+
if input == "" {
418+
return "", "", errors.New("expected protocol, got empty string")
350419
}
351420

352-
// If the rule has method filters and the provided method is not one of them, block the request.
353-
if _, methodIsAllowed := r.Methods[method]; !methodIsAllowed {
354-
return false
421+
// Look for "://" separator
422+
if idx := strings.Index(input, "://"); idx > 0 {
423+
protocolPart := input[:idx]
424+
rest := input[idx+3:]
425+
426+
// Validate protocol characters (scheme per RFC 3986)
427+
// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
428+
if len(protocolPart) == 0 {
429+
return "", "", errors.New("empty protocol")
430+
}
431+
432+
// First character must be alpha
433+
if !((protocolPart[0] >= 'A' && protocolPart[0] <= 'Z') ||
434+
(protocolPart[0] >= 'a' && protocolPart[0] <= 'z')) {
435+
return "", "", errors.New("protocol must start with a letter")
436+
}
437+
438+
// Rest can be alphanumeric, +, -, or .
439+
for i := 1; i < len(protocolPart); i++ {
440+
c := protocolPart[i]
441+
if !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
442+
(c >= '0' && c <= '9') || c == '+' || c == '-' || c == '.') {
443+
return "", "", fmt.Errorf("invalid character in protocol: %c", c)
444+
}
445+
}
446+
447+
return protocol(protocolPart), rest, nil
355448
}
356449

450+
// No protocol found
451+
return "", input, nil
452+
}
453+
454+
type port uint16
455+
456+
func parsePort(input string) (port, string, error) {
457+
if input == "" {
458+
return 0, "", nil
459+
}
460+
461+
// Port must start with ':'
462+
if input[0] != ':' {
463+
return 0, input, nil
464+
}
465+
466+
// Find the end of the port number
467+
i := 1
468+
for i < len(input) && input[i] >= '0' && input[i] <= '9' {
469+
i++
470+
}
471+
472+
// No digits found after ':'
473+
if i == 1 {
474+
return 0, "", errors.New("expected port number after ':'")
475+
}
476+
477+
portStr := input[1:i]
478+
rest := input[i:]
479+
480+
// Convert to uint16 (port range is 0-65535)
481+
portNum := 0
482+
for _, digit := range portStr {
483+
portNum = portNum*10 + int(digit-'0')
484+
if portNum > 65535 {
485+
return 0, "", errors.New("port number too large (max 65535)")
486+
}
487+
}
488+
489+
return port(portNum), rest, nil
490+
}
491+
492+
// Matches checks if the rule matches the given method and URL using wildcard patterns
493+
func (re *Engine) matches(r Rule, method, url string) bool {
357494
return true
358495
}

0 commit comments

Comments
 (0)