@@ -8,9 +8,11 @@ import {
88 PRECONDITION_CHECK_FAILED_STATUS_CODE ,
99 GOOGLE_VERTEX_AI ,
1010} from '../globals' ;
11+ import { HookSpan } from '../middlewares/hooks' ;
1112import { VertexLlamaChatCompleteStreamChunkTransform } from '../providers/google-vertex-ai/chatComplete' ;
1213import { OpenAIChatCompleteResponse } from '../providers/openai/chatComplete' ;
1314import { OpenAICompleteResponse } from '../providers/openai/complete' ;
15+ import { endpointStrings } from '../providers/types' ;
1416import { Params } from '../types/requestBody' ;
1517import { getStreamModeSplitPattern , type SplitPatternType } from '../utils' ;
1618
@@ -24,6 +26,15 @@ function readUInt32BE(buffer: Uint8Array, offset: number) {
2426 ) ; // Ensure the result is an unsigned integer
2527}
2628
29+ const shouldSendHookResultChunk = (
30+ strictOpenAiCompliance : boolean ,
31+ hooksResult : HookSpan [ 'hooksResult' ]
32+ ) => {
33+ return (
34+ ! strictOpenAiCompliance && hooksResult ?. beforeRequestHooksResult ?. length > 0
35+ ) ;
36+ } ;
37+
2738function getPayloadFromAWSChunk ( chunk : Uint8Array ) : string {
2839 const decoder = new TextDecoder ( ) ;
2940 const chunkLength = readUInt32BE ( chunk , 0 ) ;
@@ -292,7 +303,9 @@ export function handleStreamingMode(
292303 responseTransformer : Function | undefined ,
293304 requestURL : string ,
294305 strictOpenAiCompliance : boolean ,
295- gatewayRequest : Params
306+ gatewayRequest : Params ,
307+ fn : endpointStrings ,
308+ hooksResult : HookSpan [ 'hooksResult' ]
296309) : Response {
297310 const splitPattern = getStreamModeSplitPattern ( proxyProvider , requestURL ) ;
298311 // If the provider doesn't supply completion id,
@@ -311,6 +324,12 @@ export function handleStreamingMode(
311324 if ( proxyProvider === BEDROCK ) {
312325 ( async ( ) => {
313326 try {
327+ if ( shouldSendHookResultChunk ( strictOpenAiCompliance , hooksResult ) ) {
328+ const hookResultChunk = constructHookResultChunk ( hooksResult , fn ) ;
329+ if ( hookResultChunk ) {
330+ await writer . write ( encoder . encode ( hookResultChunk ) ) ;
331+ }
332+ }
314333 for await ( const chunk of readAWSStream (
315334 reader ,
316335 responseTransformer ,
@@ -337,6 +356,12 @@ export function handleStreamingMode(
337356 } else {
338357 ( async ( ) => {
339358 try {
359+ if ( shouldSendHookResultChunk ( strictOpenAiCompliance , hooksResult ) ) {
360+ const hookResultChunk = constructHookResultChunk ( hooksResult , fn ) ;
361+ if ( hookResultChunk ) {
362+ await writer . write ( encoder . encode ( hookResultChunk ) ) ;
363+ }
364+ }
340365 for await ( const chunk of readStream (
341366 reader ,
342367 splitPattern ,
@@ -389,7 +414,10 @@ export function handleStreamingMode(
389414export async function handleJSONToStreamResponse (
390415 response : Response ,
391416 provider : string ,
392- responseTransformerFunction : Function
417+ responseTransformerFunction : Function ,
418+ strictOpenAiCompliance : boolean ,
419+ fn : endpointStrings ,
420+ hooksResult : HookSpan [ 'hooksResult' ]
393421) : Promise < Response > {
394422 const { readable, writable } = new TransformStream ( ) ;
395423 const writer = writable . getWriter ( ) ;
@@ -403,6 +431,12 @@ export async function handleJSONToStreamResponse(
403431 ) {
404432 const generator = responseTransformerFunction ( responseJSON , provider ) ;
405433 ( async ( ) => {
434+ if ( shouldSendHookResultChunk ( strictOpenAiCompliance , hooksResult ) ) {
435+ const hookResultChunk = constructHookResultChunk ( hooksResult , fn ) ;
436+ if ( hookResultChunk ) {
437+ await writer . write ( encoder . encode ( hookResultChunk ) ) ;
438+ }
439+ }
406440 while ( true ) {
407441 const chunk = generator . next ( ) ;
408442 if ( chunk . done ) {
@@ -418,6 +452,12 @@ export async function handleJSONToStreamResponse(
418452 provider
419453 ) ;
420454 ( async ( ) => {
455+ if ( shouldSendHookResultChunk ( strictOpenAiCompliance , hooksResult ) ) {
456+ const hookResultChunk = constructHookResultChunk ( hooksResult , fn ) ;
457+ if ( hookResultChunk ) {
458+ await writer . write ( encoder . encode ( hookResultChunk ) ) ;
459+ }
460+ }
421461 for ( const chunk of streamChunkArray ) {
422462 await writer . write ( encoder . encode ( chunk ) ) ;
423463 }
@@ -434,3 +474,21 @@ export async function handleJSONToStreamResponse(
434474 statusText : response . statusText ,
435475 } ) ;
436476}
477+
478+ const constructHookResultChunk = (
479+ hooksResult : HookSpan [ 'hooksResult' ] ,
480+ fn : endpointStrings
481+ ) => {
482+ if ( fn === 'messages' ) {
483+ return `event: hook_results\ndata: ${ JSON . stringify ( {
484+ hook_results : {
485+ before_request_hooks : hooksResult . beforeRequestHooksResult ,
486+ } ,
487+ } ) } \n\n`;
488+ }
489+ return `data: ${ JSON . stringify ( {
490+ hook_results : {
491+ before_request_hooks : hooksResult . beforeRequestHooksResult ,
492+ } ,
493+ } ) } \n\n`;
494+ } ;
0 commit comments