@@ -13,6 +13,7 @@ use dynamo_parsers::tool_calling::{
1313} ;
1414use dynamo_runtime:: protocols:: annotated:: Annotated ;
1515use futures:: { Stream , StreamExt } ;
16+ use std:: collections:: HashMap ;
1617
1718use crate :: utils:: { MarkerMatcher , MatchResult } ;
1819
@@ -72,6 +73,8 @@ struct ChoiceJailState {
7273 accumulated_content : String ,
7374 /// Buffer for partial marker matches across chunks
7475 partial_match_buffer : String ,
76+ /// Stream finish reason
77+ stream_finish_reason : Option < FinishReason > ,
7578}
7679
7780fn create_choice_stream (
@@ -106,6 +109,7 @@ impl ChoiceJailState {
106109 is_jailed : false ,
107110 accumulated_content : String :: new ( ) ,
108111 partial_match_buffer : String :: new ( ) ,
112+ stream_finish_reason : None ,
109113 }
110114 }
111115
@@ -130,7 +134,6 @@ impl ChoiceJailState {
130134 jail_stream : & JailedStream ,
131135 ) -> Vec < ChoiceEmission > {
132136 let mut emissions = Vec :: new ( ) ;
133-
134137 if !self . is_jailed {
135138 // Use the marker matcher to detect complete/partial markers
136139 let match_result = jail_stream
@@ -152,7 +155,7 @@ impl ChoiceJailState {
152155 choice. delta . role ,
153156 & prefix,
154157 None ,
155- None ,
158+ choice . finish_reason ,
156159 choice. logprobs . clone ( ) ,
157160 ) ;
158161 emissions. push ( ChoiceEmission :: PassThrough ( prefix_choice) ) ;
@@ -192,7 +195,7 @@ impl ChoiceJailState {
192195 choice. delta . role ,
193196 trailing_part,
194197 None ,
195- None ,
198+ choice . finish_reason ,
196199 choice. logprobs . clone ( ) ,
197200 ) ;
198201 emissions. push ( ChoiceEmission :: Trailing ( trailing_choice) ) ;
@@ -224,7 +227,7 @@ impl ChoiceJailState {
224227 choice. delta . role ,
225228 & prefix,
226229 None ,
227- None ,
230+ choice . finish_reason ,
228231 choice. logprobs . clone ( ) ,
229232 ) ;
230233 emissions. push ( ChoiceEmission :: PassThrough ( prefix_choice) ) ;
@@ -267,7 +270,7 @@ impl ChoiceJailState {
267270 choice. delta . role ,
268271 & content,
269272 None ,
270- None ,
273+ choice . finish_reason ,
271274 choice. logprobs . clone ( ) ,
272275 ) ;
273276 emissions. push ( ChoiceEmission :: PassThrough ( pass_through_choice) ) ;
@@ -312,7 +315,7 @@ impl ChoiceJailState {
312315 choice. delta . role ,
313316 trailing_part,
314317 None ,
315- None ,
318+ choice . finish_reason ,
316319 choice. logprobs . clone ( ) ,
317320 ) ;
318321 emissions. push ( ChoiceEmission :: Trailing ( trailing_choice) ) ;
@@ -323,7 +326,6 @@ impl ChoiceJailState {
323326 }
324327 // If not unjailing, don't emit anything (still accumulating)
325328 }
326-
327329 emissions
328330 }
329331
@@ -342,7 +344,7 @@ impl ChoiceJailState {
342344 Some ( Role :: Assistant ) ,
343345 & self . accumulated_content ,
344346 None ,
345- None ,
347+ self . stream_finish_reason , // For the accumulated content, assign the original stream finish reason, otherwise it will get lost
346348 None ,
347349 ) ;
348350
@@ -428,6 +430,19 @@ impl JailedStream {
428430 JailedStreamBuilder :: new ( )
429431 }
430432
433+ /// Apply jail stream transformation with finish_reason fix
434+ /// This is a convenience method that applies both apply() and fix_finish_reason()
435+ pub fn apply_with_finish_reason < S > (
436+ self ,
437+ stream : S ,
438+ ) -> impl Stream < Item = Annotated < NvCreateChatCompletionStreamResponse > > + Send
439+ where
440+ S : Stream < Item = Annotated < NvCreateChatCompletionStreamResponse > > + Send + ' static ,
441+ {
442+ let jailed_stream = self . apply ( stream) ;
443+ JailedStream :: fix_finish_reason ( jailed_stream)
444+ }
445+
431446 /// Apply the jail transformation to a stream of chat completion responses
432447 /// Consumes self and returns the transformed stream
433448 pub fn apply < S > (
@@ -449,6 +464,7 @@ impl JailedStream {
449464 // Pin the stream for iteration (stack pinning is more efficient)
450465 tokio:: pin!( stream) ;
451466
467+
452468 // Process each item in the stream
453469 while let Some ( response) = stream. next( ) . await {
454470 if let Some ( chat_response) = response. data. as_ref( ) {
@@ -467,6 +483,9 @@ impl JailedStream {
467483 last_annotated_comment = response. comment. clone( ) ;
468484 }
469485
486+ // Track actual stream finish reason in the choice state
487+ choice_state. stream_finish_reason = choice. finish_reason;
488+
470489 // Process this choice and get emissions
471490 let emissions = choice_state. process_content( choice, content, & self ) . await ;
472491 all_emissions. extend( emissions) ;
@@ -707,16 +726,16 @@ impl JailedStream {
707726 } ) ,
708727 } )
709728 . collect ( ) ;
710-
711729 // Create choice with tool calls
712- return create_choice_stream (
730+ let choice = create_choice_stream (
713731 choice_index,
714732 Some ( Role :: Assistant ) ,
715733 normal_text. as_deref ( ) . unwrap_or ( "" ) ,
716734 Some ( tool_call_chunks) ,
717- Some ( FinishReason :: ToolCalls ) ,
735+ None ,
718736 None ,
719737 ) ;
738+ return choice;
720739 }
721740
722741 // No tool calls found or parsing failed, return content choice
@@ -725,7 +744,7 @@ impl JailedStream {
725744 Some ( Role :: Assistant ) ,
726745 accumulated_content,
727746 None ,
728- None ,
747+ base_choice . finish_reason ,
729748 base_choice. logprobs . clone ( ) ,
730749 )
731750 }
@@ -745,6 +764,44 @@ impl JailedStream {
745764 }
746765 false
747766 }
767+
768+ /// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted
769+ /// This should be called after apply() to fix the finish_reason for tool call chunks
770+ pub fn fix_finish_reason < S > (
771+ input_stream : S ,
772+ ) -> impl Stream < Item = Annotated < NvCreateChatCompletionStreamResponse > > + Send
773+ where
774+ S : Stream < Item = Annotated < NvCreateChatCompletionStreamResponse > > + Send + ' static ,
775+ {
776+ stream ! {
777+ tokio:: pin!( input_stream) ;
778+ let mut has_tool_calls_per_choice: HashMap <u32 , bool > = HashMap :: new( ) ;
779+
780+ while let Some ( mut response) = input_stream. next( ) . await {
781+ // Track if any choice emitted tool calls
782+ if let Some ( ref data) = response. data {
783+ for choice in & data. choices {
784+ if choice. delta. tool_calls. is_some( ) {
785+ has_tool_calls_per_choice. insert( choice. index, true ) ;
786+ }
787+ }
788+ }
789+
790+ // If this chunk has finish_reason and the choice had tool calls, override to ToolCalls
791+ if let Some ( ref mut data) = response. data {
792+ for choice in & mut data. choices {
793+ if choice. finish_reason. is_some( ) && choice. finish_reason == Some ( FinishReason :: Stop )
794+ && has_tool_calls_per_choice. get( & choice. index) . copied( ) . unwrap_or( false )
795+ {
796+ choice. finish_reason = Some ( FinishReason :: ToolCalls ) ;
797+ }
798+ }
799+ }
800+
801+ yield response;
802+ }
803+ }
804+ }
748805}
749806
750807/// Builder for configuring a JailedStream
0 commit comments