Skip to content

Commit 69951f5

Browse files
committed
Add --discovery-base-url flag to policy server
Enables serving discovery URLs through a CDN by allowing the policy server to return absolute URLs instead of relative ones. When set, both the Link header and redirect Location use the configured base URL.
1 parent ed83619 commit 69951f5

File tree

6 files changed

+184
-9
lines changed

6 files changed

+184
-9
lines changed

.tasks/closed/307ktjx8.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
---
2+
yatl_version: 1
3+
title: Add --discovery-base-url flag to policy server CLI
4+
id: 307ktjx8
5+
created: 2025-12-28T00:14:59.357075Z
6+
updated: 2025-12-28T00:20:29.291136Z
7+
author: Brian McCallister
8+
priority: high
9+
tags:
10+
- feature
11+
---
12+
13+
---
14+
# Log: 2025-12-28T00:14:59Z Brian McCallister
15+
16+
Created task.
17+
18+
---
19+
# Log: 2025-12-28T00:17:17Z Brian McCallister
20+
21+
Started working.
22+
23+
---
24+
# Log: 2025-12-28T00:20:29Z Brian McCallister
25+
26+
Closed: Implemented --discovery-base-url flag for CDN support

cmd/epithet/policy.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ type PolicyServerCLI struct {
4040

4141
// Default expiration
4242
DefaultExpiration string `help:"Default certificate expiration (e.g., 5m)" name:"default-expiration"`
43+
44+
// Discovery base URL for CDN support
45+
DiscoveryBaseURL string `help:"Base URL for discovery endpoints (e.g., https://cdn.example.com)" name:"discovery-base-url"`
4346
}
4447

