@@ -18,6 +18,9 @@ import (
1818// per model.
1919const maximumRecordsPerModel = 10
2020
21+ // subscriberChannelBuffer is the buffer size for subscriber channels.
22+ const subscriberChannelBuffer = 100
23+
2124type responseRecorder struct {
2225 http.ResponseWriter
2326 body * bytes.Buffer
@@ -71,7 +74,7 @@ type OpenAIRecorder struct {
7174 m sync.RWMutex
7275
7376 // streaming
74- subscribers map [string ]chan * RequestResponsePair
77+ subscribers map [string ]chan [] ModelRecordsResponse
7578 subMutex sync.RWMutex
7679}
7780
@@ -80,7 +83,7 @@ func NewOpenAIRecorder(log logging.Logger, modelManager *models.Manager) *OpenAI
8083 log : log ,
8184 modelManager : modelManager ,
8285 records : make (map [string ]* ModelData ),
83- subscribers : make (map [string ]chan * RequestResponsePair ),
86+ subscribers : make (map [string ]chan [] ModelRecordsResponse ),
8487 }
8588}
8689
@@ -193,7 +196,18 @@ func (r *OpenAIRecorder) RecordResponse(id, model string, rw http.ResponseWriter
193196 record .Response = response
194197 record .Error = "" // Ensure Error is empty for successful responses
195198 }
196- go r .broadcastToSubscribers (record )
199+ // Create ModelRecordsResponse with this single updated record to match
200+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
201+ // See getAllRecords and getRecordsByModel.
202+ modelResponse := []ModelRecordsResponse {{
203+ Count : 1 ,
204+ Model : model ,
205+ ModelData : ModelData {
206+ Config : modelData .Config ,
207+ Records : []* RequestResponsePair {record },
208+ },
209+ }}
210+ go r .broadcastToSubscribers (modelResponse )
197211 return
198212 }
199213 }
@@ -335,7 +349,7 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt
335349
336350 // Create subscriber channel.
337351 subscriberID := fmt .Sprintf ("sub_%d" , time .Now ().UnixNano ())
338- ch := make (chan * RequestResponsePair , 100 )
352+ ch := make (chan [] ModelRecordsResponse , subscriberChannelBuffer )
339353
340354 // Register subscriber.
341355 r .subMutex .Lock ()
@@ -368,18 +382,18 @@ func (r *OpenAIRecorder) handleStreamingRequests(w http.ResponseWriter, req *htt
368382
369383 for {
370384 select {
371- case record , ok := <- ch :
385+ case modelRecords , ok := <- ch :
372386 if ! ok {
373387 return
374388 }
375389
376390 // Filter by model if specified.
377- if model != "" && record .Model != model {
391+ if model != "" && len ( modelRecords ) > 0 && modelRecords [ 0 ] .Model != model {
378392 continue
379393 }
380394
381395 // Send as SSE event.
382- jsonData , err := json .Marshal (record )
396+ jsonData , err := json .Marshal (modelRecords )
383397 if err != nil {
384398 continue
385399 }
@@ -438,13 +452,13 @@ func (r *OpenAIRecorder) getRecordsByModel(model string) []ModelRecordsResponse
438452 return nil
439453}
440454
441- func (r * OpenAIRecorder ) broadcastToSubscribers (record * RequestResponsePair ) {
455+ func (r * OpenAIRecorder ) broadcastToSubscribers (modelResponses [] ModelRecordsResponse ) {
442456 r .subMutex .RLock ()
443457 defer r .subMutex .RUnlock ()
444458
445459 for _ , ch := range r .subscribers {
446460 select {
447- case ch <- record :
461+ case ch <- modelResponses :
448462 default :
449463 // The channel is full, skip this subscriber.
450464 }
@@ -461,13 +475,24 @@ func (r *OpenAIRecorder) sendExistingRecords(w http.ResponseWriter, model string
461475 }
462476
463477 if records != nil {
464- for _ , modelResponse := range records {
465- for _ , record := range modelResponse .Records {
466- jsonData , err := json .Marshal (record )
467- if err != nil {
468- continue
478+ // Send each individual request-response pair as a separate event.
479+ for _ , modelRecord := range records {
480+ for _ , requestRecord := range modelRecord .Records {
481+ // Create a ModelRecordsResponse with a single record to match
482+ // what the non-streaming endpoint returns - []ModelRecordsResponse.
483+ // See getAllRecords and getRecordsByModel.
484+ singleRecord := []ModelRecordsResponse {{
485+ Count : 1 ,
486+ Model : modelRecord .Model ,
487+ ModelData : ModelData {
488+ Config : modelRecord .Config ,
489+ Records : []* RequestResponsePair {requestRecord },
490+ },
491+ }}
492+ jsonData , err := json .Marshal (singleRecord )
493+ if err == nil {
494+ fmt .Fprintf (w , "event: existing_request\n data: %s\n \n " , jsonData )
469495 }
470- fmt .Fprintf (w , "event: existing_request\n data: %s\n \n " , jsonData )
471496 }
472497 }
473498 }
0 commit comments