@@ -18,6 +18,9 @@ import (
1818// per model.
1919const maximumRecordsPerModel = 10
2020
21+ // subscriberChannelBuffer is the buffer size for subscriber channels.
22+ const subscriberChannelBuffer = 100
23+
2124type responseRecorder struct {
2225 http.ResponseWriter
2326 body * bytes.Buffer
@@ -69,13 +72,18 @@ type OpenAIRecorder struct {
6972 records map [string ]* ModelData // key is model ID
7073 modelManager * models.Manager // for resolving model tags to IDs
7174 m sync.RWMutex
75+
76+ // streaming
77+ subscribers map [string ]chan []ModelRecordsResponse
78+ subMutex sync.RWMutex
7279}
7380
7481func NewOpenAIRecorder (log logging.Logger , modelManager * models.Manager ) * OpenAIRecorder {
7582 return & OpenAIRecorder {
7683 log : log ,
7784 modelManager : modelManager ,
7885 records : make (map [string ]* ModelData ),
86+ subscribers : make (map [string ]chan []ModelRecordsResponse ),
7987 }
8088}
8189
@@ -188,6 +196,18 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
188196 record .Response = response
189197 record .Error = "" // Ensure Error is empty for successful responses
190198 }
199+ // Create ModelRecordsResponse with this single updated record to match
200+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
201+ // See getAllRecords and getRecordsByModel.
202+ modelResponse := []ModelRecordsResponse {{
203+ Count : 1 ,
204+ Model : model ,
205+ ModelData : ModelData {
206+ Config : modelData .Config ,
207+ Records : []* RequestResponsePair {record },
208+ },
209+ }}
210+ go r .broadcastToSubscribers (modelResponse )
191211 return
192212 }
193213 }
@@ -274,36 +294,124 @@ func (r *OpenAIRecorder) convertStreamingResponse(streamingBody string) string {
274294
275295func (r * OpenAIRecorder ) GetRecordsHandler () http.HandlerFunc {
276296 return func (w http.ResponseWriter , req * http.Request ) {
277- w .Header ().Set ("Content-Type" , "application/json" )
297+ acceptHeader := req .Header .Get ("Accept" )
298+
299+ // Check if client wants Server-Sent Events
300+ if acceptHeader == "text/event-stream" {
301+ r .handleStreamingRequests (w , req )
302+ return
303+ }
304+
305+ // Default to JSON response
306+ r .handleJSONRequests (w , req )
307+ }
308+ }
309+
310+ func (r * OpenAIRecorder ) handleJSONRequests (w http.ResponseWriter , req * http.Request ) {
311+ w .Header ().Set ("Content-Type" , "application/json" )
312+
313+ model := req .URL .Query ().Get ("model" )
314+
315+ if model == "" {
316+ // Retrieve all records for all models.
317+ allRecords := r .getAllRecords ()
318+ if allRecords == nil {
319+ allRecords = []ModelRecordsResponse {}
320+ }
321+ if err := json .NewEncoder (w ).Encode (allRecords ); err != nil {
322+ http .Error (w , fmt .Sprintf ("Failed to encode all records: %v" , err ),
323+ http .StatusInternalServerError )
324+ return
325+ }
326+ } else {
327+ // Retrieve records for the specified model.
328+ records := r .getRecordsByModel (model )
329+ if records == nil {
330+ records = []ModelRecordsResponse {}
331+ }
332+ if err := json .NewEncoder (w ).Encode (records ); err != nil {
333+ http .Error (w , fmt .Sprintf ("Failed to encode records for model '%s': %v" , model , err ),
334+ http .StatusInternalServerError )
335+ return
336+ }
337+ }
338+ }
278339
279- model := req .URL .Query ().Get ("model" )
340+ func (r * OpenAIRecorder ) handleStreamingRequests (w http.ResponseWriter , req * http.Request ) {
341+ // Set SSE headers.
342+ w .Header ().Set ("Content-Type" , "text/event-stream" )
343+ w .Header ().Set ("Cache-Control" , "no-cache" )
344+ w .Header ().Set ("Connection" , "keep-alive" )
345+
346+ // Create subscriber channel.
347+ subscriberID := fmt .Sprintf ("sub_%d" , time .Now ().UnixNano ())
348+ ch := make (chan []ModelRecordsResponse , subscriberChannelBuffer )
349+
350+ // Register subscriber.
351+ r .subMutex .Lock ()
352+ r .subscribers [subscriberID ] = ch
353+ r .subMutex .Unlock ()
354+
355+ // Clean up on disconnect.
356+ defer func () {
357+ r .subMutex .Lock ()
358+ delete (r .subscribers , subscriberID )
359+ close (ch )
360+ r .subMutex .Unlock ()
361+ }()
362+
363+ // Optional: Send existing records first.
364+ model := req .URL .Query ().Get ("model" )
365+ if includeExisting := req .URL .Query ().Get ("include_existing" ); includeExisting == "true" {
366+ r .sendExistingRecords (w , model )
367+ }
280368
281- if model == "" {
282- // Retrieve all records for all models.
283- allRecords := r .getAllRecords ()
284- if allRecords == nil {
285- // No records found.
286- http .Error (w , "No records found" , http .StatusNotFound )
369+ flusher , ok := w .(http.Flusher )
370+ if ! ok {
371+ http .Error (w , "Streaming not supported" , http .StatusInternalServerError )
372+ return
373+ }
374+
375+ // Send heartbeat to establish connection.
376+ if _ , err := fmt .Fprintf (w , "event: connected\n data: {\" status\" : \" connected\" }\n \n " ); err != nil {
377+ r .log .Errorf ("Failed to write connected event to response: %v" , err )
378+ }
379+ flusher .Flush ()
380+
381+ for {
382+ select {
383+ case modelRecords , ok := <- ch :
384+ if ! ok {
287385 return
288386 }
289- if err := json .NewEncoder (w ).Encode (allRecords ); err != nil {
290- http .Error (w , fmt .Sprintf ("Failed to encode all records: %v" , err ),
291- http .StatusInternalServerError )
292- return
387+
388+ // Filter by model if specified.
389+ // modelRecords is assumed to have size 1 because that's how we call broadcastToSubscribers.
390+ // We do this so we don't need to query a 2nd time for the model config.
391+ if model != "" && len (modelRecords ) > 0 && modelRecords [0 ].Model != model {
392+ continue
293393 }
294- } else {
295- // Retrieve records for the specified model.
296- records := r .getRecordsByModel (model )
297- if records == nil {
298- // No records found for the specified model.
299- http .Error (w , fmt .Sprintf ("No records found for model '%s'" , model ), http .StatusNotFound )
300- return
394+
395+ // Send as SSE event.
396+ jsonData , err := json .Marshal (modelRecords )
397+ if err != nil {
398+ r .log .Errorf ("Failed to marshal record for streaming: %v" , err )
399+ errorMsg := fmt .Sprintf (`{"error": "Failed to marshal record: %v"}` , err )
400+ if _ , writeErr := fmt .Fprintf (w , "event: error\n data: %s\n \n " , errorMsg ); writeErr != nil {
401+ r .log .Errorf ("Failed to write error event to response: %v" , writeErr )
402+ }
403+ flusher .Flush ()
404+ continue
301405 }
302- if err := json .NewEncoder (w ).Encode (records ); err != nil {
303- http .Error (w , fmt .Sprintf ("Failed to encode records for model '%s': %v" , model , err ),
304- http .StatusInternalServerError )
305- return
406+
407+ if _ , err := fmt .Fprintf (w , "event: new_request\n data: %s\n \n " , jsonData ); err != nil {
408+ r .log .Errorf ("Failed to write new_request event to response: %v" , err )
306409 }
410+ flusher .Flush ()
411+
412+ case <- req .Context ().Done ():
413+ // Client disconnected.
414+ return
307415 }
308416 }
309417}
@@ -352,6 +460,60 @@ func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse
352460 return nil
353461}
354462
463+ func (r * OpenAIRecorder ) broadcastToSubscribers (modelResponses []ModelRecordsResponse ) {
464+ r .subMutex .RLock ()
465+ defer r .subMutex .RUnlock ()
466+
467+ for _ , ch := range r .subscribers {
468+ select {
469+ case ch <- modelResponses :
470+ default :
471+ // The channel is full, skip this subscriber.
472+ }
473+ }
474+ }
475+
476+ func (r * OpenAIRecorder ) sendExistingRecords (w http.ResponseWriter , model string ) {
477+ var records []ModelRecordsResponse
478+
479+ if model == "" {
480+ records = r .getAllRecords ()
481+ } else {
482+ records = r .getRecordsByModel (model )
483+ }
484+
485+ if records != nil {
486+ // Send each individual request-response pair as a separate event.
487+ for _ , modelRecord := range records {
488+ for _ , requestRecord := range modelRecord .Records {
489+ // Create a ModelRecordsResponse with a single record to match
490+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
491+ // See getAllRecords and getRecordsByModel.
492+ singleRecord := []ModelRecordsResponse {{
493+ Count : 1 ,
494+ Model : modelRecord .Model ,
495+ ModelData : ModelData {
496+ Config : modelRecord .Config ,
497+ Records : []* RequestResponsePair {requestRecord },
498+ },
499+ }}
500+ jsonData , err := json .Marshal (singleRecord )
501+ if err != nil {
502+ r .log .Errorf ("Failed to marshal existing record for streaming: %v" , err )
503+ errorMsg := fmt .Sprintf (`{"error": "Failed to marshal existing record: %v"}` , err )
504+ if _ , writeErr := fmt .Fprintf (w , "event: error\n data: %s\n \n " , errorMsg ); writeErr != nil {
505+ r .log .Errorf ("Failed to write error event to response: %v" , writeErr )
506+ }
507+ } else {
508+ if _ , writeErr := fmt .Fprintf (w , "event: existing_request\n data: %s\n \n " , jsonData ); writeErr != nil {
509+ r .log .Errorf ("Failed to write existing_request event to response: %v" , writeErr )
510+ }
511+ }
512+ }
513+ }
514+ }
515+ }
516+
355517func (r * OpenAIRecorder ) RemoveModel (model string ) {
356518 modelID := r .modelManager .ResolveModelID (model )
357519
0 commit comments