4548
func (c *PolicyServerCLI) Run(logger *slog.Logger, tlsCfg tlsconfig.Config, unifiedConfig cue.Value) error {
@@ -82,10 +85,11 @@ func (c *PolicyServerCLI) Run(logger *slog.Logger, tlsCfg tlsconfig.Config, unif
8285

8386
// Create policy server handler
8487
handler := policyserver.NewHandler(policyserver.Config{
85-
CAPublicKey: sshcert.RawPublicKey(caPubkey),
86-
Validator: validator,
87-
Evaluator: eval,
88-
DiscoveryHash: cfg.DiscoveryHash(),
88+
CAPublicKey: sshcert.RawPublicKey(caPubkey),
89+
Validator: validator,
90+
Evaluator: eval,
91+
DiscoveryHash: cfg.DiscoveryHash(),
92+
DiscoveryBaseURL: c.DiscoveryBaseURL,
8993
})
9094

9195
// Create discovery handler
@@ -110,7 +114,7 @@ func (c *PolicyServerCLI) Run(logger *slog.Logger, tlsCfg tlsconfig.Config, unif
110114

111115
r.Post("/", handler)
112116
// Redirect endpoint: /d/current -> /d/{hash} (cached 5 min)
113-
r.Get("/d/current", policyserver.NewDiscoveryRedirectHandler(cfg.DiscoveryHash()))
117+
r.Get("/d/current", policyserver.NewDiscoveryRedirectHandler(cfg.DiscoveryHash(), c.DiscoveryBaseURL))
114118
// Content-addressed endpoint: /d/{hash} (immutable)
115119
r.Get("/d/*", discoveryHandler)
116120

pkg/policyserver/discovery.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,15 @@ func NewDiscoveryHandler(config DiscoveryConfig) http.HandlerFunc {
4343
// The redirect response is cached for 5 minutes to allow policy changes to propagate.
4444
// Clients should request /d/current and follow the redirect to /d/{hash}.
4545
// Uses 302 Found (temporary) rather than 301 (permanent) since the redirect target may change.
46-
func NewDiscoveryRedirectHandler(hash string) http.HandlerFunc {
46+
// If baseURL is set, redirects to an absolute URL on that base; otherwise uses relative URLs.
47+
func NewDiscoveryRedirectHandler(hash string, baseURL string) http.HandlerFunc {
48+
location := "/d/" + hash
49+
if baseURL != "" {
50+
location = strings.TrimSuffix(baseURL, "/") + "/d/" + hash
51+
}
4752
return func(w http.ResponseWriter, r *http.Request) {
4853
w.Header().Set("Cache-Control", "max-age=300")
49-
w.Header().Set("Location", "/d/"+hash)
54+
w.Header().Set("Location", location)
5055
w.WriteHeader(http.StatusFound)
5156
}
5257
}

pkg/policyserver/discovery_test.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ func TestDiscoveryHandler_EmptyPatterns(t *testing.T) {
175175
}
176176

177177
func TestDiscoveryRedirectHandler(t *testing.T) {
178-
handler := policyserver.NewDiscoveryRedirectHandler("abc123def456")
178+
handler := policyserver.NewDiscoveryRedirectHandler("abc123def456", "")
179179

180180
req := httptest.NewRequest(http.MethodGet, "/d/current", nil)
181181
w := httptest.NewRecorder()
@@ -199,3 +199,38 @@ func TestDiscoveryRedirectHandler(t *testing.T) {
199199
t.Errorf("expected Cache-Control 'max-age=300', got %q", cacheControl)
200200
}
201201
}
202+
203+
func TestDiscoveryRedirectHandler_WithBaseURL(t *testing.T) {
204+
handler := policyserver.NewDiscoveryRedirectHandler("abc123def456", "https://cdn.example.com")
205+
206+
req := httptest.NewRequest(http.MethodGet, "/d/current", nil)
207+
w := httptest.NewRecorder()
208+
209+
handler(w, req)
210+
211+
// Check status code is 302 Found (temporary redirect)
212+
if w.Code != http.StatusFound {
213+
t.Errorf("expected status 302, got %d", w.Code)
214+
}
215+
216+
// Check Location header points to absolute URL on base
217+
location := w.Header().Get("Location")
218+
if location != "https://cdn.example.com/d/abc123def456" {
219+
t.Errorf("expected Location 'https://cdn.example.com/d/abc123def456', got %q", location)
220+
}
221+
}
222+
223+
func TestDiscoveryRedirectHandler_WithBaseURLTrailingSlash(t *testing.T) {
224+
handler := policyserver.NewDiscoveryRedirectHandler("abc123def456", "https://cdn.example.com/")
225+
226+
req := httptest.NewRequest(http.MethodGet, "/d/current", nil)
227+
w := httptest.NewRecorder()
228+
229+
handler(w, req)
230+
231+
// Check Location header correctly handles trailing slash
232+
location := w.Header().Get("Location")
233+
if location != "https://cdn.example.com/d/abc123def456" {
234+
t.Errorf("expected Location 'https://cdn.example.com/d/abc123def456', got %q", location)
235+
}
236+
}

pkg/policyserver/policyserver.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ type Config struct {
114114
// If empty, no Link header is set.
115115
// The path is hardcoded to "/d/" + hash.
116116
DiscoveryHash string
117+
118+
// DiscoveryBaseURL is the base URL for discovery endpoints.
119+
// If set, discovery URLs will be absolute URLs on this base (e.g., "https://cdn.example.com").
120+
// If empty, discovery URLs will be relative (e.g., "/d/current").
121+
DiscoveryBaseURL string
117122
}
118123

119124
// handler holds the config and implements the HTTP handler methods
@@ -217,9 +222,14 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
217222

218223
// setDiscoveryHeader sets the Link header for discovery if configured.
219224
// Always points to /d/current which redirects to the content-addressed URL.
225+
// If DiscoveryBaseURL is set, uses absolute URLs; otherwise uses relative URLs.
220226
func (h *handler) setDiscoveryHeader(w http.ResponseWriter) {
221227
if h.config.DiscoveryHash != "" {
222-
w.Header().Set("Link", "</d/current>; rel=\"discovery\"")
228+
url := "/d/current"
229+
if h.config.DiscoveryBaseURL != "" {
230+
url = strings.TrimSuffix(h.config.DiscoveryBaseURL, "/") + "/d/current"
231+
}
232+
w.Header().Set("Link", "<"+url+">; rel=\"discovery\"")
223233
}
224234
}
225235

pkg/policyserver/policyserver_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,98 @@ func TestHandler_DiscoveryLinkHeader_NotSetWhenEmpty(t *testing.T) {
399399
t.Errorf("expected no Link header, got %q", link)
400400
}
401401
}
402+
403+
func TestHandler_DiscoveryLinkHeader_WithBaseURL(t *testing.T) {
404+
evaluator := &mockEvaluator{
405+
response: &policyserver.Response{
406+
CertParams: ca.CertParams{
407+
Identity: "test@example.com",
408+
Names: []string{"testuser"},
409+
Expiration: 5 * time.Minute,
410+
},
411+
Policy: policy.Policy{
412+
HostUsers: map[string][]string{
413+
"*": {"testuser"},
414+
},
415+
},
416+
},
417+
}
418+
419+
handler := policyserver.NewHandler(policyserver.Config{
420+
Validator: &mockValidator{},
421+
Evaluator: evaluator,
422+
DiscoveryHash: "abc123def456",
423+
DiscoveryBaseURL: "https://cdn.example.com",
424+
})
425+
426+
req := policyserver.Request{
427+
Token: encodeToken("test-token"),
428+
Connection: policy.Connection{
429+
RemoteHost: "server.example.com",
430+
RemoteUser: "testuser",
431+
Port: 22,
432+
},
433+
}
434+
body, _ := json.Marshal(req)
435+
436+
httpReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body))
437+
w := httptest.NewRecorder()
438+
439+
handler(w, httpReq)
440+
441+
if w.Code != http.StatusOK {
442+
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
443+
}
444+
445+
link := w.Header().Get("Link")
446+
// Link header uses absolute URL when DiscoveryBaseURL is set
447+
expected := "<https://cdn.example.com/d/current>; rel=\"discovery\""
448+
if link != expected {
449+
t.Errorf("expected Link header %q, got %q", expected, link)
450+
}
451+
}
452+
453+
func TestHandler_DiscoveryLinkHeader_WithBaseURLTrailingSlash(t *testing.T) {
454+
evaluator := &mockEvaluator{
455+
response: &policyserver.Response{
456+
CertParams: ca.CertParams{
457+
Identity: "test@example.com",
458+
Names: []string{"testuser"},
459+
Expiration: 5 * time.Minute,
460+
},
461+
},
462+
}
463+
464+
handler := policyserver.NewHandler(policyserver.Config{
465+
Validator: &mockValidator{},
466+
Evaluator: evaluator,
467+
DiscoveryHash: "abc123def456",
468+
DiscoveryBaseURL: "https://cdn.example.com/", // trailing slash
469+
})
470+
471+
req := policyserver.Request{
472+
Token: encodeToken("test-token"),
473+
Connection: policy.Connection{
474+
RemoteHost: "server.example.com",
475+
RemoteUser: "testuser",
476+
Port: 22,
477+
},
478+
}
479+
body, _ := json.Marshal(req)
480+
481+
httpReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body))
482+
w := httptest.NewRecorder()
483+
484+
handler(w, httpReq)
485+
486+
if w.Code != http.StatusOK {
487+
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
488+
}
489+
490+
link := w.Header().Get("Link")
491+
// Trailing slash should be stripped to avoid double slashes
492+
expected := "<https://cdn.example.com/d/current>; rel=\"discovery\""
493+
if link != expected {
494+
t.Errorf("expected Link header %q, got %q", expected, link)
495+
}
496+
}

0 commit comments

Comments
 (0)