Skip to content

Commit 3a578be

Browse files
committed
split engine and rules code into separate files to make easier to read through
1 parent ced3bc8 commit 3a578be

File tree

8 files changed

+410
-399
lines changed

8 files changed

+410
-399
lines changed

boundary.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ import (
1111
"github.com/coder/boundary/audit"
1212
"github.com/coder/boundary/jail"
1313
"github.com/coder/boundary/proxy"
14-
"github.com/coder/boundary/rules"
14+
"github.com/coder/boundary/rulesengine"
1515
)
1616

1717
type Config struct {
18-
RuleEngine rules.Engine
18+
RuleEngine rulesengine.Engine
1919
Auditor audit.Auditor
2020
TLSConfig *tls.Config
2121
Logger *slog.Logger

cli/cli.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
"github.com/coder/boundary"
1313
"github.com/coder/boundary/audit"
1414
"github.com/coder/boundary/jail"
15-
"github.com/coder/boundary/rules"
15+
"github.com/coder/boundary/rulesengine"
1616
"github.com/coder/boundary/tls"
1717
"github.com/coder/boundary/util"
1818
"github.com/coder/serpent"
@@ -101,14 +101,14 @@ func Run(ctx context.Context, config Config, args []string) error {
101101
}
102102

103103
// Parse allow rules
104-
allowRules, err := rules.ParseAllowSpecs(config.AllowStrings)
104+
allowRules, err := rulesengine.ParseAllowSpecs(config.AllowStrings)
105105
if err != nil {
106106
logger.Error("Failed to parse allow rules", "error", err)
107107
return fmt.Errorf("failed to parse allow rules: %v", err)
108108
}
109109

110110
// Create rule engine
111-
ruleEngine := rules.NewRuleEngine(allowRules, logger)
111+
ruleEngine := rulesengine.NewRuleEngine(allowRules, logger)
112112

113113
// Create auditor
114114
auditor := audit.NewLogAuditor(logger)

proxy/proxy.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ import (
1515
"sync/atomic"
1616

1717
"github.com/coder/boundary/audit"
18-
"github.com/coder/boundary/rules"
18+
"github.com/coder/boundary/rulesengine"
1919
)
2020

2121
// Server handles HTTP and HTTPS requests with rule-based filtering
2222
type Server struct {
23-
ruleEngine rules.Engine
23+
ruleEngine rulesengine.Engine
2424
auditor audit.Auditor
2525
logger *slog.Logger
2626
tlsConfig *tls.Config
@@ -33,7 +33,7 @@ type Server struct {
3333
// Config holds configuration for the proxy server
3434
type Config struct {
3535
HTTPPort int
36-
RuleEngine rules.Engine
36+
RuleEngine rulesengine.Engine
3737
Auditor audit.Auditor
3838
Logger *slog.Logger
3939
TLSConfig *tls.Config

proxy/proxy_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
"github.com/stretchr/testify/require"
1818

1919
"github.com/coder/boundary/audit"
20-
"github.com/coder/boundary/rules"
20+
"github.com/coder/boundary/rulesengine"
2121
)
2222

2323
// mockAuditor is a simple mock auditor for testing
@@ -35,13 +35,13 @@ func TestProxyServerBasicHTTP(t *testing.T) {
3535
}))
3636

3737
// Create test rules (allow all for testing)
38-
testRules, err := rules.ParseAllowSpecs([]string{"method=*"})
38+
testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"})
3939
if err != nil {
4040
t.Fatalf("Failed to parse test rules: %v", err)
4141
}
4242

4343
// Create rule engine
44-
ruleEngine := rules.NewRuleEngine(testRules, logger)
44+
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
4545

4646
// Create mock auditor
4747
auditor := &mockAuditor{}
@@ -116,13 +116,13 @@ func TestProxyServerBasicHTTPS(t *testing.T) {
116116
}))
117117

118118
// Create test rules (allow all for testing)
119-
testRules, err := rules.ParseAllowSpecs([]string{"method=*"})
119+
testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"})
120120
if err != nil {
121121
t.Fatalf("Failed to parse test rules: %v", err)
122122
}
123123

124124
// Create rule engine
125-
ruleEngine := rules.NewRuleEngine(testRules, logger)
125+
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
126126

127127
// Create mock auditor
128128
auditor := &mockAuditor{}
@@ -210,13 +210,13 @@ func TestProxyServerCONNECT(t *testing.T) {
210210
}))
211211

