@@ -94,6 +94,22 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
9494 modelMessage2
9595 ]
9696 } ] ;
97+ const chatHistoryOnlyCall : ChatHistoryItem [ ] = [ ...baseChatHistory , {
98+ type : "model" ,
99+ response : [
100+ {
101+ type : "functionCall" ,
102+ name : func1name ,
103+
104+ // convert to number since this will go through JSON.stringify,
105+ // and we want to avoid escaping characters in the rendered output
106+ params : Number ( func1params ) ,
107+ result : Number ( func1result ) ,
108+ startsNewChunk : true
109+ } ,
110+ modelMessage2
111+ ]
112+ } ] ;
97113 const chatHistory2Calls : ChatHistoryItem [ ] = [ ...baseChatHistory , {
98114 type : "model" ,
99115 response : [
@@ -257,6 +273,17 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
257273 stringifyFunctionResults : stringifyResult ,
258274 combineModelMessageAndToolCalls
259275 } ) ;
276+ const renderedOnlyCall = getFirstValidResult ( [
277+ ( ) => renderTemplate ( {
278+ chatHistory : chatHistoryOnlyCall ,
279+ functions : functions1 ,
280+ additionalParams,
281+ stringifyFunctionParams : stringifyParams ,
282+ stringifyFunctionResults : stringifyResult ,
283+ combineModelMessageAndToolCalls
284+ } ) ,
285+ ( ) => undefined
286+ ] ) ;
260287 const rendered2Calls = getFirstValidResult ( [
261288 ( ) => renderTemplate ( {
262289 chatHistory : chatHistory2Calls ,
@@ -411,14 +438,46 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
411438 parallelismResultPrefix
412439 } = resolveParallelismBetweenSectionsParts ( func2ParamsToFunc1Result . text . slice ( callSuffixLength , - resultPrefixLength ) ) ;
413440
441+ let revivedCallPrefix = reviveSeparatorText ( callPrefixText , idToStaticContent , contentIds ) ;
442+ const revivedParallelismCallSectionPrefix = removeCommonRevivedPrefix (
443+ reviveSeparatorText ( parallelismCallPrefix , idToStaticContent , contentIds ) ,
444+ ! combineModelMessageAndToolCalls
445+ ? textBetween2TextualModelResponses
446+ : LlamaText ( )
447+ ) ;
448+ let revivedParallelismCallBetweenCalls = reviveSeparatorText ( parallelismBetweenCallsText , idToStaticContent , contentIds ) ;
449+
450+ if ( revivedParallelismCallSectionPrefix . values . length === 0 && renderedOnlyCall != null ) {
451+ const userMessage1ToModelMessage1Start = getTextBetweenIds ( rendered1Call , userMessage1 , modelMessage1 ) ;
452+ const onlyCallUserMessage1ToFunc1Name = getTextBetweenIds ( renderedOnlyCall , userMessage1 , func1name ) ;
453+
454+ if ( userMessage1ToModelMessage1Start . text != null && onlyCallUserMessage1ToFunc1Name . text != null ) {
455+ const onlyCallModelMessagePrefixLength = findCommandStartLength (
456+ userMessage1ToModelMessage1Start . text ,
457+ onlyCallUserMessage1ToFunc1Name . text
458+ ) ;
459+ const onlyCallCallPrefixText = onlyCallUserMessage1ToFunc1Name . text . slice ( onlyCallModelMessagePrefixLength ) ;
460+ const revivedOnlyCallCallPrefixText = reviveSeparatorText ( onlyCallCallPrefixText , idToStaticContent , contentIds ) ;
461+
462+ const optionalCallPrefix = removeCommonRevivedSuffix ( revivedCallPrefix , revivedOnlyCallCallPrefixText ) ;
463+ if ( optionalCallPrefix . values . length > 0 ) {
464+ revivedCallPrefix = removeCommonRevivedPrefix ( revivedCallPrefix , optionalCallPrefix ) ;
465+ revivedParallelismCallBetweenCalls = LlamaText ( [
466+ optionalCallPrefix ,
467+ revivedParallelismCallBetweenCalls
468+ ] ) ;
469+ }
470+ }
471+ }
472+
414473 return {
415474 stringifyParams,
416475 stringifyResult,
417476 combineModelMessageAndToolCalls,
418477 settings : {
419478 call : {
420479 optionalPrefixSpace : true ,
421- prefix : reviveSeparatorText ( callPrefixText , idToStaticContent , contentIds ) ,
480+ prefix : revivedCallPrefix ,
422481 paramsPrefix : reviveSeparatorText ( callParamsPrefixText , idToStaticContent , contentIds ) ,
423482 suffix : reviveSeparatorText ( callSuffixText , idToStaticContent , contentIds ) ,
424483 emptyCallParamsPlaceholder : { }
@@ -445,13 +504,8 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
445504 } ,
446505 parallelism : {
447506 call : {
448- sectionPrefix : removeCommonRevivedPrefix (
449- reviveSeparatorText ( parallelismCallPrefix , idToStaticContent , contentIds ) ,
450- ! combineModelMessageAndToolCalls
451- ? textBetween2TextualModelResponses
452- : LlamaText ( )
453- ) ,
454- betweenCalls : reviveSeparatorText ( parallelismBetweenCallsText , idToStaticContent , contentIds ) ,
507+ sectionPrefix : revivedParallelismCallSectionPrefix ,
508+ betweenCalls : revivedParallelismCallBetweenCalls ,
455509 sectionSuffix : reviveSeparatorText ( parallelismCallSuffixText , idToStaticContent , contentIds )
456510 } ,
457511 result : {
@@ -524,14 +578,48 @@ function removeCommonRevivedPrefix(target: LlamaText, matchStart: LlamaText) {
524578 } else if ( targetValue instanceof SpecialToken && matchStartValue instanceof SpecialToken ) {
525579 if ( targetValue . value === matchStartValue . value )
526580 continue ;
527- }
581+ } else if ( LlamaText ( targetValue ?? "" ) . compare ( LlamaText ( matchStartValue ?? "" ) ) )
582+ continue ;
528583
529584 return LlamaText ( target . values . slice ( commonStartLength ) ) ;
530585 }
531586
532587 return LlamaText ( target . values . slice ( matchStart . values . length ) ) ;
533588}
534589
590+ function removeCommonRevivedSuffix ( target : LlamaText , matchEnd : LlamaText ) {
591+ for (
592+ let commonEndLength = 0 ;
593+ commonEndLength < target . values . length && commonEndLength < matchEnd . values . length ;
594+ commonEndLength ++
595+ ) {
596+ const targetValue = target . values [ target . values . length - commonEndLength - 1 ] ;
597+ const matchEndValue = matchEnd . values [ matchEnd . values . length - commonEndLength - 1 ] ;
598+
599+ if ( typeof targetValue === "string" && typeof matchEndValue === "string" ) {
600+ if ( targetValue === matchEndValue )
601+ continue ;
602+ } else if ( targetValue instanceof SpecialTokensText && matchEndValue instanceof SpecialTokensText ) {
603+ const commonLength = findCommonEndLength ( targetValue . value , matchEndValue . value ) ;
604+ if ( commonLength === targetValue . value . length && commonLength === matchEndValue . value . length )
605+ continue ;
606+
607+ return LlamaText ( [
608+ ...target . values . slice ( 0 , target . values . length - commonEndLength - 1 ) ,
609+ new SpecialTokensText ( targetValue . value . slice ( 0 , targetValue . value . length - commonLength ) )
610+ ] ) ;
611+ } else if ( targetValue instanceof SpecialToken && matchEndValue instanceof SpecialToken ) {
612+ if ( targetValue . value === matchEndValue . value )
613+ continue ;
614+ } else if ( LlamaText ( targetValue ?? "" ) . compare ( LlamaText ( matchEndValue ?? "" ) ) )
615+ continue ;
616+
617+ return LlamaText ( target . values . slice ( 0 , target . values . length - commonEndLength - 1 ) ) ;
618+ }
619+
620+ return LlamaText ( target . values . slice ( 0 , target . values . length - matchEnd . values . length ) ) ;
621+ }
622+
535623function findCommandStartLength ( text1 : string , text2 : string ) {
536624 let commonStartLength = 0 ;
537625 while ( commonStartLength < text1 . length && commonStartLength < text2 . length ) {
0 commit comments