11import { Response } from 'express'
22import statsd from '@/observability/lib/statsd'
3- import got from 'got '
3+ import { fetchStream } from '@/frame/lib/fetch-utils '
44import { getHmacWithEpoch } from '@/search/lib/helpers/get-cse-copilot-auth'
55import { getCSECopilotSource } from '@/search/lib/helpers/cse-copilot-docs-versions'
66import type { ExtendedRequest } from '@/types'
@@ -56,56 +56,76 @@ export const aiSearchProxy = async (req: ExtendedRequest, res: Response) => {
5656 stream : true ,
5757 }
5858
59+ let reader : ReadableStreamDefaultReader < Uint8Array > | null = null
60+
5961 try {
6062 // TODO: We temporarily add ?ai_search=1 to use a new pattern in cgs-copilot production
61- const stream = got . stream . post ( `${ process . env . CSE_COPILOT_ENDPOINT } /answers?ai_search=1` , {
62- json : body ,
63- headers : {
64- Authorization : getHmacWithEpoch ( ) ,
65- 'Content-Type' : 'application/json' ,
63+ const response = await fetchStream (
64+ `${ process . env . CSE_COPILOT_ENDPOINT } /answers?ai_search=1` ,
65+ {
66+ method : 'POST' ,
67+ body : JSON . stringify ( body ) ,
68+ headers : {
69+ Authorization : getHmacWithEpoch ( ) ,
70+ 'Content-Type' : 'application/json' ,
71+ } ,
6672 } ,
67- } )
68-
69- // Listen for data events to count characters
70- stream . on ( 'data' , ( chunk : Buffer | string ) => {
71- // Ensure we have a string for proper character count
72- const dataStr = typeof chunk === 'string' ? chunk : chunk . toString ( )
73- totalChars += dataStr . length
74- } )
75-
76- // Handle the upstream response before piping
77- stream . on ( 'response' , ( upstreamResponse ) => {
78- if ( upstreamResponse . statusCode !== 200 ) {
79- const errorMessage = `Upstream server responded with status code ${ upstreamResponse . statusCode } `
80- console . error ( errorMessage )
81- statsd . increment ( 'ai-search.stream_response_error' , 1 , diagnosticTags )
82- res . status ( upstreamResponse . statusCode ) . json ( {
83- errors : [ { message : errorMessage } ] ,
84- upstreamStatus : upstreamResponse . statusCode ,
85- } )
86- stream . destroy ( )
87- } else {
88- // Set response headers
89- res . setHeader ( 'Content-Type' , 'application/x-ndjson' )
90- res . flushHeaders ( )
91-
92- // Pipe the got stream directly to the response
93- stream . pipe ( res )
73+ {
74+ throwHttpErrors : false ,
75+ } ,
76+ )
77+
78+ if ( ! response . ok ) {
79+ const errorMessage = `Upstream server responded with status code ${ response . status } `
80+ console . error ( errorMessage )
81+ statsd . increment ( 'ai-search.stream_response_error' , 1 , diagnosticTags )
82+ res . status ( response . status ) . json ( {
83+ errors : [ { message : errorMessage } ] ,
84+ upstreamStatus : response . status ,
85+ } )
86+ return
87+ }
88+
89+ // Set response headers
90+ res . setHeader ( 'Content-Type' , 'application/x-ndjson' )
91+ res . flushHeaders ( )
92+
93+ // Stream the response body
94+ if ( ! response . body ) {
95+ res . status ( 500 ) . json ( { errors : [ { message : 'No response body' } ] } )
96+ return
97+ }
98+
99+ reader = response . body . getReader ( )
100+ const decoder = new TextDecoder ( )
101+
102+ try {
103+ while ( true ) {
104+ const { done, value } = await reader . read ( )
105+
106+ if ( done ) {
107+ break
108+ }
109+
110+ // Decode chunk and count characters
111+ const chunk = decoder . decode ( value , { stream : true } )
112+ totalChars += chunk . length
113+
114+ // Write chunk to response
115+ res . write ( chunk )
94116 }
95- } )
96117
97- // Handle stream errors
98- stream . on ( 'error' , ( error : any ) => {
99- console . error ( 'Error streaming from cse-copilot:' , error )
118+ // Calculate metrics on stream end
119+ const totalResponseTime = Date . now ( ) - startTime // in ms
120+ const charPerMsRatio = totalResponseTime > 0 ? totalChars / totalResponseTime : 0 // chars per ms
100121
101- if ( error ?. code === 'ERR_NON_2XX_3XX_RESPONSE' ) {
102- const upstreamStatus = error ?. response ?. statusCode || 500
103- return res . status ( upstreamStatus ) . json ( {
104- errors : [ { message : 'Upstream server error' } ] ,
105- upstreamStatus,
106- } )
107- }
122+ statsd . gauge ( 'ai-search.total_response_time' , totalResponseTime , diagnosticTags )
123+ statsd . gauge ( 'ai-search.response_chars_per_ms' , charPerMsRatio , diagnosticTags )
108124
125+ statsd . increment ( 'ai-search.success_stream_end' , 1 , diagnosticTags )
126+ res . end ( )
127+ } catch ( streamError ) {
128+ console . error ( 'Error streaming from cse-copilot:' , streamError )
109129 statsd . increment ( 'ai-search.stream_error' , 1 , diagnosticTags )
110130
111131 if ( ! res . headersSent ) {
@@ -117,22 +137,20 @@ export const aiSearchProxy = async (req: ExtendedRequest, res: Response) => {
117137 res . write ( errorMessage )
118138 res . end ( )
119139 }
120- } )
121-
122- // Calculate metrics on stream end
123- stream . on ( 'end' , ( ) => {
124- const totalResponseTime = Date . now ( ) - startTime // in ms
125- const charPerMsRatio = totalResponseTime > 0 ? totalChars / totalResponseTime : 0 // chars per ms
126-
127- statsd . gauge ( 'ai-search.total_response_time' , totalResponseTime , diagnosticTags )
128- statsd . gauge ( 'ai-search.response_chars_per_ms' , charPerMsRatio , diagnosticTags )
129-
130- statsd . increment ( 'ai-search.success_stream_end' , 1 , diagnosticTags )
131- res . end ( )
132- } )
140+ } finally {
141+ if ( reader ) {
142+ reader . releaseLock ( )
143+ reader = null
144+ }
145+ }
133146 } catch ( error ) {
134147 statsd . increment ( 'ai-search.route_error' , 1 , diagnosticTags )
135148 console . error ( 'Error posting /answers to cse-copilot:' , error )
136149 res . status ( 500 ) . json ( { errors : [ { message : 'Internal server error' } ] } )
150+ } finally {
151+ // Ensure reader lock is always released
152+ if ( reader ) {
153+ reader . releaseLock ( )
154+ }
137155 }
138156}
0 commit comments