diff --git a/README.md b/README.md index 418dc25..1910a5c 100644 --- a/README.md +++ b/README.md @@ -25,12 +25,12 @@ curl -fsSL https://raw.githubusercontent.com/coder/boundary/main/install.sh | ba ```bash # Allow only requests to github.com -boundary --allow "github.com" -- curl https://github.com +boundary --allow "domain=github.com" -- curl https://github.com # Allow full access to GitHub issues API, but only GET/HEAD elsewhere on GitHub boundary \ - --allow "github.com/api/issues/*" \ - --allow "GET,HEAD github.com" \ + --allow "domain=github.com path=/api/issues/*" \ + --allow "method=GET,HEAD domain=github.com" \ -- npm install # Default deny-all: everything is blocked unless explicitly allowed @@ -41,16 +41,20 @@ boundary -- curl https://example.com ### Format ```text ---allow "pattern" # All HTTP methods allowed ---allow "METHOD[,METHOD] pattern" # Specific methods only +--allow "key=value [key=value ...]" ``` +**Keys:** +- `method` - HTTP method(s), comma-separated (GET, POST, etc.) +- `domain` - Domain/hostname pattern +- `path` - URL path pattern + ### Examples ```bash -boundary --allow "github.com" -- git pull -boundary --allow "*.github.com" -- npm install # GitHub subdomains -boundary --allow "api.*" -- ./app # Any API domain -boundary --allow "GET,HEAD api.github.com" -- curl https://api.github.com +boundary --allow "domain=github.com" -- git pull +boundary --allow "domain=*.github.com" -- npm install # GitHub subdomains +boundary --allow "method=GET,HEAD domain=api.github.com" -- curl https://api.github.com +boundary --allow "method=POST domain=api.example.com path=/users" -- ./app ``` Wildcards: `*` matches any characters. All traffic is denied unless explicitly allowed. @@ -58,8 +62,8 @@ Wildcards: `*` matches any characters. All traffic is denied unless explicitly a ## Logging ```bash -boundary --log-level info --allow "*" -- npm install # Show all requests -boundary --log-level debug --allow "github.com" -- git pull # Debug info +boundary --log-level info --allow "method=*" -- npm install # Show all requests +boundary --log-level debug --allow "domain=github.com" -- git pull # Debug info ``` **Log Levels:** `error`, `warn` (default), `info`, `debug` @@ -70,10 +74,10 @@ When you can't or don't want to run with sudo privileges, use `--unprivileged`: ```bash # Run without network isolation (uses HTTP_PROXY/HTTPS_PROXY environment variables) -boundary --unprivileged --allow "github.com" -- npm install +boundary --unprivileged --allow "domain=github.com" -- npm install # Useful in containers or restricted environments -boundary --unprivileged --allow "*.npmjs.org" --allow "registry.npmjs.org" -- npm install +boundary --unprivileged --allow "domain=*.npmjs.org" --allow "domain=registry.npmjs.org" -- npm install ``` **Unprivileged Mode:** diff --git a/boundary.go b/boundary.go index d3e98a6..821f83b 100644 --- a/boundary.go +++ b/boundary.go @@ -11,11 +11,11 @@ import ( "github.com/coder/boundary/audit" "github.com/coder/boundary/jail" "github.com/coder/boundary/proxy" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" ) type Config struct { - RuleEngine rules.Evaluator + RuleEngine rulesengine.Engine Auditor audit.Auditor TLSConfig *tls.Config Logger *slog.Logger diff --git a/cli/cli.go b/cli/cli.go index dabaa46..f24a886 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -14,7 +14,7 @@ import ( "github.com/coder/boundary" "github.com/coder/boundary/audit" "github.com/coder/boundary/jail" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" "github.com/coder/boundary/tls" "github.com/coder/boundary/util" "github.com/coder/serpent" @@ -38,10 +38,10 @@ func NewCommand() *serpent.Command { // may be called something different when used as a subcommand / there will be a leading binary (i.e. `coder boundary` vs. `boundary`). cmd.Long += `Examples: # Allow only requests to github.com - boundary --allow "github.com" -- curl https://github.com + boundary --allow "domain=github.com" -- curl https://github.com # Monitor all requests to specific domains (allow only those) - boundary --allow "github.com/api/issues/*" --allow "GET,HEAD github.com" -- npm install + boundary --allow "domain=github.com path=/api/issues/*" --allow "method=GET,HEAD domain=github.com" -- npm install # Block everything by default (implicit)` @@ -114,14 +114,14 @@ func Run(ctx context.Context, config Config, args []string) error { } // Parse allow rules - allowRules, err := rules.ParseAllowSpecs(config.AllowStrings) + allowRules, err := rulesengine.ParseAllowSpecs(config.AllowStrings) if err != nil { logger.Error("Failed to parse allow rules", "error", err) return fmt.Errorf("failed to parse allow rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(allowRules, logger) + ruleEngine := rulesengine.NewRuleEngine(allowRules, logger) // Create auditor auditor := audit.NewLogAuditor(logger) diff --git a/e2e_tests/boundary_integration_test.go b/e2e_tests/boundary_integration_test.go index 84ff284..d966a0c 100644 --- a/e2e_tests/boundary_integration_test.go +++ b/e2e_tests/boundary_integration_test.go @@ -78,8 +78,8 @@ func TestBoundaryIntegration(t *testing.T) { // Start boundary process with sudo boundaryCmd := exec.CommandContext(ctx, "/tmp/boundary-test", - "--allow", "dev.coder.com", - "--allow", "jsonplaceholder.typicode.com", + "--allow", "domain=dev.coder.com", + "--allow", "domain=jsonplaceholder.typicode.com", "--log-level", "debug", "--", "bash", "-c", "sleep 10 && echo 'Test completed'") diff --git a/proxy/proxy.go b/proxy/proxy.go index e2aa537..fa47aae 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,12 +15,12 @@ import ( "sync/atomic" "github.com/coder/boundary/audit" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" ) // Server handles HTTP and HTTPS requests with rule-based filtering type Server struct { - ruleEngine rules.Evaluator + ruleEngine rulesengine.Engine auditor audit.Auditor logger *slog.Logger tlsConfig *tls.Config @@ -33,7 +33,7 @@ type Server struct { // Config holds configuration for the proxy server type Config struct { HTTPPort int - RuleEngine rules.Evaluator + RuleEngine rulesengine.Engine Auditor audit.Auditor Logger *slog.Logger TLSConfig *tls.Config @@ -254,8 +254,8 @@ Request: %s %s Host: %s To allow this request, restart boundary with: - --allow "%s" # Allow all methods to this host - --allow "%s %s" # Allow only %s requests to this host + --allow "domain=%s" # Allow all methods to this host + --allow "method=%s domain=%s" # Allow only %s requests to this host For more help: https://github.com/coder/boundary `, @@ -639,7 +639,7 @@ func (p *Server) constructFullURL(req *http.Request, hostname string) string { // writeBlockedResponseStreaming writes a blocked response directly to the TLS connection func (p *Server) writeBlockedResponseStreaming(tlsConn *tls.Conn, req *http.Request) { - response := fmt.Sprintf("HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n🚫 Request Blocked by Boundary\n\nRequest: %s %s\nHost: %s\n\nTo allow this request, restart boundary with:\n --allow \"%s\"\n", + response := fmt.Sprintf("HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n🚫 Request Blocked by Boundary\n\nRequest: %s %s\nHost: %s\n\nTo allow this request, restart boundary with:\n --allow \"domain=%s\"\n", req.Method, req.URL.Path, req.Host, req.Host) _, _ = tlsConn.Write([]byte(response)) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index fe61391..2128198 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/boundary/audit" - "github.com/coder/boundary/rules" + "github.com/coder/boundary/rulesengine" ) // mockAuditor is a simple mock auditor for testing @@ -35,13 +35,13 @@ func TestProxyServerBasicHTTP(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"*"}) + testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(testRules, logger) + ruleEngine := rulesengine.NewRuleEngine(testRules, logger) // Create mock auditor auditor := &mockAuditor{} @@ -116,13 +116,13 @@ func TestProxyServerBasicHTTPS(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"*"}) + testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(testRules, logger) + ruleEngine := rulesengine.NewRuleEngine(testRules, logger) // Create mock auditor auditor := &mockAuditor{} @@ -210,13 +210,13 @@ func TestProxyServerCONNECT(t *testing.T) { })) // Create test rules (allow all for testing) - testRules, err := rules.ParseAllowSpecs([]string{"*"}) + testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"}) if err != nil { t.Fatalf("Failed to parse test rules: %v", err) } // Create rule engine - ruleEngine := rules.NewRuleEngine(testRules, logger) + ruleEngine := rulesengine.NewRuleEngine(testRules, logger) // Create mock auditor auditor := &mockAuditor{} diff --git a/rules/rules.go b/rules/rules.go deleted file mode 100644 index ab64cc4..0000000 --- a/rules/rules.go +++ /dev/null @@ -1,212 +0,0 @@ -package rules - -import ( - "fmt" - "log/slog" - "strings" -) - -type Evaluator interface { - Evaluate(method, url string) Result -} - -// Rule represents an allow rule with optional HTTP method restrictions -type Rule struct { - Pattern string // wildcard pattern for matching - Methods map[string]bool // nil means all methods allowed - Raw string // rule string for logging -} - -// ParseAllowSpecs parses a slice of --allow specs into allow Rules. -func ParseAllowSpecs(allowStrings []string) ([]Rule, error) { - var out []Rule - for _, s := range allowStrings { - r, err := newAllowRule(s) - if err != nil { - return nil, fmt.Errorf("failed to parse allow '%s': %v", s, err) - } - out = append(out, r) - } - return out, nil -} - -// Engine evaluates HTTP requests against a set of rules -type Engine struct { - rules []Rule - logger *slog.Logger -} - -// NewRuleEngine creates a new rule engine -func NewRuleEngine(rules []Rule, logger *slog.Logger) *Engine { - return &Engine{ - rules: rules, - logger: logger, - } -} - -// Result contains the result of rule evaluation -type Result struct { - Allowed bool - Rule string // The rule that matched (if any) -} - -// Evaluate evaluates a request and returns both result and matching rule -func (re *Engine) Evaluate(method, url string) Result { - // Check if any allow rule matches - for _, rule := range re.rules { - if re.matches(rule, method, url) { - return Result{ - Allowed: true, - Rule: rule.Raw, - } - } - } - - // Default deny if no allow rules match - return Result{ - Allowed: false, - Rule: "", - } -} - -// Matches checks if the rule matches the given method and URL using wildcard patterns -func (re *Engine) matches(r Rule, method, url string) bool { - // Check method if specified - if r.Methods != nil && !r.Methods[strings.ToUpper(method)] { - return false - } - - // Check URL pattern using wildcard matching - // Try exact match first - if wildcardMatch(r.Pattern, url) { - return true - } - - // If pattern doesn't start with protocol, try matching against the URL without protocol - if !strings.HasPrefix(r.Pattern, "http://") && !strings.HasPrefix(r.Pattern, "https://") { - // Extract domain and path from URL - urlWithoutProtocol := url - if strings.HasPrefix(url, "https://") { - urlWithoutProtocol = url[8:] // Remove "https://" - } else if strings.HasPrefix(url, "http://") { - urlWithoutProtocol = url[7:] // Remove "http://" - } - - // Try matching against URL without protocol - if wildcardMatch(r.Pattern, urlWithoutProtocol) { - return true - } - - // Also try matching just the domain part - domainEnd := strings.Index(urlWithoutProtocol, "/") - if domainEnd > 0 { - domain := urlWithoutProtocol[:domainEnd] - if wildcardMatch(r.Pattern, domain) { - return true - } - } else { - // No path, just domain - if wildcardMatch(r.Pattern, urlWithoutProtocol) { - return true - } - } - } - - return false -} - -// wildcardMatch performs wildcard pattern matching -// Supports * (matches any sequence of characters) -func wildcardMatch(pattern, text string) bool { - pattern = strings.ToLower(pattern) - text = strings.ToLower(text) - - // Handle simple case - if pattern == "*" { - return true - } - - // Split pattern by '*' and check each part exists in order - parts := strings.Split(pattern, "*") - - // If no wildcards, must be exact match - if len(parts) == 1 { - return pattern == text - } - - textPos := 0 - for i, part := range parts { - if part == "" { - continue // Skip empty parts from consecutive '*' - } - - if i == 0 { - // First part must be at the beginning - if !strings.HasPrefix(text, part) { - return false - } - textPos = len(part) - } else if i == len(parts)-1 { - // Last part must be at the end - if !strings.HasSuffix(text[textPos:], part) { - return false - } - } else { - // Middle parts must exist in order - idx := strings.Index(text[textPos:], part) - if idx == -1 { - return false - } - textPos += idx + len(part) - } - } - - return true -} - -// newAllowRule creates an allow Rule from a spec string used by --allow. -// Supported formats: -// -// "pattern" -> allow all methods to pattern -// "GET,HEAD pattern" -> allow only listed methods to pattern -func newAllowRule(spec string) (Rule, error) { - s := strings.TrimSpace(spec) - if s == "" { - return Rule{}, fmt.Errorf("invalid allow spec: empty") - } - - var methods map[string]bool - pattern := s - - // Detect optional leading methods list separated by commas and a space before pattern - // e.g., "GET,HEAD github.com" - if idx := strings.IndexFunc(s, func(r rune) bool { return r == ' ' || r == '\t' }); idx > 0 { - left := strings.TrimSpace(s[:idx]) - right := strings.TrimSpace(s[idx:]) - // methods part is valid if it only contains letters and commas - valid := left != "" && strings.IndexFunc(left, func(r rune) bool { - return r != ',' && (r < 'A' || r > 'Z') && (r < 'a' || r > 'z') - }) == -1 - if valid { - methods = make(map[string]bool) - for _, m := range strings.Split(left, ",") { - m = strings.TrimSpace(m) - if m == "" { - continue - } - methods[strings.ToUpper(m)] = true - } - pattern = right - } - } - - if pattern == "" { - return Rule{}, fmt.Errorf("invalid allow spec: missing pattern") - } - - return Rule{ - Pattern: pattern, - Methods: methods, - Raw: "allow " + spec, - }, nil -} diff --git a/rules/rules_test.go b/rules/rules_test.go deleted file mode 100644 index eb702fe..0000000 --- a/rules/rules_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package rules - -import "testing" - -// Stub test file - tests removed -func TestStub(t *testing.T) { - // This is a stub test - t.Skip("stub test file") -} diff --git a/rulesengine/engine.go b/rulesengine/engine.go new file mode 100644 index 0000000..a9c9c58 --- /dev/null +++ b/rulesengine/engine.go @@ -0,0 +1,122 @@ +package rulesengine + +import ( + "log/slog" + neturl "net/url" + "strings" +) + +// Engine evaluates HTTP requests against a set of rules. +type Engine struct { + rules []Rule + logger *slog.Logger +} + +// NewRuleEngine creates a new rule engine +func NewRuleEngine(rules []Rule, logger *slog.Logger) Engine { + return Engine{ + rules: rules, + logger: logger, + } +} + +// Result contains the result of rule evaluation +type Result struct { + Allowed bool + Rule string // The rule that matched (if any) +} + +// Evaluate evaluates a request and returns both result and matching rule +func (re *Engine) Evaluate(method, url string) Result { + // Check if any allow rule matches + for _, rule := range re.rules { + if re.matches(rule, method, url) { + return Result{ + Allowed: true, + Rule: rule.Raw, + } + } + } + + // Default deny if no allow rules match + return Result{ + Allowed: false, + Rule: "", + } +} + +// Matches checks if the rule matches the given method and URL using wildcard patterns +func (re *Engine) matches(r Rule, method, url string) bool { + + // Check method patterns if they exist + if r.MethodPatterns != nil { + methodMatches := false + for mp := range r.MethodPatterns { + if string(mp) == method || string(mp) == "*" { + methodMatches = true + break + } + } + if !methodMatches { + re.logger.Debug("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url) + return false + } + } + + parsedUrl, err := neturl.Parse(url) + if err != nil { + re.logger.Debug("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err) + return false + } + + if r.HostPattern != nil { + // For a host pattern to match, every label has to match or be an `*`. + // Subdomains also match automatically, meaning if the pattern is "example.com" + // and the real is "api.example.com", it should match. We check this by comparing + // from the end of the actual hostname with the pattern (which is in normal order). + + labels := strings.Split(parsedUrl.Hostname(), ".") + + // If the host pattern is longer than the actual host, it's definitely not a match + if len(r.HostPattern) > len(labels) { + re.logger.Debug("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels)) + return false + } + + // Since host patterns cannot end with asterisk, we only need to handle: + // "example.com" or "*.example.com" - match from the end (allowing subdomains) + for i, lp := range r.HostPattern { + labelIndex := len(labels) - len(r.HostPattern) + i + if string(lp) != labels[labelIndex] && lp != "*" { + re.logger.Debug("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex]) + return false + } + } + } + + if r.PathPattern != nil { + segments := strings.Split(parsedUrl.Path, "/") + + // Skip the first empty segment if the path starts with "/" + if len(segments) > 0 && segments[0] == "" { + segments = segments[1:] + } + + // If the path pattern is longer than the actual path, definitely not a match + if len(r.PathPattern) > len(segments) { + re.logger.Debug("rule does not match", "reason", "path pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.PathPattern), "path_segments", len(segments)) + return false + } + + // Each segment in the pattern must be either as asterisk or match the actual path segment + for i, sp := range r.PathPattern { + if string(sp) != segments[i] && sp != "*" { + re.logger.Debug("rule does not match", "reason", "path pattern segment mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(sp), "actual", segments[i]) + return false + } + } + } + + re.logger.Debug("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url) + return true +} diff --git a/rulesengine/engine_test.go b/rulesengine/engine_test.go new file mode 100644 index 0000000..2a5137f --- /dev/null +++ b/rulesengine/engine_test.go @@ -0,0 +1,271 @@ +package rulesengine + +import ( + "log/slog" + "testing" +) + +func TestEngineMatches(t *testing.T) { + logger := slog.Default() + engine := NewRuleEngine(nil, logger) + + tests := []struct { + name string + rule Rule + method string + url string + expected bool + }{ + // Method pattern tests + { + name: "method matches exact", + rule: Rule{ + MethodPatterns: map[string]struct{}{"GET": {}}, + }, + method: "GET", + url: "https://example.com/api", + expected: true, + }, + { + name: "method does not match", + rule: Rule{ + MethodPatterns: map[string]struct{}{"POST": {}}, + }, + method: "GET", + url: "https://example.com/api", + expected: false, + }, + { + name: "method wildcard matches any", + rule: Rule{ + MethodPatterns: map[string]struct{}{"*": {}}, + }, + method: "PUT", + url: "https://example.com/api", + expected: true, + }, + { + name: "no method pattern allows all methods", + rule: Rule{ + HostPattern: []string{"example", "com"}, + }, + method: "DELETE", + url: "https://example.com/api", + expected: true, + }, + + // Host pattern tests + { + name: "host matches exact", + rule: Rule{ + HostPattern: []string{"example", "com"}, + }, + method: "GET", + url: "https://example.com/api", + expected: true, + }, + { + name: "host does not match", + rule: Rule{ + HostPattern: []string{"example", "org"}, + }, + method: "GET", + url: "https://example.com/api", + expected: false, + }, + { + name: "subdomain matches", + rule: Rule{ + HostPattern: []string{"example", "com"}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: true, + }, + { + name: "host pattern too long", + rule: Rule{ + HostPattern: []string{"v1", "api", "example", "com"}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: false, + }, + { + name: "host wildcard matches", + rule: Rule{ + HostPattern: []string{"*", "com"}, + }, + method: "GET", + url: "https://test.com/api", + expected: true, + }, + { + name: "multiple host wildcards", + rule: Rule{ + HostPattern: []string{"*", "*"}, + }, + method: "GET", + url: "https://api.example.com/users", + expected: true, + }, + + // Path pattern tests + { + name: "path matches exact", + rule: Rule{ + PathPattern: []string{"api", "users"}, + }, + method: "GET", + url: "https://example.com/api/users", + expected: true, + }, + { + name: "path does not match", + rule: Rule{ + PathPattern: []string{"api", "posts"}, + }, + method: "GET", + url: "https://example.com/api/users", + expected: false, + }, + { + name: "subpath matches", + rule: Rule{ + PathPattern: []string{"api"}, + }, + method: "GET", + url: "https://example.com/api/users/123", + expected: true, + }, + { + name: "path pattern too long", + rule: Rule{ + PathPattern: []string{"api", "v1", "users", "profile"}, + }, + method: "GET", + url: "https://example.com/api/v1/users", + expected: false, + }, + { + name: "path wildcard matches", + rule: Rule{ + PathPattern: []string{"api", "*", "profile"}, + }, + method: "GET", + url: "https://example.com/api/users/profile", + expected: true, + }, + { + name: "multiple path wildcards", + rule: Rule{ + PathPattern: []string{"*", "*"}, + }, + method: "GET", + url: "https://example.com/api/users/123", + expected: true, + }, + + // Combined pattern tests + { + name: "all patterns match", + rule: Rule{ + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "com"}, + PathPattern: []string{"users"}, + }, + method: "POST", + url: "https://api.com/users", + expected: true, + }, + { + name: "method fails combined test", + rule: Rule{ + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "com"}, + PathPattern: []string{"users"}, + }, + method: "GET", + url: "https://api.com/users", + expected: false, + }, + { + name: "host fails combined test", + rule: Rule{ + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "org"}, + PathPattern: []string{"users"}, + }, + method: "POST", + url: "https://api.com/users", + expected: false, + }, + { + name: "path fails combined test", + rule: Rule{ + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "com"}, + PathPattern: []string{"posts"}, + }, + method: "POST", + url: "https://api.com/users", + expected: false, + }, + { + name: "all wildcards match", + rule: Rule{ + MethodPatterns: map[string]struct{}{"*": {}}, + HostPattern: []string{"*", "*"}, + PathPattern: []string{"*", "*"}, + }, + method: "PATCH", + url: "https://test.example.com/api/users/123", + expected: true, + }, + + // Edge cases + { + name: "empty rule matches everything", + rule: Rule{}, + method: "GET", + url: "https://example.com/api/users", + expected: true, + }, + { + name: "invalid URL", + rule: Rule{ + HostPattern: []string{"example", "com"}, + }, + method: "GET", + url: "not-a-valid-url", + expected: false, + }, + { + name: "root path", + rule: Rule{ + PathPattern: []string{}, + }, + method: "GET", + url: "https://example.com/", + expected: true, + }, + { + name: "localhost host", + rule: Rule{ + HostPattern: []string{"localhost"}, + }, + method: "GET", + url: "http://localhost:8080/api", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := engine.matches(tt.rule, tt.method, tt.url) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/rulesengine/rules.go b/rulesengine/rules.go new file mode 100644 index 0000000..6a038ec --- /dev/null +++ b/rulesengine/rules.go @@ -0,0 +1,393 @@ +package rulesengine + +import ( + "errors" + "fmt" + "strings" +) + +// Rule represents an allow rule passed to the cli with --allow or read from the config file. +// Rules have a specific grammar that we need to parse carefully. +// Example: --allow="method=GET,PATCH domain=wibble.wobble.com, path=/posts/*" +type Rule struct { + + // The path segments of the url. + // - nil means all paths allowed + // - a path segment of `*` acts as a wild card. + // - sub paths automatically match + PathPattern []string + + // The labels of the host, i.e. ["google", "com"]. + // - nil means all hosts allowed + // - A label of `*` acts as a wild card. + // - subdomains automatically match + HostPattern []string + + // The allowed http methods. + // - nil means all methods allowed + MethodPatterns map[string]struct{} + + // Raw rule string for logging + Raw string +} + +// ParseAllowSpecs parses a slice of --allow specs into allow Rules. +func ParseAllowSpecs(allowStrings []string) ([]Rule, error) { + var out []Rule + for _, s := range allowStrings { + r, err := parseAllowRule(s) + if err != nil { + return nil, fmt.Errorf("failed to parse allow '%s': %v", s, err) + } + out = append(out, r) + } + return out, nil +} + +// parseAllowRule takes an allow rule string and tries to parse it as a rule. +func parseAllowRule(ruleStr string) (Rule, error) { + rule := Rule{ + Raw: ruleStr, + } + + // Functions called by this function used a really common pattern: recursive descent parsing. + // All the helper functions for parsing an allow rule will be called like `thing, rest, err := parseThing(rest)`. + // What's going on here is that we try to parse some expected text from the front of the string. + // If we succeed, we get back the thing we parsed and the remaining text. If we fail, we get back a non nil error. + rest := ruleStr + var key string + var err error + + // Ann allow rule can have as many key=value pairs as needed, we go until there's no more text in the rule. + for rest != "" { + // Parse the key + key, rest, err = parseKey(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse key: %v", err) + } + + // Parse the value based on the key type + switch key { + case "method": + // Initialize Methods map if needed + if rule.MethodPatterns == nil { + rule.MethodPatterns = make(map[string]struct{}) + } + + var method string + for { + method, rest, err = parseMethodPattern(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse method: %v", err) + } + + rule.MethodPatterns[method] = struct{}{} + + // Check if there's a comma for more methods + if rest != "" && rest[0] == ',' { + rest = rest[1:] // Skip the comma + continue + } + + break + } + + case "domain": + var host []string + host, rest, err = parseHostPattern(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse domain: %v", err) + } + + // Convert labels to strings + rule.HostPattern = append(rule.HostPattern, host...) + + case "path": + var segments []string + segments, rest, err = parsePathPattern(rest) + if err != nil { + return Rule{}, fmt.Errorf("failed to parse path: %v", err) + } + + // Convert segments to strings + rule.PathPattern = append(rule.PathPattern, segments...) + + default: + return Rule{}, fmt.Errorf("unknown key: %s", key) + } + + // Skip whitespace separators (only support mac and linux so \r\n shouldn't be a thing) + for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == '\n') { + rest = rest[1:] + } + } + + return rule, nil +} + +// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by +// various systems. +// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6 +func parseMethodPattern(token string) (string, string, error) { + if token == "" { + return "", "", errors.New("expected http token, got empty string") + } + return doParseMethodPattern(token, nil) +} + +func doParseMethodPattern(token string, acc []byte) (string, string, error) { + // BASE CASE: if the token passed in is empty, we're done parsing + if token == "" { + return string(acc), "", nil + } + + // If the next byte in the string is not a valid http token character, we're done parsing. + if !isHTTPTokenChar(token[0]) { + return string(acc), token, nil + } + + // The next character is valid, so the http token continues + acc = append(acc, token[0]) + return doParseMethodPattern(token[1:], acc) +} + +// The valid characters that can be in an http token (like the lexer/parser kind of token). +func isHTTPTokenChar(c byte) bool { + switch { + // Alpha numeric is fine. + case c >= 'A' && c <= 'Z': + return true + case c >= 'a' && c <= 'z': + return true + case c >= '0' && c <= '9': + return true + + // These special characters are also allowed unbelievably. + case c == '!' || c == '#' || c == '$' || c == '%' || c == '&' || + c == '\'' || c == '*' || c == '+' || c == '-' || c == '.' || + c == '^' || c == '_' || c == '`' || c == '|' || c == '~': + return true + + default: + return false + } +} + +// Represents a valid host. +// https://datatracker.ietf.org/doc/html/rfc952 +// https://datatracker.ietf.org/doc/html/rfc1123#page-13 +func parseHostPattern(input string) ([]string, string, error) { + rest := input + var host []string + var err error + + if input == "" { + return nil, "", errors.New("expected host, got empty string") + } + + // There should be at least one label. + var label string + label, rest, err = parseLabelPattern(rest) + if err != nil { + return nil, "", err + } + host = append(host, label) + + // A host is just a bunch of labels separated by `.` characters. + var found bool + for { + rest, found = strings.CutPrefix(rest, ".") + if !found { + break + } + + label, rest, err = parseLabelPattern(rest) + if err != nil { + return nil, "", err + } + host = append(host, label) + } + + // Validate: host patterns cannot end with asterisk + if len(host) > 0 && host[len(host)-1] == "*" { + return nil, "", errors.New("host patterns cannot end with asterisk") + } + + return host, rest, nil +} + +func parseLabelPattern(rest string) (string, string, error) { + if rest == "" { + return "", "", errors.New("expected label, got empty string") + } + + // If the label is simply an asterisk, good to go. + if rest[0] == '*' { + return "*", rest[1:], nil + } + + // First try to get a valid leading char. Leading char in a label cannot be a hyphen. + if !isValidLabelChar(rest[0]) || rest[0] == '-' { + return "", "", fmt.Errorf("could not pull label from front of string: %s", rest) + } + + // Go until the next character is not a valid char + var i int + for i = 1; i < len(rest) && isValidLabelChar(rest[i]); i += 1 { + } + + // Final char in a label cannot be a hyphen. + if rest[i-1] == '-' { + return "", "", fmt.Errorf("invalid label: %s", rest[:i]) + } + + return rest[:i], rest[i:], nil +} + +func isValidLabelChar(c byte) bool { + switch { + // Alpha numeric is fine. + case c >= 'A' && c <= 'Z': + return true + case c >= 'a' && c <= 'z': + return true + case c >= '0' && c <= '9': + return true + + // Hyphens are good + case c == '-': + return true + + default: + return false + } +} + +func parsePathPattern(input string) ([]string, string, error) { + if input == "" { + return nil, "", nil + } + + rest := input + var segments []string + var err error + + // If the path doesn't start with '/', it's not a valid absolute path + // But we'll be flexible and parse relative paths too + for { + // Skip leading slash if present + if rest != "" && rest[0] == '/' { + rest = rest[1:] + } + + // If we've consumed all input, we're done + if rest == "" { + break + } + + // Parse the next segment + var segment string + segment, rest, err = parsePathSegmentPattern(rest) + if err != nil { + return nil, "", err + } + + // If we got an empty segment and there's still input, + // it means we hit an invalid character + if segment == "" && rest != "" { + break + } + + segments = append(segments, segment) + + // If there's no slash after the segment, we're done parsing the path + if rest == "" || rest[0] != '/' { + break + } + } + + return segments, rest, nil +} + +func parsePathSegmentPattern(input string) (string, string, error) { + if input == "" { + return "", "", nil + } + + if len(input) > 0 && input[0] == '*' { + if len(input) > 1 && input[1] != '/' { + return "", "", fmt.Errorf("path segment wildcards must be for the entire segment, got: %s", input) + } + + return "*", input[1:], nil + } + + var i int + for i = 0; i < len(input); i++ { + c := input[i] + + // Check for percent-encoded characters (%XX) + if c == '%' { + if i+2 >= len(input) || !isHexDigit(input[i+1]) || !isHexDigit(input[i+2]) { + break + } + i += 2 + continue + } + + // Check for valid pchar characters + if !isPChar(c) { + break + } + } + + return input[:i], input[i:], nil +} + +// isUnreserved returns true if the character is unreserved per RFC 3986 +// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" +func isUnreserved(c byte) bool { + return (c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || + c == '-' || c == '.' || c == '_' || c == '~' +} + +// isSubDelim returns true if the character is a sub-delimiter per RFC 3986 +// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" +func isSubDelim(c byte) bool { + return c == '!' || c == '$' || c == '&' || c == '\'' || + c == '(' || c == ')' || c == '*' || c == '+' || + c == ',' || c == ';' || c == '=' +} + +// isPChar returns true if the character is valid in a path segment (excluding percent-encoded) +// pchar = unreserved / sub-delims / ":" / "@" +func isPChar(c byte) bool { + return isUnreserved(c) || isSubDelim(c) || c == ':' || c == '@' +} + +// isHexDigit returns true if the character is a hexadecimal digit +func isHexDigit(c byte) bool { + return (c >= '0' && c <= '9') || + (c >= 'A' && c <= 'F') || + (c >= 'a' && c <= 'f') +} + +// parseKey parses the predefined keys that the cli can handle. Also strips the `=` following the key. +func parseKey(rule string) (string, string, error) { + if rule == "" { + return "", "", errors.New("expected key") + } + + // These are the current keys we support. + keys := []string{"method", "domain", "path"} + + for _, key := range keys { + if rest, found := strings.CutPrefix(rule, key+"="); found { + return key, rest, nil + } + } + + return "", "", errors.New("expected key") +} diff --git a/rulesengine/rules_test.go b/rulesengine/rules_test.go new file mode 100644 index 0000000..8ccf94e --- /dev/null +++ b/rulesengine/rules_test.go @@ -0,0 +1,1233 @@ +package rulesengine + +import ( + "fmt" + "log/slog" + "testing" +) + +func TestParseHTTPToken(t *testing.T) { + tests := []struct { + name string + input string + expectedToken string + expectedRemain string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedToken: "", + expectedRemain: "", + expectError: true, + }, + { + name: "simple method GET", + input: "GET", + expectedToken: "GET", + expectedRemain: "", + expectError: false, + }, + { + name: "simple method POST", + input: "POST", + expectedToken: "POST", + expectedRemain: "", + expectError: false, + }, + { + name: "method with trailing space", + input: "GET ", + expectedToken: "GET", + expectedRemain: " ", + expectError: false, + }, + { + name: "method with trailing content", + input: "POST /api/users", + expectedToken: "POST", + expectedRemain: " /api/users", + expectError: false, + }, + { + name: "all valid special characters", + input: "!#$%&'*+-.^_`|~", + expectedToken: "!#$%&'*+-.^_`|~", + expectedRemain: "", + expectError: false, + }, + { + name: "alphanumeric token", + input: "ABC123xyz", + expectedToken: "ABC123xyz", + expectedRemain: "", + expectError: false, + }, + { + name: "token with invalid character", + input: "GET@test", + expectedToken: "GET", + expectedRemain: "@test", + expectError: false, + }, + { + name: "token starting with invalid character", + input: "@GET", + expectedToken: "", + expectedRemain: "@GET", + expectError: false, + }, + { + name: "single character token", + input: "A", + expectedToken: "A", + expectedRemain: "", + expectError: false, + }, + { + name: "token with underscore and dash", + input: "CUSTOM-METHOD_1", + expectedToken: "CUSTOM-METHOD_1", + expectedRemain: "", + expectError: false, + }, + { + name: "token stops at comma", + input: "GET,POST", + expectedToken: "GET", + expectedRemain: ",POST", + expectError: false, + }, + { + name: "token stops at semicolon", + input: "GET;charset=utf-8", + expectedToken: "GET", + expectedRemain: ";charset=utf-8", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token, remain, err := parseMethodPattern(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if token != tt.expectedToken { + t.Errorf("expected token %q, got %q", tt.expectedToken, token) + } + + if remain != tt.expectedRemain { + t.Errorf("expected remaining %q, got %q", tt.expectedRemain, remain) + } + }) + } +} + +func TestParseHost(t *testing.T) { + tests := []struct { + name string + input string + expectedHost []string + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "simple domain", + input: "google.com", + expectedHost: []string{"google", "com"}, + expectedRest: "", + expectError: false, + }, + { + name: "subdomain", + input: "api.google.com", + expectedHost: []string{"api", "google", "com"}, + expectedRest: "", + expectError: false, + }, + { + name: "single label", + input: "localhost", + expectedHost: []string{"localhost"}, + expectedRest: "", + expectError: false, + }, + { + name: "domain with trailing content", + input: "example.org/path", + expectedHost: []string{"example", "org"}, + expectedRest: "/path", + expectError: false, + }, + { + name: "domain with port", + input: "localhost:8080", + expectedHost: []string{"localhost"}, + expectedRest: ":8080", + expectError: false, + }, + { + name: "numeric labels", + input: "192.168.1.1", + expectedHost: []string{"192", "168", "1", "1"}, + expectedRest: "", + expectError: false, + }, + { + name: "hyphenated domain", + input: "my-site.example-domain.co.uk", + expectedHost: []string{"my-site", "example-domain", "co", "uk"}, + expectedRest: "", + expectError: false, + }, + { + name: "alphanumeric labels", + input: "a1b2c3.test123.com", + expectedHost: []string{"a1b2c3", "test123", "com"}, + expectedRest: "", + expectError: false, + }, + { + name: "starts with hyphen", + input: "-invalid.com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "ends with hyphen", + input: "invalid-.com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "label ends with hyphen", + input: "test.invalid-.com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "invalid character", + input: "test@example.com", + expectedHost: []string{"test"}, + expectedRest: "@example.com", + expectError: false, + }, + { + name: "empty label", + input: "test..com", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "trailing dot", + input: "example.com.", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "single character labels", + input: "a.b.c", + expectedHost: []string{"a", "b", "c"}, + expectedRest: "", + expectError: false, + }, + { + name: "mixed case", + input: "Example.COM", + expectedHost: []string{"Example", "COM"}, + expectedRest: "", + expectError: false, + }, + { + name: "wildcard subdomain", + input: "*.example.com", + expectedHost: []string{"*", "example", "com"}, + expectedRest: "", + expectError: false, + }, + { + name: "wildcard domain - should error", + input: "api.*", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + { + name: "multiple wildcards", + input: "*.*.com", + expectedHost: []string{"*", "*", "com"}, + expectedRest: "", + expectError: false, + }, + { + name: "wildcard with trailing content", + input: "*.example.com/path", + expectedHost: []string{"*", "example", "com"}, + expectedRest: "/path", + expectError: false, + }, + { + name: "host pattern ending with asterisk - rejected", + input: "api.*", + expectedHost: nil, + expectedRest: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hostResult, rest, err := parseHostPattern(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if len(hostResult) != len(tt.expectedHost) { + t.Errorf("expected host length %d, got %d", len(tt.expectedHost), len(hostResult)) + return + } + + for i, expectedLabel := range tt.expectedHost { + if hostResult[i] != expectedLabel { + t.Errorf("expected label[%d] %q, got %q", i, expectedLabel, hostResult[i]) + } + } + + if rest != tt.expectedRest { + t.Errorf("expected remaining %q, got %q", tt.expectedRest, rest) + } + }) + } +} + +func TestParseLabel(t *testing.T) { + tests := []struct { + name string + input string + expectedLabel string + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "simple label", + input: "test", + expectedLabel: "test", + expectedRest: "", + expectError: false, + }, + { + name: "label with dot", + input: "test.com", + expectedLabel: "test", + expectedRest: ".com", + expectError: false, + }, + { + name: "label with hyphen", + input: "my-site", + expectedLabel: "my-site", + expectedRest: "", + expectError: false, + }, + { + name: "alphanumeric label", + input: "test123", + expectedLabel: "test123", + expectedRest: "", + expectError: false, + }, + { + name: "starts with hyphen", + input: "-invalid", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "ends with hyphen", + input: "invalid-", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "ends with hyphen followed by dot", + input: "invalid-.com", + expectedLabel: "", + expectedRest: "", + expectError: true, + }, + { + name: "single character", + input: "a", + expectedLabel: "a", + expectedRest: "", + expectError: false, + }, + { + name: "numeric label", + input: "123", + expectedLabel: "123", + expectedRest: "", + expectError: false, + }, + { + name: "mixed case", + input: "Test", + expectedLabel: "Test", + expectedRest: "", + expectError: false, + }, + { + name: "invalid character", + input: "test@invalid", + expectedLabel: "test", + expectedRest: "@invalid", + expectError: false, + }, + { + name: "starts with number", + input: "1test", + expectedLabel: "1test", + expectedRest: "", + expectError: false, + }, + { + name: "label with trailing slash", + input: "api/path", + expectedLabel: "api", + expectedRest: "/path", + expectError: false, + }, + { + name: "wildcard label", + input: "*", + expectedLabel: "*", + expectedRest: "", + expectError: false, + }, + { + name: "wildcard with dot", + input: "*.com", + expectedLabel: "*", + expectedRest: ".com", + expectError: false, + }, + { + name: "wildcard with trailing content", + input: "*/path", + expectedLabel: "*", + expectedRest: "/path", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + labelResult, rest, err := parseLabelPattern(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if labelResult != tt.expectedLabel { + t.Errorf("expected label %q, got %q", tt.expectedLabel, labelResult) + } + + if rest != tt.expectedRest { + t.Errorf("expected remaining %q, got %q", tt.expectedRest, rest) + } + }) + } +} + +func TestParsePathSegment(t *testing.T) { + tests := []struct { + name string + input string + expectedSegment string + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedSegment: "", + expectedRest: "", + expectError: false, + }, + { + name: "simple segment", + input: "api", + expectedSegment: "api", + expectedRest: "", + expectError: false, + }, + { + name: "segment with slash", + input: "api/users", + expectedSegment: "api", + expectedRest: "/users", + expectError: false, + }, + { + name: "segment with unreserved chars", + input: "my-file.txt_version~1", + expectedSegment: "my-file.txt_version~1", + expectedRest: "", + expectError: false, + }, + { + name: "segment with sub-delims", + input: "filter='test'&sort=name", + expectedSegment: "filter='test'&sort=name", + expectedRest: "", + expectError: false, + }, + { + name: "segment with colon and at", + input: "user:password@domain", + expectedSegment: "user:password@domain", + expectedRest: "", + expectError: false, + }, + { + name: "percent encoded segment", + input: "hello%20world", + expectedSegment: "hello%20world", + expectedRest: "", + expectError: false, + }, + { + name: "multiple percent encoded", + input: "%3Fkey%3Dvalue%26other%3D123", + expectedSegment: "%3Fkey%3Dvalue%26other%3D123", + expectedRest: "", + expectError: false, + }, + { + name: "invalid percent encoding incomplete", + input: "test%2", + expectedSegment: "test", + expectedRest: "%2", + expectError: false, + }, + { + name: "invalid percent encoding non-hex", + input: "test%ZZ", + expectedSegment: "test", + expectedRest: "%ZZ", + expectError: false, + }, + { + name: "segment stops at space", + input: "test hello", + expectedSegment: "test", + expectedRest: " hello", + expectError: false, + }, + { + name: "segment with question mark stops", + input: "path?query=value", + expectedSegment: "path", + expectedRest: "?query=value", + expectError: false, + }, + { + name: "segment with hash stops", + input: "path#fragment", + expectedSegment: "path", + expectedRest: "#fragment", + expectError: false, + }, + { + name: "numeric segment", + input: "123456", + expectedSegment: "123456", + expectedRest: "", + expectError: false, + }, + { + name: "mixed alphanumeric", + input: "abc123XYZ", + expectedSegment: "abc123XYZ", + expectedRest: "", + expectError: false, + }, + { + name: "all sub-delims", + input: "!$&'()*+,;=", + expectedSegment: "!$&'()*+,;=", + expectedRest: "", + expectError: false, + }, + { + name: "segment with brackets", + input: "test[bracket]", + expectedSegment: "test", + expectedRest: "[bracket]", + expectError: false, + }, + { + name: "wildcard segment", + input: "*", + expectedSegment: "*", + expectedRest: "", + expectError: false, + }, + { + name: "wildcard with slash", + input: "*/users", + expectedSegment: "*", + expectedRest: "/users", + expectError: false, + }, + { + name: "wildcard at end with slash", + input: "*", + expectedSegment: "*", + expectedRest: "", + expectError: false, + }, + { + name: "invalid partial wildcard", + input: "*abc", + expectedSegment: "", + expectedRest: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + segment, rest, err := parsePathSegmentPattern(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if segment != tt.expectedSegment { + t.Errorf("expected segment %q, got %q", tt.expectedSegment, segment) + } + + if rest != tt.expectedRest { + t.Errorf("expected rest %q, got %q", tt.expectedRest, rest) + } + }) + } +} + +func TestParsePath(t *testing.T) { + tests := []struct { + name string + input string + expectedSegments []string + expectedRest string + expectError bool + }{ + { + name: "empty string", + input: "", + expectedSegments: nil, + expectedRest: "", + expectError: false, + }, + { + name: "single segment", + input: "/api", + expectedSegments: []string{"api"}, + expectedRest: "", + expectError: false, + }, + { + name: "multiple segments", + input: "/api/v1/users", + expectedSegments: []string{"api", "v1", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "relative path", + input: "api/users", + expectedSegments: []string{"api", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with trailing slash", + input: "/api/users/", + expectedSegments: []string{"api", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with query string", + input: "/api/users?limit=10", + expectedSegments: []string{"api", "users"}, + expectedRest: "?limit=10", + expectError: false, + }, + { + name: "path with fragment", + input: "/docs/api#authentication", + expectedSegments: []string{"docs", "api"}, + expectedRest: "#authentication", + expectError: false, + }, + { + name: "path with encoded segments", + input: "/api/hello%20world/test", + expectedSegments: []string{"api", "hello%20world", "test"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with special chars", + input: "/api/filter='test'&sort=name/results", + expectedSegments: []string{"api", "filter='test'&sort=name", "results"}, + expectedRest: "", + expectError: false, + }, + { + name: "just slash", + input: "/", + expectedSegments: nil, + expectedRest: "", + expectError: false, + }, + { + name: "empty segments", + input: "/api//users", + expectedSegments: []string{"api"}, + expectedRest: "/users", + expectError: false, + }, + { + name: "path with port-like segment", + input: "/host:8080/status", + expectedSegments: []string{"host:8080", "status"}, + expectedRest: "", + expectError: false, + }, + { + name: "path stops at space", + input: "/api/test hello", + expectedSegments: []string{"api", "test"}, + expectedRest: " hello", + expectError: false, + }, + { + name: "path with hyphens and underscores", + input: "/my-api/user_data/file-name.txt", + expectedSegments: []string{"my-api", "user_data", "file-name.txt"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with tildes", + input: "/api/~user/docs~backup", + expectedSegments: []string{"api", "~user", "docs~backup"}, + expectedRest: "", + expectError: false, + }, + { + name: "numeric segments", + input: "/api/v2/users/12345", + expectedSegments: []string{"api", "v2", "users", "12345"}, + expectedRest: "", + expectError: false, + }, + { + name: "single character segments", + input: "/a/b/c", + expectedSegments: []string{"a", "b", "c"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with at symbol", + input: "/user@domain.com/profile", + expectedSegments: []string{"user@domain.com", "profile"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with wildcard segment", + input: "/api/*/users", + expectedSegments: []string{"api", "*", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with multiple wildcards", + input: "/*/v1/*/profile", + expectedSegments: []string{"*", "v1", "*", "profile"}, + expectedRest: "", + expectError: false, + }, + { + name: "path ending with wildcard", + input: "/api/users/*", + expectedSegments: []string{"api", "users", "*"}, + expectedRest: "", + expectError: false, + }, + { + name: "path starting with wildcard", + input: "/*/users", + expectedSegments: []string{"*", "users"}, + expectedRest: "", + expectError: false, + }, + { + name: "path with wildcard and query", + input: "/api/*/users?limit=10", + expectedSegments: []string{"api", "*", "users"}, + expectedRest: "?limit=10", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + segments, rest, err := parsePathPattern(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if len(segments) != len(tt.expectedSegments) { + t.Errorf("expected %d segments, got %d", len(tt.expectedSegments), len(segments)) + return + } + + for i, expectedSeg := range tt.expectedSegments { + if segments[i] != expectedSeg { + t.Errorf("expected segment[%d] %q, got %q", i, expectedSeg, segments[i]) + } + } + + if rest != tt.expectedRest { + t.Errorf("expected rest %q, got %q", tt.expectedRest, rest) + } + }) + } +} + +func TestParseAllowRule(t *testing.T) { + tests := []struct { + name string + input string + expectedRule Rule + expectError bool + }{ + { + name: "empty string", + input: "", + expectedRule: Rule{ + Raw: "", + }, + expectError: false, + }, + { + name: "method only", + input: "method=GET", + expectedRule: Rule{ + Raw: "method=GET", + MethodPatterns: map[string]struct{}{"GET": {}}, + }, + expectError: false, + }, + { + name: "domain only", + input: "domain=google.com", + expectedRule: Rule{ + Raw: "domain=google.com", + HostPattern: []string{"google", "com"}, + }, + expectError: false, + }, + { + name: "path only", + input: "path=/api/v1", + expectedRule: Rule{ + Raw: "path=/api/v1", + PathPattern: []string{"api", "v1"}, + }, + expectError: false, + }, + { + name: "method and domain", + input: "method=POST domain=api.example.com", + expectedRule: Rule{ + Raw: "method=POST domain=api.example.com", + MethodPatterns: map[string]struct{}{"POST": {}}, + HostPattern: []string{"api", "example", "com"}, + }, + expectError: false, + }, + { + name: "all three keys", + input: "method=DELETE domain=test.com path=/resources/456", + expectedRule: Rule{ + Raw: "method=DELETE domain=test.com path=/resources/456", + MethodPatterns: map[string]struct{}{"DELETE": {}}, + HostPattern: []string{"test", "com"}, + PathPattern: []string{"resources", "456"}, + }, + expectError: false, + }, + { + name: "wildcard domain", + input: "domain=*.example.com", + expectedRule: Rule{ + Raw: "domain=*.example.com", + HostPattern: []string{"*", "example", "com"}, + }, + expectError: false, + }, + { + name: "wildcard path", + input: "path=/api/*/users", + expectedRule: Rule{ + Raw: "path=/api/*/users", + PathPattern: []string{"api", "*", "users"}, + }, + expectError: false, + }, + { + name: "wildcard method", + input: "method=*", + expectedRule: Rule{ + Raw: "method=*", + MethodPatterns: map[string]struct{}{"*": {}}, + }, + expectError: false, + }, + { + name: "all wildcards - domain ending with asterisk should error", + input: "method=* domain=*.* path=/*/", + expectedRule: Rule{}, + expectError: true, + }, + { + name: "invalid key", + input: "invalid=value", + expectedRule: Rule{}, + expectError: true, + }, + { + name: "missing value", + input: "method=", + expectedRule: Rule{}, + expectError: true, + }, + { + name: "invalid method", + input: "method=@invalid", + expectedRule: Rule{}, + expectError: true, + }, + { + name: "invalid domain", + input: "domain=-invalid.com", + expectedRule: Rule{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule, err := parseAllowRule(tt.input) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Check Raw field + if rule.Raw != tt.expectedRule.Raw { + t.Errorf("expected Raw %q, got %q", tt.expectedRule.Raw, rule.Raw) + } + + // Check MethodPatterns + if tt.expectedRule.MethodPatterns == nil { + if rule.MethodPatterns != nil { + t.Errorf("expected MethodPatterns to be nil, got %v", rule.MethodPatterns) + } + } else { + if rule.MethodPatterns == nil { + t.Errorf("expected MethodPatterns %v, got nil", tt.expectedRule.MethodPatterns) + } else { + if len(rule.MethodPatterns) != len(tt.expectedRule.MethodPatterns) { + t.Errorf("expected %d methods, got %d", len(tt.expectedRule.MethodPatterns), len(rule.MethodPatterns)) + } + for method := range tt.expectedRule.MethodPatterns { + if _, exists := rule.MethodPatterns[method]; !exists { + t.Errorf("expected method %q not found", method) + } + } + } + } + + // Check HostPattern + if len(rule.HostPattern) != len(tt.expectedRule.HostPattern) { + t.Errorf("expected HostPattern length %d, got %d", len(tt.expectedRule.HostPattern), len(rule.HostPattern)) + } else { + for i, expectedLabel := range tt.expectedRule.HostPattern { + if rule.HostPattern[i] != expectedLabel { + t.Errorf("expected HostPattern[%d] %q, got %q", i, expectedLabel, rule.HostPattern[i]) + } + } + } + + // Check PathPattern + if len(rule.PathPattern) != len(tt.expectedRule.PathPattern) { + t.Errorf("expected PathPattern length %d, got %d", len(tt.expectedRule.PathPattern), len(rule.PathPattern)) + } else { + for i, expectedSegment := range tt.expectedRule.PathPattern { + if rule.PathPattern[i] != expectedSegment { + t.Errorf("expected PathPattern[%d] %q, got %q", i, expectedSegment, rule.PathPattern[i]) + } + } + } + }) + } +} + +func TestReadmeExamples(t *testing.T) { + logger := slog.Default() + + tests := []struct { + name string + allowRule string + testCases []struct { + method string + url string + expected bool + } + }{ + { + name: "domain only - github.com", + allowRule: "domain=github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://github.com", true}, + {"POST", "https://github.com/user/repo", true}, + {"GET", "https://api.github.com", true}, // subdomain match + {"GET", "https://example.com", false}, + }, + }, + { + name: "domain with path - github.com/api/issues/*", + allowRule: "domain=github.com path=/api/issues/*", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://github.com/api/issues/123", true}, + {"POST", "https://github.com/api/issues/new", true}, + {"GET", "https://github.com/api/users", false}, // wrong path + {"GET", "https://example.com/api/issues/123", false}, // wrong domain + }, + }, + { + name: "method with domain - GET,HEAD github.com", + allowRule: "method=GET,HEAD domain=github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://github.com/user/repo", true}, + {"HEAD", "https://github.com/user/repo", true}, + {"POST", "https://github.com/user/repo", false}, // wrong method + {"GET", "https://example.com", false}, // wrong domain + }, + }, + { + name: "wildcard subdomain - *.github.com", + allowRule: "domain=*.github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://api.github.com", true}, + {"GET", "https://raw.github.com", true}, + {"GET", "https://github.com", false}, // no subdomain + {"GET", "https://example.com", false}, + }, + }, + { + name: "method with domain and specific host", + allowRule: "method=GET,HEAD domain=api.github.com", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://api.github.com/users", true}, + {"HEAD", "https://api.github.com/repos", true}, + {"POST", "https://api.github.com/users", false}, // wrong method + {"GET", "https://github.com", false}, // wrong domain + }, + }, + { + name: "method with domain and path", + allowRule: "method=POST domain=api.example.com path=/users", + testCases: []struct { + method string + url string + expected bool + }{ + {"POST", "https://api.example.com/users", true}, + {"POST", "https://api.example.com/users/123", true}, // subpath match + {"GET", "https://api.example.com/users", false}, // wrong method + {"POST", "https://api.example.com/posts", false}, // wrong path + {"POST", "https://example.com/users", false}, // wrong domain + }, + }, + { + name: "method wildcard - all methods", + allowRule: "method=*", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://example.com", true}, + {"POST", "https://example.com", true}, + {"DELETE", "https://example.com", true}, + {"PATCH", "https://example.com", true}, + {"OPTIONS", "https://example.com", true}, + }, + }, + { + name: "multiple wildcards - wildcard subdomains", + allowRule: "domain=*.npmjs.org", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://registry.npmjs.org", true}, + {"GET", "https://api.npmjs.org", true}, + {"GET", "https://npmjs.org", false}, // no subdomain + {"GET", "https://example.com", false}, + }, + }, + { + name: "registry domain exact match", + allowRule: "domain=registry.npmjs.org", + testCases: []struct { + method string + url string + expected bool + }{ + {"GET", "https://registry.npmjs.org", true}, + {"GET", "https://registry.npmjs.org/package", true}, + {"GET", "https://api.npmjs.org", false}, // different subdomain + {"GET", "https://npmjs.org", false}, // missing subdomain + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the allow rule + rule, err := parseAllowRule(tt.allowRule) + if err != nil { + t.Fatalf("Failed to parse allow rule %q: %v", tt.allowRule, err) + } + + // Create engine with the single rule + engine := NewRuleEngine([]Rule{rule}, logger) + + // Test each case + for i, tc := range tt.testCases { + t.Run(fmt.Sprintf("case_%d_%s_%s", i, tc.method, tc.url), func(t *testing.T) { + result := engine.matches(rule, tc.method, tc.url) + if result != tc.expected { + t.Errorf("Rule %q with method %q and URL %q: expected %v, got %v", + tt.allowRule, tc.method, tc.url, tc.expected, result) + } + }) + } + }) + } +}