212212
// Create test rules (allow all for testing)
213-
testRules, err := rules.ParseAllowSpecs([]string{"method=*"})
213+
testRules, err := rulesengine.ParseAllowSpecs([]string{"method=*"})
214214
if err != nil {
215215
t.Fatalf("Failed to parse test rules: %v", err)
216216
}
217217

218218
// Create rule engine
219-
ruleEngine := rules.NewRuleEngine(testRules, logger)
219+
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
220220

221221
// Create mock auditor
222222
auditor := &mockAuditor{}

rulesengine/engine.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package rulesengine
2+
3+
import (
4+
"log/slog"
5+
neturl "net/url"
6+
"strings"
7+
)
8+
9+
// Engine evaluates HTTP requests against a set of rules.
10+
type Engine struct {
11+
rules []Rule
12+
logger *slog.Logger
13+
}
14+
15+
// NewRuleEngine creates a new rule engine
16+
func NewRuleEngine(rules []Rule, logger *slog.Logger) Engine {
17+
return Engine{
18+
rules: rules,
19+
logger: logger,
20+
}
21+
}
22+
23+
// Result contains the result of rule evaluation
24+
type Result struct {
25+
Allowed bool
26+
Rule string // The rule that matched (if any)
27+
}
28+
29+
// Evaluate evaluates a request and returns both result and matching rule
30+
func (re *Engine) Evaluate(method, url string) Result {
31+
// Check if any allow rule matches
32+
for _, rule := range re.rules {
33+
if re.matches(rule, method, url) {
34+
return Result{
35+
Allowed: true,
36+
Rule: rule.Raw,
37+
}
38+
}
39+
}
40+
41+
// Default deny if no allow rules match
42+
return Result{
43+
Allowed: false,
44+
Rule: "",
45+
}
46+
}
47+
48+
// Matches checks if the rule matches the given method and URL using wildcard patterns
49+
func (re *Engine) matches(r Rule, method, url string) bool {
50+
51+
// Check method patterns if they exist
52+
if r.MethodPatterns != nil {
53+
methodMatches := false
54+
for mp := range r.MethodPatterns {
55+
if string(mp) == method || string(mp) == "*" {
56+
methodMatches = true
57+
break
58+
}
59+
}
60+
if !methodMatches {
61+
re.logger.Debug("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url)
62+
return false
63+
}
64+
}
65+
66+
parsedUrl, err := neturl.Parse(url)
67+
if err != nil {
68+
re.logger.Debug("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err)
69+
return false
70+
}
71+
72+
if r.HostPattern != nil {
73+
// For a host pattern to match, every label has to match or be an `*`.
74+
// Subdomains also match automatically, meaning if the pattern is "example.com"
75+
// and the real is "api.example.com", it should match. We check this by comparing
76+
// from the end of the actual hostname with the pattern (which is in normal order).
77+
78+
labels := strings.Split(parsedUrl.Hostname(), ".")
79+
80+
// If the host pattern is longer than the actual host, it's definitely not a match
81+
if len(r.HostPattern) > len(labels) {
82+
re.logger.Debug("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels))
83+
return false
84+
}
85+
86+
// Since host patterns cannot end with asterisk, we only need to handle:
87+
// "example.com" or "*.example.com" - match from the end (allowing subdomains)
88+
for i, lp := range r.HostPattern {
89+
labelIndex := len(labels) - len(r.HostPattern) + i
90+
if string(lp) != labels[labelIndex] && lp != "*" {
91+
re.logger.Debug("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex])
92+
return false
93+
}
94+
}
95+
}
96+
97+
if r.PathPattern != nil {
98+
segments := strings.Split(parsedUrl.Path, "/")
99+
100+
// Skip the first empty segment if the path starts with "/"
101+
if len(segments) > 0 && segments[0] == "" {
102+
segments = segments[1:]
103+
}
104+
105+
// If the path pattern is longer than the actual path, definitely not a match
106+
if len(r.PathPattern) > len(segments) {
107+
re.logger.Debug("rule does not match", "reason", "path pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.PathPattern), "path_segments", len(segments))
108+
return false
109+
}
110+
111+
// Each segment in the pattern must be either as asterisk or match the actual path segment
112+
for i, sp := range r.PathPattern {
113+
if string(sp) != segments[i] && sp != "*" {
114+
re.logger.Debug("rule does not match", "reason", "path pattern segment mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(sp), "actual", segments[i])
115+
return false
116+
}
117+
}
118+
}
119+
120+
re.logger.Debug("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url)
121+
return true
122+
}

0 commit comments

Comments
 (0)