diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d9d33c0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,46 @@ +name: CI + +on: + push: + pull_request: + +jobs: + test: + name: Test + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} + + steps: + - name: Check out code + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.25' + check-latest: true + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Download dependencies + run: go mod download + + - name: Verify dependencies + run: go mod verify + + - name: Run tests + run: go test -v -race ./... + + - name: Run build + run: go build -v ./... \ No newline at end of file diff --git a/main.go b/main.go index e07736b..2572b3d 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "syscall" "time" - "github.com/coder/jail/netjail" + "github.com/coder/jail/network" "github.com/coder/jail/proxy" "github.com/coder/jail/rules" "github.com/coder/jail/tls" @@ -181,7 +181,7 @@ func runJail(inv *serpent.Invocation) error { } // Create network jail configuration - netjailConfig := netjail.Config{ + networkConfig := network.JailConfig{ HTTPPort: 8040, HTTPSPort: 8043, NetJailName: "jail", @@ -189,7 +189,7 @@ func runJail(inv *serpent.Invocation) error { } // Create network jail - netjailInstance, err := netjail.NewNetJail(netjailConfig, logger) + networkInstance, err := network.NewJail(networkConfig, logger) if err != nil { logger.Error("Failed to create network jail", "error", err) return fmt.Errorf("failed to create network jail: %v", err) @@ -203,7 +203,7 @@ func runJail(inv *serpent.Invocation) error { go func() { sig := <-sigChan logger.Info("Received signal during setup, cleaning up...", "signal", sig) - if err := netjailInstance.Cleanup(); err != nil { + if err := networkInstance.Cleanup(); err != nil { logger.Error("Emergency cleanup failed", "error", err) } os.Exit(1) @@ -212,7 +212,7 @@ func runJail(inv *serpent.Invocation) error { // Ensure cleanup happens no matter what defer func() { logger.Debug("Starting cleanup process") - if err := netjailInstance.Cleanup(); err != nil { + if err := networkInstance.Cleanup(); err != nil { logger.Error("Failed to cleanup network jail", "error", err) } else { logger.Debug("Cleanup completed successfully") @@ -220,15 +220,15 @@ func runJail(inv *serpent.Invocation) error { }() // Setup network jail - if err := netjailInstance.Setup(netjailConfig.HTTPPort, netjailConfig.HTTPSPort); err != nil { + if err := networkInstance.Setup(networkConfig.HTTPPort, networkConfig.HTTPSPort); err != nil { logger.Error("Failed to setup network jail", "error", err) return fmt.Errorf("failed to setup network jail: %v", err) } // Create proxy server proxyConfig := proxy.Config{ - HTTPPort: netjailConfig.HTTPPort, - HTTPSPort: netjailConfig.HTTPSPort, + HTTPPort: networkConfig.HTTPPort, + HTTPSPort: networkConfig.HTTPSPort, RuleEngine: ruleEngine, Logger: logger, TLSConfig: tlsConfig, @@ -253,7 +253,7 @@ func runJail(inv *serpent.Invocation) error { // Execute command in network jail go func() { defer cancel() - if err := netjailInstance.Execute(args, extraEnv); err != nil { + if err := networkInstance.Execute(args, extraEnv); err != nil { logger.Error("Command execution failed", "error", err) } }() diff --git a/netjail/linux_stub.go b/netjail/linux_stub.go deleted file mode 100644 index 4f3041f..0000000 --- a/netjail/linux_stub.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !linux - -package netjail - -import ( - "fmt" - "log/slog" -) - -// newLinuxNetJail is not available on non-Linux platforms -func newLinuxNetJail(config Config, logger *slog.Logger) (NetJail, error) { - return nil, fmt.Errorf("Linux network jail not supported on this platform") -} diff --git a/netjail/macos_stub.go b/netjail/macos_stub.go deleted file mode 100644 index 99ac3d7..0000000 --- a/netjail/macos_stub.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !darwin - -package netjail - -import "log/slog" - -// newMacOSNetJail is not available on non-macOS platforms -func newMacOSNetJail(config Config, logger *slog.Logger) (NetJail, error) { - panic("macOS network jail not available on this platform") -} diff --git a/netjail/linux.go b/network/linux.go similarity index 88% rename from netjail/linux.go rename to network/linux.go index 0371c5d..82eacc2 100644 --- a/netjail/linux.go +++ b/network/linux.go @@ -1,6 +1,6 @@ //go:build linux -package netjail +package network import ( "fmt" @@ -11,19 +11,23 @@ import ( "time" ) -// LinuxNetJail implements NetJail using Linux network namespaces -type LinuxNetJail struct { - config Config +const ( + namespacePrefix = "coder_jail" +) + +// LinuxJail implements NetJail using Linux network namespaces +type LinuxJail struct { + config JailConfig namespace string logger *slog.Logger } -// newLinuxNetJail creates a new Linux network jail instance -func newLinuxNetJail(config Config, logger *slog.Logger) (*LinuxNetJail, error) { +// newLinuxJail creates a new Linux network jail instance +func newLinuxJail(config JailConfig, logger *slog.Logger) (*LinuxJail, error) { // Generate unique namespace name - namespace := fmt.Sprintf("boundary_%d", time.Now().UnixNano()%10000000) + namespace := fmt.Sprintf("%s_%d", namespacePrefix, time.Now().UnixNano()%10000000) - return &LinuxNetJail{ + return &LinuxJail{ config: config, namespace: namespace, logger: logger, @@ -31,7 +35,7 @@ func newLinuxNetJail(config Config, logger *slog.Logger) (*LinuxNetJail, error) } // Setup creates network namespace and configures iptables rules -func (l *LinuxNetJail) Setup(httpPort, httpsPort int) error { +func (l *LinuxJail) Setup(httpPort, httpsPort int) error { l.logger.Debug("Setup called", "httpPort", httpPort, "httpsPort", httpsPort) l.config.HTTPPort = httpPort l.config.HTTPSPort = httpsPort @@ -70,7 +74,7 @@ func (l *LinuxNetJail) Setup(httpPort, httpsPort int) error { } // Execute runs a command within the network namespace -func (l *LinuxNetJail) Execute(command []string, extraEnv map[string]string) error { +func (l *LinuxJail) Execute(command []string, extraEnv map[string]string) error { l.logger.Debug("Execute called", "command", command) if len(command) == 0 { return fmt.Errorf("no command specified") @@ -81,7 +85,7 @@ func (l *LinuxNetJail) Execute(command []string, extraEnv map[string]string) err cmdArgs := []string{"ip", "netns", "exec", l.namespace} cmdArgs = append(cmdArgs, command...) l.logger.Debug("Full command args", "args", cmdArgs) - + cmd := exec.Command("ip", cmdArgs[1:]...) // Set up environment @@ -124,7 +128,7 @@ func (l *LinuxNetJail) Execute(command []string, extraEnv map[string]string) err } // Cleanup removes the network namespace and iptables rules -func (l *LinuxNetJail) Cleanup() error { +func (l *LinuxJail) Cleanup() error { if l.config.SkipCleanup { return nil } @@ -152,7 +156,7 @@ func (l *LinuxNetJail) Cleanup() error { } // createNamespace creates a new network namespace -func (l *LinuxNetJail) createNamespace() error { +func (l *LinuxJail) createNamespace() error { cmd := exec.Command("ip", "netns", "add", l.namespace) if err := cmd.Run(); err != nil { return fmt.Errorf("failed to create namespace: %v", err) @@ -161,12 +165,12 @@ func (l *LinuxNetJail) createNamespace() error { } // setupNetworking configures networking within the namespace -func (l *LinuxNetJail) setupNetworking() error { +func (l *LinuxJail) setupNetworking() error { // Create veth pair with short names (Linux interface names limited to 15 chars) // Generate unique ID to avoid conflicts uniqueID := fmt.Sprintf("%d", time.Now().UnixNano()%10000000) // 7 digits max - vethHost := fmt.Sprintf("veth_h_%s", uniqueID) // veth_h_1234567 = 14 chars - vethNetJail := fmt.Sprintf("veth_n_%s", uniqueID) // veth_n_1234567 = 14 chars + vethHost := fmt.Sprintf("veth_h_%s", uniqueID) // veth_h_1234567 = 14 chars + vethNetJail := fmt.Sprintf("veth_n_%s", uniqueID) // veth_n_1234567 = 14 chars cmd := exec.Command("ip", "link", "add", vethHost, "type", "veth", "peer", "name", vethNetJail) if err := cmd.Run(); err != nil { @@ -218,7 +222,7 @@ func (l *LinuxNetJail) setupNetworking() error { // setupDNS configures DNS resolution for the namespace // This ensures reliable DNS resolution by using public DNS servers // instead of relying on the host's potentially complex DNS configuration -func (l *LinuxNetJail) setupDNS() error { +func (l *LinuxJail) setupDNS() error { // Always create namespace-specific resolv.conf with reliable public DNS servers // This avoids issues with systemd-resolved, Docker DNS, and other complex setups netnsEtc := fmt.Sprintf("/etc/netns/%s", l.namespace) @@ -228,7 +232,7 @@ func (l *LinuxNetJail) setupDNS() error { // Write custom resolv.conf with multiple reliable public DNS servers resolvConfPath := fmt.Sprintf("%s/resolv.conf", netnsEtc) - dnsConfig := `# Custom DNS for boundary namespace + dnsConfig := `# Custom DNS for network namespace nameserver 8.8.8.8 nameserver 8.8.4.4 nameserver 1.1.1.1 @@ -244,7 +248,7 @@ options timeout:2 attempts:2 } // setupIptables configures iptables rules for traffic redirection -func (l *LinuxNetJail) setupIptables() error { +func (l *LinuxJail) setupIptables() error { // Enable IP forwarding cmd := exec.Command("sysctl", "-w", "net.ipv4.ip_forward=1") cmd.Run() // Ignore error @@ -273,7 +277,7 @@ func (l *LinuxNetJail) setupIptables() error { } // removeIptables removes iptables rules -func (l *LinuxNetJail) removeIptables() error { +func (l *LinuxJail) removeIptables() error { // Remove NAT rule cmd := exec.Command("iptables", "-t", "nat", "-D", "POSTROUTING", "-s", "192.168.100.0/24", "-j", "MASQUERADE") cmd.Run() // Ignore errors during cleanup @@ -282,10 +286,10 @@ func (l *LinuxNetJail) removeIptables() error { } // removeNamespace removes the network namespace -func (l *LinuxNetJail) removeNamespace() error { +func (l *LinuxJail) removeNamespace() error { cmd := exec.Command("ip", "netns", "del", l.namespace) if err := cmd.Run(); err != nil { return fmt.Errorf("failed to remove namespace: %v", err) } return nil -} \ No newline at end of file +} diff --git a/network/linux_stub.go b/network/linux_stub.go new file mode 100644 index 0000000..3155cb4 --- /dev/null +++ b/network/linux_stub.go @@ -0,0 +1,13 @@ +//go:build !linux + +package network + +import ( + "fmt" + "log/slog" +) + +// newLinuxJail is not available on non-Linux platforms +func newLinuxJail(_ JailConfig, _ *slog.Logger) (Jail, error) { + return nil, fmt.Errorf("linux network jail not supported on this platform") +} diff --git a/netjail/macos.go b/network/macos.go similarity index 95% rename from netjail/macos.go rename to network/macos.go index e9256f9..79f8097 100644 --- a/netjail/macos.go +++ b/network/macos.go @@ -1,10 +1,9 @@ //go:build darwin -package netjail +package network import ( "fmt" - "io/ioutil" "log/slog" "os" "os/exec" @@ -14,21 +13,21 @@ import ( ) const ( - PF_ANCHOR_NAME = "boundary" - GROUP_NAME = "boundary" + PF_ANCHOR_NAME = "network" + GROUP_NAME = "network" ) // MacOSNetJail implements network jail using macOS PF (Packet Filter) and group-based isolation type MacOSNetJail struct { - config Config + config JailConfig groupID int pfRulesPath string mainRulesPath string logger *slog.Logger } -// newMacOSNetJail creates a new macOS network jail instance -func newMacOSNetJail(config Config, logger *slog.Logger) (*MacOSNetJail, error) { +// newMacOSJail creates a new macOS network jail instance +func newMacOSJail(config JailConfig, logger *slog.Logger) (*MacOSNetJail, error) { pfRulesPath := fmt.Sprintf("/tmp/%s.pf", config.NetJailName) mainRulesPath := fmt.Sprintf("/tmp/%s_main.pf", config.NetJailName) @@ -266,7 +265,7 @@ func (m *MacOSNetJail) setupPFRules() error { } // Write rules to temp file - if err := ioutil.WriteFile(m.pfRulesPath, []byte(rules), 0644); err != nil { + if err := os.WriteFile(m.pfRulesPath, []byte(rules), 0644); err != nil { return fmt.Errorf("failed to write PF rules file: %v", err) } @@ -297,7 +296,7 @@ anchor "%s" `, PF_ANCHOR_NAME, PF_ANCHOR_NAME) // Write and load the main ruleset - if err := ioutil.WriteFile(m.mainRulesPath, []byte(mainRules), 0644); err != nil { + if err := os.WriteFile(m.mainRulesPath, []byte(mainRules), 0644); err != nil { return fmt.Errorf("failed to write main PF rules: %v", err) } @@ -335,4 +334,4 @@ func (m *MacOSNetJail) cleanupTempFiles() { if m.mainRulesPath != "" { os.Remove(m.mainRulesPath) } -} \ No newline at end of file +} diff --git a/network/macos_stub.go b/network/macos_stub.go new file mode 100644 index 0000000..4dd2580 --- /dev/null +++ b/network/macos_stub.go @@ -0,0 +1,10 @@ +//go:build !darwin + +package network + +import "log/slog" + +// newMacOSJail is not available on non-macOS platforms +func newMacOSJail(config JailConfig, logger *slog.Logger) (Jail, error) { + panic("macOS network jail not available on this platform") +} diff --git a/netjail/netjail.go b/network/network.go similarity index 51% rename from netjail/netjail.go rename to network/network.go index b44961e..cb8f3f6 100644 --- a/netjail/netjail.go +++ b/network/network.go @@ -1,4 +1,4 @@ -package netjail +package network import ( "fmt" @@ -6,8 +6,8 @@ import ( "runtime" ) -// NetJail represents a network isolation mechanism -type NetJail interface { +// Jail represents a network isolation mechanism +type Jail interface { // Setup configures the network jail for the given proxy ports Setup(httpPort, httpsPort int) error @@ -18,22 +18,22 @@ type NetJail interface { Cleanup() error } -// Config holds configuration for network jail -type Config struct { - HTTPPort int - HTTPSPort int - NetJailName string - SkipCleanup bool +// JailConfig holds configuration for network jail +type JailConfig struct { + HTTPPort int + HTTPSPort int + NetJailName string + SkipCleanup bool } -// NewNetJail creates a new NetJail instance for the current platform -func NewNetJail(config Config, logger *slog.Logger) (NetJail, error) { +// NewJail creates a new NetJail instance for the current platform +func NewJail(config JailConfig, logger *slog.Logger) (Jail, error) { switch runtime.GOOS { case "darwin": - return newMacOSNetJail(config, logger) + return newMacOSJail(config, logger) case "linux": - return newLinuxNetJail(config, logger) + return newLinuxJail(config, logger) default: return nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) } -} \ No newline at end of file +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 0e25cea..8f59077 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -144,7 +144,6 @@ func (p *ProxyServer) forwardHTTPRequest(w http.ResponseWriter, r *http.Request) // Create HTTP client client := &http.Client{ - Timeout: 30 * time.Second, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse // Don't follow redirects }, diff --git a/rules/rules.go b/rules/rules.go index ba4d9ba..2f4810d 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -6,72 +6,30 @@ import ( "strings" ) -// Action represents whether to allow or deny a request +// Action represents whether to allow a request type Action int const ( Allow Action = iota - Deny + Deny // Default deny when no allow rules match ) func (a Action) String() string { switch a { case Allow: return "ALLOW" - case Deny: - return "DENY" default: - return "UNKNOWN" + return "DENY" } } -// Rule represents a filtering rule with optional HTTP method restrictions +// Rule represents an allow rule with optional HTTP method restrictions type Rule struct { - Action Action - Pattern string // wildcard pattern for matching - Methods map[string]bool // nil means all methods allowed - Raw string // rule string for logging + Pattern string // wildcard pattern for matching + Methods map[string]bool // nil means all methods allowed + Raw string // rule string for logging } -// newRule creates a new rule from a string format like "allow: github.com" or "deny-post: telemetry.*" -func newRule(ruleStr string) (*Rule, error) { - parts := strings.SplitN(ruleStr, ":", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid rule format: %s (expected 'action[-method]: pattern')", ruleStr) - } - - actionPart := strings.TrimSpace(parts[0]) - pattern := strings.TrimSpace(parts[1]) - - // Parse action and optional method - var action Action - var methods map[string]bool - - actionParts := strings.Split(actionPart, "-") - switch strings.ToLower(actionParts[0]) { - case "allow": - action = Allow - case "deny": - action = Deny - default: - return nil, fmt.Errorf("invalid action: %s (must be 'allow' or 'deny')", actionParts[0]) - } - - // Parse optional method restriction - if len(actionParts) > 1 { - methods = make(map[string]bool) - for _, method := range actionParts[1:] { - methods[strings.ToUpper(method)] = true - } - } - - return &Rule{ - Action: action, - Pattern: pattern, - Methods: methods, - Raw: ruleStr, - }, nil -} // Matches checks if the rule matches the given method and URL using wildcard patterns func (r *Rule) Matches(method, url string) bool { @@ -95,12 +53,12 @@ func (r *Rule) Matches(method, url string) bool { } else if strings.HasPrefix(url, "http://") { urlWithoutProtocol = url[7:] // Remove "http://" } - + // Try matching against URL without protocol if wildcardMatch(r.Pattern, urlWithoutProtocol) { return true } - + // Also try matching just the domain part domainEnd := strings.Index(urlWithoutProtocol, "/") if domainEnd > 0 { @@ -167,43 +125,38 @@ func wildcardMatchRecursive(pattern, text string, p, t int) bool { // RuleEngine evaluates HTTP requests against a set of rules type RuleEngine struct { - rules []*Rule - logger *slog.Logger + rules []*Rule + logger *slog.Logger } // NewRuleEngine creates a new rule engine func NewRuleEngine(rules []*Rule, logger *slog.Logger) *RuleEngine { return &RuleEngine{ - rules: rules, - logger: logger, + rules: rules, + logger: logger, } } -// Evaluate evaluates a request against all rules and returns the action to take +// Evaluate evaluates a request against all allow rules and returns the action to take func (re *RuleEngine) Evaluate(method, url string) Action { - // Evaluate rules in order + // Check if any allow rule matches for _, rule := range re.rules { if rule.Matches(method, url) { - switch rule.Action { - case Allow: - re.logger.Info("ALLOW", "method", method, "url", url, "rule", rule.Raw) - return Allow - case Deny: - re.logger.Warn("DENY", "method", method, "url", url, "rule", rule.Raw) - return Deny - } + re.logger.Info("ALLOW", "method", method, "url", url, "rule", rule.Raw) + return Allow } } - // Default deny if no rules match - re.logger.Warn("DENY", "method", method, "url", url, "reason", "no matching rules") + // Default deny if no allow rules match + re.logger.Warn("DENY", "method", method, "url", url, "reason", "no matching allow rules") return Deny } // newAllowRule creates an allow Rule from a spec string used by --allow. // Supported formats: -// "pattern" -> allow all methods to pattern -// "GET,HEAD pattern" -> allow only listed methods to pattern +// +// "pattern" -> allow all methods to pattern +// "GET,HEAD pattern" -> allow only listed methods to pattern func newAllowRule(spec string) (*Rule, error) { s := strings.TrimSpace(spec) if s == "" { @@ -240,7 +193,6 @@ func newAllowRule(spec string) (*Rule, error) { } return &Rule{ - Action: Allow, Pattern: pattern, Methods: methods, Raw: "allow " + spec, @@ -258,4 +210,4 @@ func ParseAllowSpecs(allowStrings []string) ([]*Rule, error) { out = append(out, r) } return out, nil -} \ No newline at end of file +} diff --git a/rules/rules_test.go b/rules/rules_test.go index a09ca6d..fdcd997 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -5,70 +5,64 @@ import ( "testing" ) -func TestNewRule(t *testing.T) { +func TestNewAllowRule(t *testing.T) { tests := []struct { name string - ruleStr string + spec string expectError bool - expAction Action expMethods map[string]bool expPattern string }{ { name: "simple allow rule", - ruleStr: "allow: github.com", + spec: "github.com", expectError: false, - expAction: Allow, expMethods: nil, expPattern: "github.com", }, { - name: "simple deny rule with wildcard", - ruleStr: "deny: telemetry.*", + name: "wildcard pattern", + spec: "api.*", expectError: false, - expAction: Deny, expMethods: nil, - expPattern: "telemetry.*", + expPattern: "api.*", }, { name: "method-specific allow rule", - ruleStr: "allow-get: api.github.com", + spec: "GET api.github.com", expectError: false, - expAction: Allow, expMethods: map[string]bool{"GET": true}, expPattern: "api.github.com", }, { - name: "multiple methods deny rule", - ruleStr: "deny-post-put: upload.*", + name: "multiple methods rule", + spec: "GET,POST,PUT api.*", expectError: false, - expAction: Deny, - expMethods: map[string]bool{"POST": true, "PUT": true}, - expPattern: "upload.*", + expMethods: map[string]bool{"GET": true, "POST": true, "PUT": true}, + expPattern: "api.*", }, { - name: "wildcard allow all", - ruleStr: "allow: *", + name: "allow all wildcard", + spec: "*", expectError: false, - expAction: Allow, expMethods: nil, expPattern: "*", }, { - name: "invalid format", - ruleStr: "invalid rule", + name: "empty spec", + spec: "", expectError: true, }, { - name: "invalid action", - ruleStr: "invalid: pattern", + name: "only spaces", + spec: " ", expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rule, err := newRule(tt.ruleStr) + rule, err := newAllowRule(tt.spec) if tt.expectError { if err == nil { t.Errorf("expected error but got none") @@ -81,10 +75,6 @@ func TestNewRule(t *testing.T) { return } - if rule.Action != tt.expAction { - t.Errorf("expected action %v, got %v", tt.expAction, rule.Action) - } - if rule.Pattern != tt.expPattern { t.Errorf("expected pattern %s, got %s", tt.expPattern, rule.Pattern) } @@ -103,6 +93,60 @@ func TestNewRule(t *testing.T) { } } +func TestParseAllowSpecs(t *testing.T) { + tests := []struct { + name string + allowStrings []string + expectError bool + expRuleCount int + }{ + { + name: "single allow rule", + allowStrings: []string{"github.com"}, + expectError: false, + expRuleCount: 1, + }, + { + name: "multiple allow rules", + allowStrings: []string{"github.com", "GET api.*", "POST,PUT upload.*"}, + expectError: false, + expRuleCount: 3, + }, + { + name: "empty list", + allowStrings: []string{}, + expectError: false, + expRuleCount: 0, + }, + { + name: "invalid rule in list", + allowStrings: []string{"github.com", ""}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rules, err := ParseAllowSpecs(tt.allowStrings) + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if len(rules) != tt.expRuleCount { + t.Errorf("expected %d rules, got %d", tt.expRuleCount, len(rules)) + } + }) + } +} + func TestWildcardMatch(t *testing.T) { tests := []struct { name string @@ -150,7 +194,7 @@ func TestWildcardMatch(t *testing.T) { } func TestRuleMatches(t *testing.T) { - rule, err := newRule("allow-get-post: api.github.*") + rule, err := newAllowRule("GET,POST api.github.*") if err != nil { t.Fatalf("failed to create rule: %v", err) } @@ -181,8 +225,8 @@ func TestRuleMatches(t *testing.T) { func TestRuleEngine(t *testing.T) { rules := []*Rule{ - {Action: Allow, Pattern: "github.com", Methods: nil, Raw: "allow: github.com"}, - {Action: Deny, Pattern: "*", Methods: nil, Raw: "deny: *"}, + {Pattern: "github.com", Methods: nil, Raw: "allow github.com"}, + {Pattern: "api.*", Methods: map[string]bool{"GET": true}, Raw: "allow GET api.*"}, } // Create a logger that discards output during tests @@ -199,6 +243,8 @@ func TestRuleEngine(t *testing.T) { expected Action }{ {"allow github", "GET", "https://github.com/user/repo", Allow}, + {"allow api GET", "GET", "https://api.example.com", Allow}, + {"deny api POST", "POST", "https://api.example.com", Deny}, {"deny other", "GET", "https://example.com", Deny}, } @@ -214,8 +260,8 @@ func TestRuleEngine(t *testing.T) { func TestRuleEngineWildcardRules(t *testing.T) { rules := []*Rule{ - {Action: Deny, Pattern: "telemetry.*", Methods: nil, Raw: "deny: telemetry.*"}, - {Action: Allow, Pattern: "*", Methods: nil, Raw: "allow: *"}, + {Pattern: "github.*", Methods: nil, Raw: "allow github.*"}, + {Pattern: "api.*.com", Methods: map[string]bool{"GET": true}, Raw: "allow GET api.*.com"}, } // Create a logger that discards output during tests @@ -231,9 +277,11 @@ func TestRuleEngineWildcardRules(t *testing.T) { url string expected Action }{ - {"deny telemetry", "GET", "https://telemetry.example.com", Deny}, - {"allow other", "GET", "https://api.github.com", Allow}, - {"deny telemetry subdomain", "POST", "https://telemetry.analytics.com", Deny}, + {"allow github", "GET", "https://github.com", Allow}, + {"allow github subdomain", "POST", "https://github.io", Allow}, + {"allow api GET", "GET", "https://api.example.com", Allow}, + {"deny api POST", "POST", "https://api.example.com", Deny}, + {"deny unmatched", "GET", "https://example.org", Deny}, } for _, tt := range tests {