Skip to content

Commit 784d6df

Browse files
committed
Rebuild ModelsManager and Scheduler routers with new CORS origins
Signed-off-by: Dorin Geman <[email protected]>
1 parent 9861625 commit 784d6df

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pkg/inference/models/manager.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"path"
1111
"strconv"
1212
"strings"
13+
"sync"
1314

1415
"github.com/docker/model-distribution/distribution"
1516
"github.com/docker/model-distribution/registry"
@@ -26,6 +27,8 @@ const (
2627
maximumConcurrentModelPulls = 2
2728
)
2829

30+
var lock sync.Mutex
31+
2932
// Manager manages inference model pulls and storage.
3033
type Manager struct {
3134
// log is the associated logger.
@@ -100,6 +103,20 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma
100103
return m
101104
}
102105

106+
func (m *Manager) RebuildRoutes(allowedOrigins []string) {
107+
lock.Lock()
108+
defer lock.Unlock()
109+
// Clear existing routes and re-register them.
110+
m.router = http.NewServeMux()
111+
// Register routes.
112+
m.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
113+
http.Error(w, "not found", http.StatusNotFound)
114+
})
115+
for route, handler := range m.routeHandlers(allowedOrigins) {
116+
m.router.HandleFunc(route, handler)
117+
}
118+
}
119+
103120
func (m *Manager) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
104121
handlers := map[string]http.HandlerFunc{
105122
"POST " + inference.ModelsPrefix + "/create": m.handleCreateModel,
@@ -494,6 +511,8 @@ func (m *Manager) GetDiskUsage() (int64, error, int) {
494511

495512
// ServeHTTP implement net/http.Handler.ServeHTTP.
496513
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
514+
lock.Lock()
515+
defer lock.Unlock()
497516
m.router.ServeHTTP(w, r)
498517
}
499518

pkg/inference/scheduling/scheduler.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"io"
1010
"net/http"
1111
"strings"
12+
"sync"
1213
"time"
1314

1415
"github.com/docker/model-distribution/distribution"
@@ -19,6 +20,8 @@ import (
1920
"golang.org/x/sync/errgroup"
2021
)
2122

23+
var lock sync.Mutex
24+
2225
// Scheduler is used to coordinate inference scheduling across multiple backends
2326
// and models.
2427
type Scheduler struct {
@@ -75,6 +78,20 @@ func NewScheduler(
7578
return s
7679
}
7780

81+
func (s *Scheduler) RebuildRoutes(allowedOrigins []string) {
82+
lock.Lock()
83+
defer lock.Unlock()
84+
// Clear existing routes and re-register them.
85+
s.router = http.NewServeMux()
86+
// Register routes.
87+
s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
88+
http.Error(w, "not found", http.StatusNotFound)
89+
})
90+
for route, handler := range s.routeHandlers(allowedOrigins) {
91+
s.router.HandleFunc(route, handler)
92+
}
93+
}
94+
7895
func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
7996
openAIRoutes := []string{
8097
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
@@ -332,5 +349,7 @@ func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) {
332349

333350
// ServeHTTP implements net/http.Handler.ServeHTTP.
334351
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
352+
lock.Lock()
353+
defer lock.Unlock()
335354
s.router.ServeHTTP(w, r)
336355
}

0 commit comments

Comments
 (0)