Skip to content

Commit 8c7c640

Browse files
authored
Merge pull request #74 from doringeman/rebuild-routes-cors
Rebuild ModelsManager and Scheduler routers with new CORS origins
2 parents 9861625 + 60166a9 commit 8c7c640

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pkg/inference/models/manager.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"path"
1111
"strconv"
1212
"strings"
13+
"sync"
1314

1415
"github.com/docker/model-distribution/distribution"
1516
"github.com/docker/model-distribution/registry"
@@ -39,6 +40,8 @@ type Manager struct {
3940
distributionClient *distribution.Client
4041
// registryClient is the client for model registry.
4142
registryClient *registry.Client
43+
// lock is used to synchronize access to the models manager's router.
44+
lock sync.Mutex
4245
}
4346

4447
type ClientConfig struct {
@@ -100,6 +103,20 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma
100103
return m
101104
}
102105

106+
func (m *Manager) RebuildRoutes(allowedOrigins []string) {
107+
m.lock.Lock()
108+
defer m.lock.Unlock()
109+
// Clear existing routes and re-register them.
110+
m.router = http.NewServeMux()
111+
// Register routes.
112+
m.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
113+
http.Error(w, "not found", http.StatusNotFound)
114+
})
115+
for route, handler := range m.routeHandlers(allowedOrigins) {
116+
m.router.HandleFunc(route, handler)
117+
}
118+
}
119+
103120
func (m *Manager) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
104121
handlers := map[string]http.HandlerFunc{
105122
"POST " + inference.ModelsPrefix + "/create": m.handleCreateModel,
@@ -494,6 +511,8 @@ func (m *Manager) GetDiskUsage() (int64, error, int) {
494511

495512
// ServeHTTP implement net/http.Handler.ServeHTTP.
496513
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
514+
m.lock.Lock()
515+
defer m.lock.Unlock()
497516
m.router.ServeHTTP(w, r)
498517
}
499518

pkg/inference/scheduling/scheduler.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"io"
1010
"net/http"
1111
"strings"
12+
"sync"
1213
"time"
1314

1415
"github.com/docker/model-distribution/distribution"
@@ -38,6 +39,8 @@ type Scheduler struct {
3839
router *http.ServeMux
3940
// tracker is the metrics tracker.
4041
tracker *metrics.Tracker
42+
// lock is used to synchronize access to the scheduler's router.
43+
lock sync.Mutex
4144
}
4245

4346
// NewScheduler creates a new inference scheduler.
@@ -75,6 +78,20 @@ func NewScheduler(
7578
return s
7679
}
7780

81+
func (s *Scheduler) RebuildRoutes(allowedOrigins []string) {
82+
s.lock.Lock()
83+
defer s.lock.Unlock()
84+
// Clear existing routes and re-register them.
85+
s.router = http.NewServeMux()
86+
// Register routes.
87+
s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
88+
http.Error(w, "not found", http.StatusNotFound)
89+
})
90+
for route, handler := range s.routeHandlers(allowedOrigins) {
91+
s.router.HandleFunc(route, handler)
92+
}
93+
}
94+
7895
func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
7996
openAIRoutes := []string{
8097
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
@@ -332,5 +349,7 @@ func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) {
332349

333350
// ServeHTTP implements net/http.Handler.ServeHTTP.
334351
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
352+
s.lock.Lock()
353+
defer s.lock.Unlock()
335354
s.router.ServeHTTP(w, r)
336355
}

0 commit comments

Comments
 (0)