Skip to content

Commit e8bc494

Browse files
committed
add http token parsing
1 parent d628bfc commit e8bc494

File tree

2 files changed

+204
-141
lines changed

2 files changed

+204
-141
lines changed

rules/rules.go

Lines changed: 78 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package rules
22

33
import (
4+
"errors"
45
"fmt"
56
"log/slog"
6-
"strings"
77
)
88

99
type Evaluator interface {
@@ -12,16 +12,84 @@ type Evaluator interface {
1212

1313
// Rule represents an allow rule with optional HTTP method restrictions
1414
type Rule struct {
15-
Pattern string // wildcard pattern for matching
16-
Methods map[string]bool // nil means all methods allowed
17-
Raw string // rule string for logging
15+
16+
// The path segments of the url
17+
// nil means all paths allowed
18+
// a path segment of `*` acts as a wild card.
19+
Path []string
20+
21+
// The labels of the host, i.e. ["google", "com"]
22+
// nil means no hosts allowed
23+
// subdomains automatically match
24+
Host []string
25+
26+
// The allowed http methods
27+
// nil means all methods allowed
28+
Methods map[string]struct{}
29+
30+
// Raw rule string for logging
31+
Raw string
32+
}
33+
34+
type httpToken string
35+
36+
// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by
37+
// various systems.
38+
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6
39+
func parseHTTPToken(token string) (httpToken, string, error) {
40+
if token == "" {
41+
return "", "", errors.New("expected http token, got empty string")
42+
}
43+
return doParseHTTPToken(token, nil)
44+
}
45+
46+
func doParseHTTPToken(token string, acc []byte) (httpToken, string, error) {
47+
// BASE CASE: if the token passed in is empty, we're done parsing
48+
if token == "" {
49+
return httpToken(acc), "", nil
50+
}
51+
52+
// If the next byte in the string is not a valid http token character, we're done parsing.
53+
if !isHTTPTokenChar(token[0]) {
54+
return httpToken(acc), token, nil
55+
}
56+
57+
// The next character is valid, so the http token continues
58+
acc = append(acc, token[0])
59+
return doParseHTTPToken(token[1:], acc)
60+
}
61+
62+
// The valid characters that can be in an http token (like the lexer/parser kind of token).
63+
func isHTTPTokenChar(c byte) bool {
64+
switch {
65+
// Alpha numeric is fine.
66+
case c >= 'A' && c <= 'Z':
67+
return true
68+
case c >= 'a' && c <= 'z':
69+
return true
70+
case c >= '0' && c <= '9':
71+
return true
72+
73+
// These special characters are also allowed unbelievably.
74+
case c == '!' || c == '#' || c == '$' || c == '%' || c == '&' ||
75+
c == '\'' || c == '*' || c == '+' || c == '-' || c == '.' ||
76+
c == '^' || c == '_' || c == '`' || c == '|' || c == '~':
77+
return true
78+
79+
default:
80+
return false
81+
}
82+
}
83+
84+
func parseAllowRule(string) (Rule, error) {
85+
return Rule{}, nil
1886
}
1987

2088
// ParseAllowSpecs parses a slice of --allow specs into allow Rules.
2189
func ParseAllowSpecs(allowStrings []string) ([]Rule, error) {
2290
var out []Rule
2391
for _, s := range allowStrings {
24-
r, err := newAllowRule(s)
92+
r, err := parseAllowRule(s)
2593
if err != nil {
2694
return nil, fmt.Errorf("failed to parse allow '%s': %v", s, err)
2795
}
@@ -71,142 +139,15 @@ func (re *Engine) Evaluate(method, url string) Result {
71139

72140
// Matches checks if the rule matches the given method and URL using wildcard patterns
73141
func (re *Engine) matches(r Rule, method, url string) bool {
74-
// Check method if specified
75-
if r.Methods != nil && !r.Methods[strings.ToUpper(method)] {
76-
return false
77-
}
78-
79-
// Check URL pattern using wildcard matching
80-
// Try exact match first
81-
if wildcardMatch(r.Pattern, url) {
82-
return true
83-
}
84-
85-
// If pattern doesn't start with protocol, try matching against the URL without protocol
86-
if !strings.HasPrefix(r.Pattern, "http://") && !strings.HasPrefix(r.Pattern, "https://") {
87-
// Extract domain and path from URL
88-
urlWithoutProtocol := url
89-
if strings.HasPrefix(url, "https://") {
90-
urlWithoutProtocol = url[8:] // Remove "https://"
91-
} else if strings.HasPrefix(url, "http://") {
92-
urlWithoutProtocol = url[7:] // Remove "http://"
93-
}
94-
95-
// Try matching against URL without protocol
96-
if wildcardMatch(r.Pattern, urlWithoutProtocol) {
97-
return true
98-
}
99-
100-
// Also try matching just the domain part
101-
domainEnd := strings.Index(urlWithoutProtocol, "/")
102-
if domainEnd > 0 {
103-
domain := urlWithoutProtocol[:domainEnd]
104-
if wildcardMatch(r.Pattern, domain) {
105-
return true
106-
}
107-
} else {
108-
// No path, just domain
109-
if wildcardMatch(r.Pattern, urlWithoutProtocol) {
110-
return true
111-
}
112-
}
113-
}
114-
115-
return false
116-
}
117-
118-
// wildcardMatch performs wildcard pattern matching
119-
// Supports * (matches any sequence of characters)
120-
func wildcardMatch(pattern, text string) bool {
121-
pattern = strings.ToLower(pattern)
122-
text = strings.ToLower(text)
123-
124-
// Handle simple case
125-
if pattern == "*" {
142+
// If the rule doesn't have any method filters, don't restrict the allowed methods
143+
if r.Methods == nil {
126144
return true
127145
}
128146

129-
// Split pattern by '*' and check each part exists in order
130-
parts := strings.Split(pattern, "*")
131-
132-
// If no wildcards, must be exact match
133-
if len(parts) == 1 {
134-
return pattern == text
135-
}
136-
137-
textPos := 0
138-
for i, part := range parts {
139-
if part == "" {
140-
continue // Skip empty parts from consecutive '*'
141-
}
142-
143-
if i == 0 {
144-
// First part must be at the beginning
145-
if !strings.HasPrefix(text, part) {
146-
return false
147-
}
148-
textPos = len(part)
149-
} else if i == len(parts)-1 {
150-
// Last part must be at the end
151-
if !strings.HasSuffix(text[textPos:], part) {
152-
return false
153-
}
154-
} else {
155-
// Middle parts must exist in order
156-
idx := strings.Index(text[textPos:], part)
157-
if idx == -1 {
158-
return false
159-
}
160-
textPos += idx + len(part)
161-
}
147+
// If the rule has method filters and the provided method is not one of them, block the request.
148+
if _, methodIsAllowed := r.Methods[method]; !methodIsAllowed {
149+
return false
162150
}
163151

164152
return true
165153
}
166-
167-
// newAllowRule creates an allow Rule from a spec string used by --allow.
168-
// Supported formats:
169-
//
170-
// "pattern" -> allow all methods to pattern
171-
// "GET,HEAD pattern" -> allow only listed methods to pattern
172-
func newAllowRule(spec string) (Rule, error) {
173-
s := strings.TrimSpace(spec)
174-
if s == "" {
175-
return Rule{}, fmt.Errorf("invalid allow spec: empty")
176-
}
177-
178-
var methods map[string]bool
179-
pattern := s
180-
181-
// Detect optional leading methods list separated by commas and a space before pattern
182-
// e.g., "GET,HEAD github.com"
183-
if idx := strings.IndexFunc(s, func(r rune) bool { return r == ' ' || r == '\t' }); idx > 0 {
184-
left := strings.TrimSpace(s[:idx])
185-
right := strings.TrimSpace(s[idx:])
186-
// methods part is valid if it only contains letters and commas
187-
valid := left != "" && strings.IndexFunc(left, func(r rune) bool {
188-
return r != ',' && (r < 'A' || r > 'Z') && (r < 'a' || r > 'z')
189-
}) == -1
190-
if valid {
191-
methods = make(map[string]bool)
192-
for _, m := range strings.Split(left, ",") {
193-
m = strings.TrimSpace(m)
194-
if m == "" {
195-
continue
196-
}
197-
methods[strings.ToUpper(m)] = true
198-
}
199-
pattern = right
200-
}
201-
}
202-
203-
if pattern == "" {
204-
return Rule{}, fmt.Errorf("invalid allow spec: missing pattern")
205-
}
206-
207-
return Rule{
208-
Pattern: pattern,
209-
Methods: methods,
210-
Raw: "allow " + spec,
211-
}, nil
212-
}

rules/rules_test.go

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,130 @@ package rules
22

33
import "testing"
44

5-
// Stub test file - tests removed
6-
func TestStub(t *testing.T) {
7-
// This is a stub test
8-
t.Skip("stub test file")
5+
func TestParseHTTPToken(t *testing.T) {
6+
tests := []struct {
7+
name string
8+
input string
9+
expectedToken httpToken
10+
expectedRemain string
11+
expectError bool
12+
}{
13+
{
14+
name: "empty string",
15+
input: "",
16+
expectedToken: "",
17+
expectedRemain: "",
18+
expectError: true,
19+
},
20+
{
21+
name: "simple method GET",
22+
input: "GET",
23+
expectedToken: "GET",
24+
expectedRemain: "",
25+
expectError: false,
26+
},
27+
{
28+
name: "simple method POST",
29+
input: "POST",
30+
expectedToken: "POST",
31+
expectedRemain: "",
32+
expectError: false,
33+
},
34+
{
35+
name: "method with trailing space",
36+
input: "GET ",
37+
expectedToken: "GET",
38+
expectedRemain: " ",
39+
expectError: false,
40+
},
41+
{
42+
name: "method with trailing content",
43+
input: "POST /api/users",
44+
expectedToken: "POST",
45+
expectedRemain: " /api/users",
46+
expectError: false,
47+
},
48+
{
49+
name: "all valid special characters",
50+
input: "!#$%&'*+-.^_`|~",
51+
expectedToken: "!#$%&'*+-.^_`|~",
52+
expectedRemain: "",
53+
expectError: false,
54+
},
55+
{
56+
name: "alphanumeric token",
57+
input: "ABC123xyz",
58+
expectedToken: "ABC123xyz",
59+
expectedRemain: "",
60+
expectError: false,
61+
},
62+
{
63+
name: "token with invalid character",
64+
input: "GET@test",
65+
expectedToken: "GET",
66+
expectedRemain: "@test",
67+
expectError: false,
68+
},
69+
{
70+
name: "token starting with invalid character",
71+
input: "@GET",
72+
expectedToken: "",
73+
expectedRemain: "@GET",
74+
expectError: false,
75+
},
76+
{
77+
name: "single character token",
78+
input: "A",
79+
expectedToken: "A",
80+
expectedRemain: "",
81+
expectError: false,
82+
},
83+
{
84+
name: "token with underscore and dash",
85+
input: "CUSTOM-METHOD_1",
86+
expectedToken: "CUSTOM-METHOD_1",
87+
expectedRemain: "",
88+
expectError: false,
89+
},
90+
{
91+
name: "token stops at comma",
92+
input: "GET,POST",
93+
expectedToken: "GET",
94+
expectedRemain: ",POST",
95+
expectError: false,
96+
},
97+
{
98+
name: "token stops at semicolon",
99+
input: "GET;charset=utf-8",
100+
expectedToken: "GET",
101+
expectedRemain: ";charset=utf-8",
102+
expectError: false,
103+
},
104+
}
105+
106+
for _, tt := range tests {
107+
t.Run(tt.name, func(t *testing.T) {
108+
token, remain, err := parseHTTPToken(tt.input)
109+
110+
if tt.expectError {
111+
if err == nil {
112+
t.Errorf("expected error but got none")
113+
}
114+
return
115+
}
116+
117+
if err != nil {
118+
t.Errorf("unexpected error: %v", err)
119+
return
120+
}
121+
122+
if token != tt.expectedToken {
123+
t.Errorf("expected token %q, got %q", tt.expectedToken, token)
124+
}
125+
126+
if remain != tt.expectedRemain {
127+
t.Errorf("expected remaining %q, got %q", tt.expectedRemain, remain)
128+
}
129+
})
130+
}
9131
}

0 commit comments

Comments
 (0)