Skip to content

Commit c818aab

Browse files
authored
Merge pull request #46 from doringeman/unload
Add /engines/unload
2 parents 7ddea9d + 47a0fae commit c818aab

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

pkg/inference/scheduling/api.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,15 @@ type DiskUsage struct {
6161
ModelsDiskUsage float64 `json:"models_disk_usage"`
6262
DefaultBackendDiskUsage float64 `json:"default_backend_disk_usage"`
6363
}
64+
65+
// UnloadRequest is used to specify which models to unload.
66+
type UnloadRequest struct {
67+
All bool `json:"all"`
68+
Backend string `json:"backend"`
69+
Model string `json:"model"`
70+
}
71+
72+
// UnloadResponse is used to return the number of unloaded runners (backend, model).
73+
type UnloadResponse struct {
74+
UnloadedRunners int `json:"unloaded_runners"`
75+
}

pkg/inference/scheduling/loader.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,42 @@ func (l *loader) evict(idleOnly bool) int {
177177
return len(l.runners)
178178
}
179179

180+
// evictRunner evicts a specific runner. The caller must hold the loader lock.
181+
// It returns the number of remaining runners.
182+
func (l *loader) evictRunner(backend, model string) int {
183+
allBackends := backend == ""
184+
for r, slot := range l.runners {
185+
if (allBackends || r.backend == backend) && r.model == model {
186+
l.log.Infof("Evicting %s backend runner with model %s in %s mode",
187+
r.backend, r.model, r.mode,
188+
)
189+
l.slots[slot].terminate()
190+
l.slots[slot] = nil
191+
l.availableMemory += l.allocations[slot]
192+
l.allocations[slot] = 0
193+
l.timestamps[slot] = time.Time{}
194+
delete(l.runners, r)
195+
}
196+
}
197+
return len(l.runners)
198+
}
199+
200+
// Unload unloads runners and returns the number of unloaded runners.
201+
func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int {
202+
if !l.lock(ctx) {
203+
return 0
204+
}
205+
defer l.unlock()
206+
207+
return len(l.runners) - func() int {
208+
if unload.All {
209+
return l.evict(false)
210+
} else {
211+
return l.evictRunner(unload.Backend, unload.Model)
212+
}
213+
}()
214+
}
215+
180216
// stopAndDrainTimer stops and drains a timer without knowing if it was running.
181217
func stopAndDrainTimer(timer *time.Timer) {
182218
timer.Stop()

pkg/inference/scheduling/scheduler.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
8484
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
8585
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
8686
m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage
87+
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
8788
return m
8889
}
8990

@@ -289,6 +290,33 @@ func (s *Scheduler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) {
289290
}
290291
}
291292

293+
// Unload unloads the specified runners (backend, model) from the backend.
294+
// Currently, this doesn't work for runners that are handling an OpenAI request.
295+
func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) {
296+
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
297+
if err != nil {
298+
if _, ok := err.(*http.MaxBytesError); ok {
299+
http.Error(w, "request too large", http.StatusBadRequest)
300+
} else {
301+
http.Error(w, "unknown error", http.StatusInternalServerError)
302+
}
303+
return
304+
}
305+
306+
var unloadRequest UnloadRequest
307+
if err := json.Unmarshal(body, &unloadRequest); err != nil {
308+
http.Error(w, "invalid request", http.StatusBadRequest)
309+
return
310+
}
311+
312+
unloadedRunners := UnloadResponse{s.loader.Unload(r.Context(), unloadRequest)}
313+
w.Header().Set("Content-Type", "application/json")
314+
if err := json.NewEncoder(w).Encode(unloadedRunners); err != nil {
315+
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
316+
return
317+
}
318+
}
319+
292320
// ServeHTTP implements net/http.Handler.ServeHTTP.
293321
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
294322
s.router.ServeHTTP(w, r)

0 commit comments

Comments
 (0)