diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 43389b5cc..c87efa831 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -10,6 +10,7 @@ import ( "path" "strconv" "strings" + "sync" "github.com/docker/model-distribution/distribution" "github.com/docker/model-distribution/registry" @@ -39,6 +40,8 @@ type Manager struct { distributionClient *distribution.Client // registryClient is the client for model registry. registryClient *registry.Client + // lock is used to synchronize access to the models manager's router. + lock sync.Mutex } type ClientConfig struct { @@ -100,6 +103,20 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma return m } +func (m *Manager) RebuildRoutes(allowedOrigins []string) { + m.lock.Lock() + defer m.lock.Unlock() + // Clear existing routes and re-register them. + m.router = http.NewServeMux() + // Register routes. + m.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + }) + for route, handler := range m.routeHandlers(allowedOrigins) { + m.router.HandleFunc(route, handler) + } +} + func (m *Manager) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc { handlers := map[string]http.HandlerFunc{ "POST " + inference.ModelsPrefix + "/create": m.handleCreateModel, @@ -494,6 +511,8 @@ func (m *Manager) GetDiskUsage() (int64, error, int) { // ServeHTTP implement net/http.Handler.ServeHTTP. func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) { + m.lock.Lock() + defer m.lock.Unlock() m.router.ServeHTTP(w, r) } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 3e5e232b3..40e536ed6 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "sync" "time" "github.com/docker/model-distribution/distribution" @@ -38,6 +39,8 @@ type Scheduler struct { router *http.ServeMux // tracker is the metrics tracker. tracker *metrics.Tracker + // lock is used to synchronize access to the scheduler's router. + lock sync.Mutex } // NewScheduler creates a new inference scheduler. @@ -75,6 +78,20 @@ func NewScheduler( return s } +func (s *Scheduler) RebuildRoutes(allowedOrigins []string) { + s.lock.Lock() + defer s.lock.Unlock() + // Clear existing routes and re-register them. + s.router = http.NewServeMux() + // Register routes. + s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + }) + for route, handler := range s.routeHandlers(allowedOrigins) { + s.router.HandleFunc(route, handler) + } +} + func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc { openAIRoutes := []string{ "POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions", @@ -332,5 +349,7 @@ func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) { // ServeHTTP implements net/http.Handler.ServeHTTP. func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.lock.Lock() + defer s.lock.Unlock() s.router.ServeHTTP(w, r) }