@@ -399,3 +399,98 @@ func TestHandler_DiscoveryLinkHeader_NotSetWhenEmpty(t *testing.T) {
399399 t .Errorf ("expected no Link header, got %q" , link )
400400 }
401401}
402+
403+ func TestHandler_DiscoveryLinkHeader_WithBaseURL (t * testing.T ) {
404+ evaluator := & mockEvaluator {
405+ response : & policyserver.Response {
406+ CertParams : ca.CertParams {
407+ Identity : "test@example.com" ,
408+ Names : []string {"testuser" },
409+ Expiration : 5 * time .Minute ,
410+ },
411+ Policy : policy.Policy {
412+ HostUsers : map [string ][]string {
413+ "*" : {"testuser" },
414+ },
415+ },
416+ },
417+ }
418+
419+ handler := policyserver .NewHandler (policyserver.Config {
420+ Validator : & mockValidator {},
421+ Evaluator : evaluator ,
422+ DiscoveryHash : "abc123def456" ,
423+ DiscoveryBaseURL : "https://cdn.example.com" ,
424+ })
425+
426+ req := policyserver.Request {
427+ Token : encodeToken ("test-token" ),
428+ Connection : policy.Connection {
429+ RemoteHost : "server.example.com" ,
430+ RemoteUser : "testuser" ,
431+ Port : 22 ,
432+ },
433+ }
434+ body , _ := json .Marshal (req )
435+
436+ httpReq := httptest .NewRequest (http .MethodPost , "/" , bytes .NewReader (body ))
437+ w := httptest .NewRecorder ()
438+
439+ handler (w , httpReq )
440+
441+ if w .Code != http .StatusOK {
442+ t .Errorf ("expected status 200, got %d: %s" , w .Code , w .Body .String ())
443+ }
444+
445+ link := w .Header ().Get ("Link" )
446+ // Link header uses absolute URL when DiscoveryBaseURL is set
447+ expected := "<https://cdn.example.com/d/current>; rel=\" discovery\" "
448+ if link != expected {
449+ t .Errorf ("expected Link header %q, got %q" , expected , link )
450+ }
451+ }
452+
453+ func TestHandler_DiscoveryLinkHeader_WithBaseURLTrailingSlash (t * testing.T ) {
454+ evaluator := & mockEvaluator {
455+ response : & policyserver.Response {
456+ CertParams : ca.CertParams {
457+ Identity : "test@example.com" ,
458+ Names : []string {"testuser" },
459+ Expiration : 5 * time .Minute ,
460+ },
461+ },
462+ }
463+
464+ handler := policyserver .NewHandler (policyserver.Config {
465+ Validator : & mockValidator {},
466+ Evaluator : evaluator ,
467+ DiscoveryHash : "abc123def456" ,
468+ DiscoveryBaseURL : "https://cdn.example.com/" , // trailing slash
469+ })
470+
471+ req := policyserver.Request {
472+ Token : encodeToken ("test-token" ),
473+ Connection : policy.Connection {
474+ RemoteHost : "server.example.com" ,
475+ RemoteUser : "testuser" ,
476+ Port : 22 ,
477+ },
478+ }
479+ body , _ := json .Marshal (req )
480+
481+ httpReq := httptest .NewRequest (http .MethodPost , "/" , bytes .NewReader (body ))
482+ w := httptest .NewRecorder ()
483+
484+ handler (w , httpReq )
485+
486+ if w .Code != http .StatusOK {
487+ t .Errorf ("expected status 200, got %d: %s" , w .Code , w .Body .String ())
488+ }
489+
490+ link := w .Header ().Get ("Link" )
491+ // Trailing slash should be stripped to avoid double slashes
492+ expected := "<https://cdn.example.com/d/current>; rel=\" discovery\" "
493+ if link != expected {
494+ t .Errorf ("expected Link header %q, got %q" , expected , link )
495+ }
496+ }
0 commit comments