Skip to content

Commit c7fe66a

Browse files
ilopezlunadoringemanCopilot
authored
Cors options preflight (#200)
* fix: also register OPTIONS for CORS preflight Signed-off-by: Dorin Geman <[email protected]> * feat: enhance CORS middleware to intercept OPTIONS requests and validate origins * feat: update CORS test cases to include DELETE method and adjust allowed methods * Update pkg/middleware/cors.go Co-authored-by: Copilot <[email protected]> --------- Signed-off-by: Dorin Geman <[email protected]> Co-authored-by: Dorin Geman <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent af6253a commit c7fe66a

File tree

5 files changed

+33
-28
lines changed

5 files changed

+33
-28
lines changed

main.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,13 @@ func main() {
122122
)
123123

124124
router := routing.NewNormalizedServeMux()
125-
for _, route := range modelManager.GetRoutes() {
126-
router.Handle(route, modelManager)
127-
}
128-
for _, route := range scheduler.GetRoutes() {
129-
router.Handle(route, scheduler)
130-
}
125+
126+
// Register path prefixes to forward all HTTP methods (including OPTIONS) to components
127+
// Components handle method routing internally
128+
// Register both with and without trailing slash to avoid redirects
129+
router.Handle(inference.ModelsPrefix, modelManager)
130+
router.Handle(inference.ModelsPrefix+"/", modelManager)
131+
router.Handle(inference.InferencePrefix+"/", scheduler)
131132

132133
// Add metrics endpoint if enabled
133134
if os.Getenv("DISABLE_METRICS") != "1" {

pkg/inference/models/manager.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,6 @@ func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
136136
}
137137
}
138138

139-
func (m *Manager) GetRoutes() []string {
140-
routeHandlers := m.routeHandlers()
141-
routes := make([]string, 0, len(routeHandlers))
142-
for route := range routeHandlers {
143-
routes = append(routes, route)
144-
}
145-
return routes
146-
}
147-
148139
// handleCreateModel handles POST <inference-prefix>/models/create requests.
149140
func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
150141
if m.distributionClient == nil {

pkg/inference/scheduling/scheduler.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,6 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
122122
return m
123123
}
124124

125-
func (s *Scheduler) GetRoutes() []string {
126-
routeHandlers := s.routeHandlers()
127-
routes := make([]string, 0, len(routeHandlers))
128-
for route := range routeHandlers {
129-
routes = append(routes, route)
130-
}
131-
return routes
132-
}
133-
134125
// Run is the scheduler's main run loop. By the time it returns, all inference
135126
// backends will have been unloaded from memory.
136127
func (s *Scheduler) Run(ctx context.Context) error {

pkg/middleware/cors.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88

99
// CorsMiddleware handles CORS and OPTIONS preflight requests with optional allowedOrigins.
1010
// If allowedOrigins is nil or empty, it falls back to getAllowedOrigins().
11+
// This middleware intercepts OPTIONS requests only if the Origin header is present and valid,
12+
// otherwise passing the request to the router (allowing 405/404 responses as appropriate).
1113
func CorsMiddleware(allowedOrigins []string, next http.Handler) http.Handler {
1214
if len(allowedOrigins) == 0 {
1315
allowedOrigins = getAllowedOrigins()
@@ -25,14 +27,26 @@ func CorsMiddleware(allowedOrigins []string, next http.Handler) http.Handler {
2527
}
2628

2729
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28-
if origin := r.Header.Get("Origin"); origin != "" && (allowAll || originAllowed(origin, allowedSet)) {
30+
origin := r.Header.Get("Origin")
31+
32+
// Set CORS headers if origin is allowed
33+
if origin != "" && (allowAll || originAllowed(origin, allowedSet)) {
2934
w.Header().Set("Access-Control-Allow-Origin", origin)
3035
}
3136

32-
// Handle OPTIONS requests.
37+
// Handle OPTIONS requests with origin validation.
38+
// Only intercept OPTIONS if the origin is valid to prevent unauthorized preflight requests.
3339
if r.Method == http.MethodOptions {
40+
// Require valid Origin header for OPTIONS requests
41+
if origin == "" || !(allowAll || originAllowed(origin, allowedSet)) {
42+
// No origin or invalid origin - pass to router for proper 405/404 response
43+
next.ServeHTTP(w, r)
44+
return
45+
}
46+
47+
// Valid origin - handle OPTIONS with CORS headers
3448
w.Header().Set("Access-Control-Allow-Credentials", "true")
35-
w.Header().Set("Access-Control-Allow-Methods", "GET, POST")
49+
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE")
3650
w.Header().Set("Access-Control-Allow-Headers", "*")
3751
w.WriteHeader(http.StatusNoContent)
3852
return

pkg/middleware/cors_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,18 @@ func TestCorsMiddleware(t *testing.T) {
4949
wantStatus: http.StatusNoContent,
5050
wantHeaders: map[string]string{
5151
"Access-Control-Allow-Credentials": "true",
52-
"Access-Control-Allow-Methods": "GET, POST",
52+
"Access-Control-Allow-Methods": "GET, POST, DELETE",
5353
"Access-Control-Allow-Headers": "*",
5454
},
5555
},
56+
{
57+
name: "DeleteRequest",
58+
allowedOrigins: []string{"http://foo.com"},
59+
method: "DELETE",
60+
origin: "http://foo.com",
61+
wantStatus: http.StatusOK,
62+
wantHeaders: map[string]string{"Access-Control-Allow-Origin": "http://foo.com"},
63+
},
5664
{
5765
name: "NoOriginHeader",
5866
allowedOrigins: []string{"http://foo.com"},

0 commit comments

Comments
 (0)