Skip to content

Commit b112d43

Browse files
committed
Add CORS configuration to dev server API endpoints
1 parent 595516e commit b112d43

File tree

6 files changed

+146
-0
lines changed

6 files changed

+146
-0
lines changed

cmd/cliflags/flags.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ const (
88
AccessTokenFlag = "access-token"
99
AnalyticsOptOut = "analytics-opt-out"
1010
BaseURIFlag = "base-uri"
11+
CorsEnabledFlag = "cors-enabled"
12+
CorsOriginFlag = "cors-origin"
1113
DataFlag = "data"
1214
DevStreamURIFlag = "dev-stream-uri"
1315
EmailsFlag = "emails"
@@ -22,6 +24,8 @@ const (
2224
AccessTokenFlagDescription = "LaunchDarkly access token with write-level access"
2325
AnalyticsOptOutDescription = "Opt out of analytics tracking"
2426
BaseURIFlagDescription = "LaunchDarkly base URI"
27+
CorsEnabledFlagDescription = "Enable CORS headers for browser-based developer tools (default: true)"
28+
CorsOriginFlagDescription = "Allowed CORS origin. Use '*' for all origins (default: '*')"
2529
DevStreamURIDescription = "Streaming service endpoint that the dev server uses to obtain authoritative flag data. This may be a LaunchDarkly or Relay Proxy endpoint"
2630
EnvironmentFlagDescription = "Default environment key"
2731
FlagFlagDescription = "Default feature flag key"
@@ -36,6 +40,8 @@ func AllFlagsHelp() map[string]string {
3640
AccessTokenFlag: AccessTokenFlagDescription,
3741
AnalyticsOptOut: AnalyticsOptOutDescription,
3842
BaseURIFlag: BaseURIFlagDescription,
43+
CorsEnabledFlag: CorsEnabledFlagDescription,
44+
CorsOriginFlag: CorsOriginFlagDescription,
3945
DevStreamURIFlag: DevStreamURIDescription,
4046
EnvironmentFlag: EnvironmentFlagDescription,
4147
FlagFlag: FlagFlagDescription,

cmd/dev_server/dev_server.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,20 @@ func NewDevServerCmd(client resources.Client, analyticsTrackerFn analytics.Track
5050

5151
_ = viper.BindPFlag(cliflags.PortFlag, cmd.PersistentFlags().Lookup(cliflags.PortFlag))
5252

53+
cmd.PersistentFlags().Bool(
54+
cliflags.CorsEnabledFlag,
55+
true,
56+
cliflags.CorsEnabledFlagDescription,
57+
)
58+
_ = viper.BindPFlag(cliflags.CorsEnabledFlag, cmd.PersistentFlags().Lookup(cliflags.CorsEnabledFlag))
59+
60+
cmd.PersistentFlags().String(
61+
cliflags.CorsOriginFlag,
62+
"*",
63+
cliflags.CorsOriginFlagDescription,
64+
)
65+
_ = viper.BindPFlag(cliflags.CorsOriginFlag, cmd.PersistentFlags().Lookup(cliflags.CorsOriginFlag))
66+
5367
// Add subcommands here
5468
cmd.AddGroup(&cobra.Group{ID: "projects", Title: "Project commands:"})
5569
cmd.AddCommand(NewListProjectsCmd(client))

cmd/dev_server/start_server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ func startServer(client dev_server.Client) func(*cobra.Command, []string) error
8989
BaseURI: viper.GetString(cliflags.BaseURIFlag),
9090
DevStreamURI: viper.GetString(cliflags.DevStreamURIFlag),
9191
Port: viper.GetString(cliflags.PortFlag),
92+
CorsEnabled: viper.GetBool(cliflags.CorsEnabledFlag),
93+
CorsOrigin: viper.GetString(cliflags.CorsOriginFlag),
9294
InitialProjectSettings: initialSetting,
9395
}
9496

internal/dev_server/dev_server.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ type ServerParams struct {
2929
BaseURI string
3030
DevStreamURI string
3131
Port string
32+
CorsEnabled bool
33+
CorsOrigin string
3234
InitialProjectSettings model.InitialProjectSettings
3335
}
3436

@@ -65,6 +67,7 @@ func (c LDClient) RunServer(ctx context.Context, serverParams ServerParams) {
6567
r.PathPrefix("/ui/").Handler(http.StripPrefix("/ui/", ui.AssetHandler))
6668
sdk.BindRoutes(r)
6769
handler := api.HandlerFromMux(apiServer, r)
70+
handler = sdk.ApiCorsHeadersWithConfig(serverParams.CorsEnabled, serverParams.CorsOrigin)(handler)
6871
handler = handlers.CombinedLoggingHandler(os.Stdout, handler)
6972
handler = handlers.RecoveryHandler(handlers.PrintRecoveryStack(true))(handler)
7073

internal/dev_server/sdk/cors.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,48 @@ func EventsCorsHeaders(handler http.Handler) http.Handler {
2525
handler.ServeHTTP(writer, request)
2626
})
2727
}
28+
29+
// ApiCorsHeaders provides CORS support for the dev-server API endpoints (/dev/*)
30+
// Supports all HTTP methods needed by browser-based developer tools
31+
func ApiCorsHeaders(handler http.Handler) http.Handler {
32+
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
33+
writer.Header().Set("Access-Control-Allow-Origin", "*")
34+
writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
35+
writer.Header().Set("Access-Control-Allow-Credentials", "true")
36+
writer.Header().Set("Access-Control-Allow-Headers", "Accept,Content-Type,Content-Length,Accept-Encoding,Authorization,X-Requested-With")
37+
writer.Header().Set("Access-Control-Expose-Headers", "Date,Content-Length")
38+
writer.Header().Set("Access-Control-Max-Age", "300")
39+
40+
// Handle preflight OPTIONS requests
41+
if request.Method == http.MethodOptions {
42+
writer.WriteHeader(http.StatusOK)
43+
return
44+
}
45+
46+
handler.ServeHTTP(writer, request)
47+
})
48+
}
49+
50+
// ApiCorsHeadersWithConfig provides configurable CORS support for the dev-server API endpoints
51+
func ApiCorsHeadersWithConfig(enabled bool, origin string) func(http.Handler) http.Handler {
52+
return func(handler http.Handler) http.Handler {
53+
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
54+
if enabled {
55+
writer.Header().Set("Access-Control-Allow-Origin", origin)
56+
writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
57+
writer.Header().Set("Access-Control-Allow-Credentials", "true")
58+
writer.Header().Set("Access-Control-Allow-Headers", "Accept,Content-Type,Content-Length,Accept-Encoding,Authorization,X-Requested-With")
59+
writer.Header().Set("Access-Control-Expose-Headers", "Date,Content-Length")
60+
writer.Header().Set("Access-Control-Max-Age", "300")
61+
62+
// Handle preflight OPTIONS requests
63+
if request.Method == http.MethodOptions {
64+
writer.WriteHeader(http.StatusOK)
65+
return
66+
}
67+
}
68+
69+
handler.ServeHTTP(writer, request)
70+
})
71+
}
72+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package sdk
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
)
8+
9+
func TestApiCorsHeadersWithConfig_Enabled(t *testing.T) {
10+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
11+
w.WriteHeader(http.StatusOK)
12+
w.Write([]byte("test response"))
13+
})
14+
15+
corsHandler := ApiCorsHeadersWithConfig(true, "*")(handler)
16+
17+
// Test GET request
18+
req := httptest.NewRequest("GET", "/dev/projects", nil)
19+
w := httptest.NewRecorder()
20+
corsHandler.ServeHTTP(w, req)
21+
22+
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
23+
t.Errorf("Expected Access-Control-Allow-Origin to be '*', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
24+
}
25+
26+
if w.Header().Get("Access-Control-Allow-Methods") != "GET,POST,PUT,PATCH,DELETE,OPTIONS" {
27+
t.Errorf("Expected Access-Control-Allow-Methods to include all methods, got '%s'", w.Header().Get("Access-Control-Allow-Methods"))
28+
}
29+
30+
if w.Code != http.StatusOK {
31+
t.Errorf("Expected status code 200, got %d", w.Code)
32+
}
33+
}
34+
35+
func TestApiCorsHeadersWithConfig_OptionsRequest(t *testing.T) {
36+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37+
t.Error("Handler should not be called for OPTIONS request")
38+
})
39+
40+
corsHandler := ApiCorsHeadersWithConfig(true, "https://example.com")(handler)
41+
42+
// Test OPTIONS preflight request
43+
req := httptest.NewRequest("OPTIONS", "/dev/projects", nil)
44+
w := httptest.NewRecorder()
45+
corsHandler.ServeHTTP(w, req)
46+
47+
if w.Header().Get("Access-Control-Allow-Origin") != "https://example.com" {
48+
t.Errorf("Expected Access-Control-Allow-Origin to be 'https://example.com', got '%s'", w.Header().Get("Access-Control-Allow-Origin"))
49+
}
50+
51+
if w.Code != http.StatusOK {
52+
t.Errorf("Expected status code 200 for OPTIONS request, got %d", w.Code)
53+
}
54+
}
55+
56+
func TestApiCorsHeadersWithConfig_Disabled(t *testing.T) {
57+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58+
w.WriteHeader(http.StatusOK)
59+
w.Write([]byte("test response"))
60+
})
61+
62+
corsHandler := ApiCorsHeadersWithConfig(false, "*")(handler)
63+
64+
// Test GET request with CORS disabled
65+
req := httptest.NewRequest("GET", "/dev/projects", nil)
66+
w := httptest.NewRecorder()
67+
corsHandler.ServeHTTP(w, req)
68+
69+
if w.Header().Get("Access-Control-Allow-Origin") != "" {
70+
t.Errorf("Expected no CORS headers when disabled, but got Access-Control-Allow-Origin: '%s'", w.Header().Get("Access-Control-Allow-Origin"))
71+
}
72+
73+
if w.Code != http.StatusOK {
74+
t.Errorf("Expected status code 200, got %d", w.Code)
75+
}
76+
}

0 commit comments

Comments
 (0)