Skip to content

Commit a0ddf2e

Browse files
committed
refactor: combine the requests endpoints
Differentiate regular and streaming based on the Accept Header. Signed-off-by: Dorin Geman <dorin.geman@docker.com>
1 parent 033cf83 commit a0ddf2e

File tree

2 files changed

+104
-94
lines changed

2 files changed

+104
-94
lines changed

pkg/inference/scheduling/scheduler.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
118118
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
119119
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
120120
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsHandler()
121-
m["GET "+inference.InferencePrefix+"/requests/stream"] = s.openAIRecorder.StreamRequestsHandler()
122121
return m
123122
}
124123

pkg/metrics/openai_recorder.go

Lines changed: 104 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -280,36 +280,116 @@ func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) string {
280280

281281
func (r *OpenAIRecorder) GetRecordsHandler() http.HandlerFunc {
282282
return func(w http.ResponseWriter, req *http.Request) {
283-
w.Header().Set("Content-Type", "application/json")
283+
acceptHeader := req.Header.Get("Accept")
284284

285-
model := req.URL.Query().Get("model")
285+
// Check if client wants Server-Sent Events
286+
if acceptHeader == "text/event-stream" {
287+
r.handleStreamingRequests(w, req)
288+
return
289+
}
286290

287-
if model == "" {
288-
// Retrieve all records for all models.
289-
allRecords := r.getAllRecords()
290-
if allRecords == nil {
291-
// No records found.
292-
http.Error(w, "No records found", http.StatusNotFound)
293-
return
294-
}
295-
if err := json.NewEncoder(w).Encode(allRecords); err != nil {
296-
http.Error(w, fmt.Sprintf("Failed to encode all records: %v", err),
297-
http.StatusInternalServerError)
291+
// Default to JSON response
292+
r.handleJSONRequests(w, req)
293+
}
294+
}
295+
296+
func (r *OpenAIRecorder) handleJSONRequests(w http.ResponseWriter, req *http.Request) {
297+
w.Header().Set("Content-Type", "application/json")
298+
299+
model := req.URL.Query().Get("model")
300+
301+
if model == "" {
302+
// Retrieve all records for all models.
303+
allRecords := r.getAllRecords()
304+
if allRecords == nil {
305+
// No records found.
306+
http.Error(w, "No records found", http.StatusNotFound)
307+
return
308+
}
309+
if err := json.NewEncoder(w).Encode(allRecords); err != nil {
310+
http.Error(w, fmt.Sprintf("Failed to encode all records: %v", err),
311+
http.StatusInternalServerError)
312+
return
313+
}
314+
} else {
315+
// Retrieve records for the specified model.
316+
records := r.getRecordsByModel(model)
317+
if records == nil {
318+
// No records found for the specified model.
319+
http.Error(w, fmt.Sprintf("No records found for model '%s'", model), http.StatusNotFound)
320+
return
321+
}
322+
if err := json.NewEncoder(w).Encode(records); err != nil {
323+
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
324+
http.StatusInternalServerError)
325+
return
326+
}
327+
}
328+
}
329+
330+
func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *http.Request) {
331+
// Set SSE headers.
332+
w.Header().Set("Content-Type", "text/event-stream")
333+
w.Header().Set("Cache-Control", "no-cache")
334+
w.Header().Set("Connection", "keep-alive")
335+
336+
// Create subscriber channel.
337+
subscriberID := fmt.Sprintf("sub_%d", time.Now().UnixNano())
338+
ch := make(chan *RequestResponsePair, 100)
339+
340+
// Register subscriber.
341+
r.subMutex.Lock()
342+
r.subscribers[subscriberID] = ch
343+
r.subMutex.Unlock()
344+
345+
// Clean up on disconnect.
346+
defer func() {
347+
r.subMutex.Lock()
348+
delete(r.subscribers, subscriberID)
349+
close(ch)
350+
r.subMutex.Unlock()
351+
}()
352+
353+
// Optional: Send existing records first.
354+
model := req.URL.Query().Get("model")
355+
if includeExisting := req.URL.Query().Get("include_existing"); includeExisting == "true" {
356+
r.sendExistingRecords(w, model)
357+
}
358+
359+
flusher, ok := w.(http.Flusher)
360+
if !ok {
361+
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
362+
return
363+
}
364+
365+
// Send heartbeat to establish connection.
366+
fmt.Fprintf(w, "event: connected\ndata: {\"status\": \"connected\"}\n\n")
367+
flusher.Flush()
368+
369+
for {
370+
select {
371+
case record, ok := <-ch:
372+
if !ok {
298373
return
299374
}
300-
} else {
301-
// Retrieve records for the specified model.
302-
records := r.getRecordsByModel(model)
303-
if records == nil {
304-
// No records found for the specified model.
305-
http.Error(w, fmt.Sprintf("No records found for model '%s'", model), http.StatusNotFound)
306-
return
375+
376+
// Filter by model if specified.
377+
if model != "" && record.Model != model {
378+
continue
307379
}
308-
if err := json.NewEncoder(w).Encode(records); err != nil {
309-
http.Error(w, fmt.Sprintf("Failed to encode records for model '%s': %v", model, err),
310-
http.StatusInternalServerError)
311-
return
380+
381+
// Send as SSE event.
382+
jsonData, err := json.Marshal(record)
383+
if err != nil {
384+
continue
312385
}
386+
387+
fmt.Fprintf(w, "event: new_request\ndata: %s\n\n", jsonData)
388+
flusher.Flush()
389+
390+
case <-req.Context().Done():
391+
// Client disconnected.
392+
return
313393
}
314394
}
315395
}
@@ -371,75 +451,6 @@ func (r *OpenAIRecorder) broadcastToSubscribers(record *RequestResponsePair) {
371451
}
372452
}
373453

374-
func (r *OpenAIRecorder) StreamRequestsHandler() http.HandlerFunc {
375-
return func(w http.ResponseWriter, req *http.Request) {
376-
// Set SSE headers.
377-
w.Header().Set("Content-Type", "text/event-stream")
378-
w.Header().Set("Cache-Control", "no-cache")
379-
w.Header().Set("Connection", "keep-alive")
380-
381-
// Create subscriber channel.
382-
subscriberID := fmt.Sprintf("sub_%d", time.Now().UnixNano())
383-
ch := make(chan *RequestResponsePair, 100)
384-
385-
// Register subscriber.
386-
r.subMutex.Lock()
387-
r.subscribers[subscriberID] = ch
388-
r.subMutex.Unlock()
389-
390-
// Clean up on disconnect.
391-
defer func() {
392-
r.subMutex.Lock()
393-
delete(r.subscribers, subscriberID)
394-
close(ch)
395-
r.subMutex.Unlock()
396-
}()
397-
398-
// Optional: Send existing records first.
399-
model := req.URL.Query().Get("model")
400-
if includeExisting := req.URL.Query().Get("include_existing"); includeExisting == "true" {
401-
r.sendExistingRecords(w, model)
402-
}
403-
404-
flusher, ok := w.(http.Flusher)
405-
if !ok {
406-
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
407-
return
408-
}
409-
410-
// Send heartbeat to establish connection.
411-
fmt.Fprintf(w, "event: connected\ndata: {\"status\": \"connected\"}\n\n")
412-
flusher.Flush()
413-
414-
for {
415-
select {
416-
case record, ok := <-ch:
417-
if !ok {
418-
return
419-
}
420-
421-
// Filter by model if specified.
422-
if model != "" && record.Model != model {
423-
continue
424-
}
425-
426-
// Send as SSE event.
427-
jsonData, err := json.Marshal(record)
428-
if err != nil {
429-
continue
430-
}
431-
432-
fmt.Fprintf(w, "event: new_request\ndata: %s\n\n", jsonData)
433-
flusher.Flush()
434-
435-
case <-req.Context().Done():
436-
// Client disconnected.
437-
return
438-
}
439-
}
440-
}
441-
}
442-
443454
func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string) {
444455
var records []ModelRecordsResponse
445456

0 commit comments

Comments
 (0)