Skip to content

Commit 033cf83

Browse files
committed
feat: add streaming endpoint for inference requests
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent 38bb017 commit 033cf83

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

pkg/inference/scheduling/scheduler.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
118118
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
119119
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
120120
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsHandler()
121+
m["GET "+inference.InferencePrefix+"/requests/stream"] = s.openAIRecorder.StreamRequestsHandler()
121122
return m
122123
}
123124

pkg/metrics/openai_recorder.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,18 @@ type OpenAIRecorder struct {
6969
records map[string]*ModelData // key is model ID
7070
modelManager *models.Manager // for resolving model tags to IDs
7171
m sync.RWMutex
72+
73+
// streaming
74+
subscribers map[string]chan *RequestResponsePair
75+
subMutex sync.RWMutex
7276
}
7377

7478
func NewOpenAIRecorder(log logging.Logger, modelManager *models.Manager) *OpenAIRecorder {
7579
return &OpenAIRecorder{
7680
log: log,
7781
modelManager: modelManager,
7882
records: make(map[string]*ModelData),
83+
subscribers: make(map[string]chan *RequestResponsePair),
7984
}
8085
}
8186

@@ -188,6 +193,7 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
188193
record.Response = response
189194
record.Error = "" // Ensure Error is empty for successful responses
190195
}
196+
go r.broadcastToSubscribers(record)
191197
return
192198
}
193199
}
@@ -352,6 +358,110 @@ func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse
352358
return nil
353359
}
354360

361+
func (r *OpenAIRecorder) broadcastToSubscribers(record *RequestResponsePair) {
362+
r.subMutex.RLock()
363+
defer r.subMutex.RUnlock()
364+
365+
for _, ch := range r.subscribers {
366+
select {
367+
case ch <- record:
368+
default:
369+
// The channel is full, skip this subscriber.
370+
}
371+
}
372+
}
373+
374+
func (r *OpenAIRecorder) StreamRequestsHandler() http.HandlerFunc {
375+
return func(w http.ResponseWriter, req *http.Request) {
376+
// Set SSE headers.
377+
w.Header().Set("Content-Type", "text/event-stream")
378+
w.Header().Set("Cache-Control", "no-cache")
379+
w.Header().Set("Connection", "keep-alive")
380+
381+
// Create subscriber channel.
382+
subscriberID := fmt.Sprintf("sub_%d", time.Now().UnixNano())
383+
ch := make(chan *RequestResponsePair, 100)
384+
385+
// Register subscriber.
386+
r.subMutex.Lock()
387+
r.subscribers[subscriberID] = ch
388+
r.subMutex.Unlock()
389+
390+
// Clean up on disconnect.
391+
defer func() {
392+
r.subMutex.Lock()
393+
delete(r.subscribers, subscriberID)
394+
close(ch)
395+
r.subMutex.Unlock()
396+
}()
397+
398+
// Optional: Send existing records first.
399+
model := req.URL.Query().Get("model")
400+
if includeExisting := req.URL.Query().Get("include_existing"); includeExisting == "true" {
401+
r.sendExistingRecords(w, model)
402+
}
403+
404+
flusher, ok := w.(http.Flusher)
405+
if !ok {
406+
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
407+
return
408+
}
409+
410+
// Send heartbeat to establish connection.
411+
fmt.Fprintf(w, "event: connected\ndata: {\"status\": \"connected\"}\n\n")
412+
flusher.Flush()
413+
414+
for {
415+
select {
416+
case record, ok := <-ch:
417+
if !ok {
418+
return
419+
}
420+
421+
// Filter by model if specified.
422+
if model != "" && record.Model != model {
423+
continue
424+
}
425+
426+
// Send as SSE event.
427+
jsonData, err := json.Marshal(record)
428+
if err != nil {
429+
continue
430+
}
431+
432+
fmt.Fprintf(w, "event: new_request\ndata: %s\n\n", jsonData)
433+
flusher.Flush()
434+
435+
case <-req.Context().Done():
436+
// Client disconnected.
437+
return
438+
}
439+
}
440+
}
441+
}
442+
443+
func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string) {
444+
var records []ModelRecordsResponse
445+
446+
if model == "" {
447+
records = r.getAllRecords()
448+
} else {
449+
records = r.getRecordsByModel(model)
450+
}
451+
452+
if records != nil {
453+
for _, modelResponse := range records {
454+
for _, record := range modelResponse.Records {
455+
jsonData, err := json.Marshal(record)
456+
if err != nil {
457+
continue
458+
}
459+
fmt.Fprintf(w, "event: existing_request\ndata: %s\n\n", jsonData)
460+
}
461+
}
462+
}
463+
}
464+
355465
func (r *OpenAIRecorder) RemoveModel(model string) {
356466
modelID := r.modelManager.ResolveModelID(model)
357467

0 commit comments

Comments
 (0)