1919import static com .google .adk .flows .llmflows .Functions .REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ;
2020import static com .google .common .collect .ImmutableList .toImmutableList ;
2121import static com .google .common .collect .ImmutableMap .toImmutableMap ;
22+ import static com .google .common .collect .ImmutableSet .toImmutableSet ;
2223
2324import com .fasterxml .jackson .core .JsonProcessingException ;
2425import com .fasterxml .jackson .databind .ObjectMapper ;
25- import com .fasterxml . jackson . datatype . jdk8 . Jdk8Module ;
26+ import com .google . adk . JsonBaseModel ;
2627import com .google .adk .agents .InvocationContext ;
2728import com .google .adk .agents .LlmAgent ;
2829import com .google .adk .events .Event ;
3132import com .google .adk .tools .ToolConfirmation ;
3233import com .google .common .collect .ImmutableList ;
3334import com .google .common .collect .ImmutableMap ;
35+ import com .google .common .collect .ImmutableSet ;
3436import com .google .genai .types .Content ;
3537import com .google .genai .types .FunctionCall ;
3638import com .google .genai .types .FunctionResponse ;
3739import com .google .genai .types .Part ;
3840import io .reactivex .rxjava3 .core .Maybe ;
3941import io .reactivex .rxjava3 .core .Single ;
4042import java .util .Collection ;
41- import java .util .List ;
43+ import java .util .HashMap ;
4244import java .util .Map ;
4345import java .util .Objects ;
4446import java .util .Optional ;
4951public class RequestConfirmationLlmRequestProcessor implements RequestProcessor {
5052 private static final Logger logger =
5153 LoggerFactory .getLogger (RequestConfirmationLlmRequestProcessor .class );
52- private final ObjectMapper objectMapper ;
53-
54- public RequestConfirmationLlmRequestProcessor () {
55- objectMapper = new ObjectMapper ().registerModule (new Jdk8Module ());
56- }
54+ private static final ObjectMapper OBJECT_MAPPER = JsonBaseModel .getMapper ();
5755
5856 @ Override
5957 public Single <RequestProcessor .RequestProcessingResult > processRequest (
6058 InvocationContext invocationContext , LlmRequest llmRequest ) {
61- List <Event > events = invocationContext .session ().events ();
59+ ImmutableList <Event > events = ImmutableList . copyOf ( invocationContext .session ().events () );
6260 if (events .isEmpty ()) {
6361 logger .info (
6462 "No events are present in the session. Skipping request confirmation processing." );
6563 return Single .just (RequestProcessingResult .create (llmRequest , ImmutableList .of ()));
6664 }
6765
68- ImmutableMap <String , ToolConfirmation > requestConfirmationFunctionResponses =
69- filterRequestConfirmationFunctionResponses (events );
66+ ImmutableMap <String , ToolConfirmation > responses = ImmutableMap .of ();
67+ int confirmationEventIndex = -1 ;
68+ for (int i = events .size () - 1 ; i >= 0 ; i --) {
69+ Event event = events .get (i );
70+ if (!Objects .equals (event .author (), "user" )) {
71+ continue ;
72+ }
73+ if (event .functionResponses ().isEmpty ()) {
74+ continue ;
75+ }
76+ responses =
77+ event .functionResponses ().stream ()
78+ .filter (functionResponse -> functionResponse .id ().isPresent ())
79+ .filter (
80+ functionResponse ->
81+ Objects .equals (
82+ functionResponse .name ().orElse (null ),
83+ REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ))
84+ .map (this ::maybeCreateToolConfirmationEntry )
85+ .flatMap (Optional ::stream )
86+ .collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
87+ confirmationEventIndex = i ;
88+ break ;
89+ }
90+
91+ // Make it final to enable access from lambda expressions.
92+ final ImmutableMap <String , ToolConfirmation > requestConfirmationFunctionResponses = responses ;
93+
7094 if (requestConfirmationFunctionResponses .isEmpty ()) {
7195 logger .info ("No request confirmation function responses found." );
7296 return Single .just (RequestProcessingResult .create (llmRequest , ImmutableList .of ()));
7397 }
7498
75- for (ImmutableList <FunctionCall > functionCalls :
76- events .stream ()
77- .map (Event ::functionCalls )
78- .filter (fc -> !fc .isEmpty ())
79- .collect (toImmutableList ())) {
99+ for (int i = events .size () - 2 ; i >= 0 ; i --) {
100+ Event event = events .get (i );
101+ if (event .functionCalls ().isEmpty ()) {
102+ continue ;
103+ }
104+
105+ Map <String , ToolConfirmation > toolsToResumeWithConfirmation = new HashMap <>();
106+ Map <String , FunctionCall > toolsToResumeWithArgs = new HashMap <>();
107+
108+ event .functionCalls ().stream ()
109+ .filter (
110+ fc ->
111+ fc .id ().isPresent ()
112+ && requestConfirmationFunctionResponses .containsKey (fc .id ().get ()))
113+ .forEach (
114+ fc ->
115+ getOriginalFunctionCall (fc )
116+ .ifPresent (
117+ ofc -> {
118+ toolsToResumeWithConfirmation .put (
119+ ofc .id ().get (),
120+ requestConfirmationFunctionResponses .get (fc .id ().get ()));
121+ toolsToResumeWithArgs .put (ofc .id ().get (), ofc );
122+ }));
123+
124+ if (toolsToResumeWithConfirmation .isEmpty ()) {
125+ continue ;
126+ }
127+
128+ // Remove the tools that have already been confirmed.
129+ ImmutableSet <String > alreadyConfirmedIds =
130+ events .subList (confirmationEventIndex + 1 , events .size ()).stream ()
131+ .flatMap (e -> e .functionResponses ().stream ())
132+ .map (FunctionResponse ::id )
133+ .flatMap (Optional ::stream )
134+ .collect (toImmutableSet ());
135+ toolsToResumeWithConfirmation .keySet ().removeAll (alreadyConfirmedIds );
136+ toolsToResumeWithArgs .keySet ().removeAll (alreadyConfirmedIds );
80137
81- ImmutableMap <String , FunctionCall > toolsToResumeWithArgs =
82- filterToolsToResumeWithArgs (functionCalls , requestConfirmationFunctionResponses );
83- ImmutableMap <String , ToolConfirmation > toolsToResumeWithConfirmation =
84- toolsToResumeWithArgs .keySet ().stream ()
85- .filter (
86- id ->
87- events .stream ()
88- .flatMap (e -> e .functionResponses ().stream ())
89- .anyMatch (fr -> Objects .equals (fr .id ().orElse (null ), id )))
90- .collect (toImmutableMap (k -> k , requestConfirmationFunctionResponses ::get ));
91138 if (toolsToResumeWithConfirmation .isEmpty ()) {
92- logger .info ("No tools to resume with confirmation." );
93139 continue ;
94140 }
95141
96142 return assembleEvent (
97- invocationContext , toolsToResumeWithArgs .values (), toolsToResumeWithConfirmation )
98- .map (event -> RequestProcessingResult .create (llmRequest , ImmutableList .of (event )))
143+ invocationContext ,
144+ toolsToResumeWithArgs .values (),
145+ ImmutableMap .copyOf (toolsToResumeWithConfirmation ))
146+ .map (e -> RequestProcessingResult .create (llmRequest , ImmutableList .of (e )))
99147 .toSingle ();
100148 }
101149
102150 return Single .just (RequestProcessingResult .create (llmRequest , ImmutableList .of ()));
103151 }
104152
153+ private Optional <FunctionCall > getOriginalFunctionCall (FunctionCall functionCall ) {
154+ if (!functionCall .args ().orElse (ImmutableMap .of ()).containsKey ("originalFunctionCall" )) {
155+ return Optional .empty ();
156+ }
157+ try {
158+ FunctionCall originalFunctionCall =
159+ OBJECT_MAPPER .convertValue (
160+ functionCall .args ().get ().get ("originalFunctionCall" ), FunctionCall .class );
161+ if (originalFunctionCall .id ().isEmpty ()) {
162+ return Optional .empty ();
163+ }
164+ return Optional .of (originalFunctionCall );
165+ } catch (IllegalArgumentException e ) {
166+ logger .warn ("Failed to convert originalFunctionCall argument." , e );
167+ return Optional .empty ();
168+ }
169+ }
170+
105171 private Maybe <Event > assembleEvent (
106172 InvocationContext invocationContext ,
107173 Collection <FunctionCall > functionCalls ,
108174 Map <String , ToolConfirmation > toolConfirmations ) {
109- ImmutableMap . Builder <String , BaseTool > toolsBuilder = ImmutableMap . builder () ;
175+ Single < ImmutableMap <String , BaseTool >> toolsMapSingle ;
110176 if (invocationContext .agent () instanceof LlmAgent llmAgent ) {
111- for (BaseTool tool : llmAgent .tools ()) {
112- toolsBuilder .put (tool .name (), tool );
113- }
177+ toolsMapSingle =
178+ llmAgent
179+ .tools ()
180+ .map (
181+ toolList ->
182+ toolList .stream ().collect (toImmutableMap (BaseTool ::name , tool -> tool )));
183+ } else {
184+ toolsMapSingle = Single .just (ImmutableMap .of ());
114185 }
115186
116187 var functionCallEvent =
@@ -124,23 +195,10 @@ private Maybe<Event> assembleEvent(
124195 .build ())
125196 .build ();
126197
127- return Functions .handleFunctionCalls (
128- invocationContext , functionCallEvent , toolsBuilder .buildOrThrow (), toolConfirmations );
129- }
130-
131- private ImmutableMap <String , ToolConfirmation > filterRequestConfirmationFunctionResponses (
132- List <Event > events ) {
133- return events .stream ()
134- .filter (event -> Objects .equals (event .author (), "user" ))
135- .flatMap (event -> event .functionResponses ().stream ())
136- .filter (functionResponse -> functionResponse .id ().isPresent ())
137- .filter (
138- functionResponse ->
139- Objects .equals (
140- functionResponse .name ().orElse (null ), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ))
141- .map (this ::maybeCreateToolConfirmationEntry )
142- .flatMap (Optional ::stream )
143- .collect (toImmutableMap (Map .Entry ::getKey , Map .Entry ::getValue ));
198+ return toolsMapSingle .flatMapMaybe (
199+ toolsMap ->
200+ Functions .handleFunctionCalls (
201+ invocationContext , functionCallEvent , toolsMap , toolConfirmations ));
144202 }
145203
146204 private Optional <Map .Entry <String , ToolConfirmation >> maybeCreateToolConfirmationEntry (
@@ -150,36 +208,19 @@ private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmatio
150208 return Optional .of (
151209 Map .entry (
152210 functionResponse .id ().get (),
153- objectMapper .convertValue (responseMap , ToolConfirmation .class )));
211+ OBJECT_MAPPER .convertValue (responseMap , ToolConfirmation .class )));
154212 }
155213
156214 try {
157215 return Optional .of (
158216 Map .entry (
159217 functionResponse .id ().get (),
160- objectMapper .readValue (
218+ OBJECT_MAPPER .readValue (
161219 (String ) responseMap .get ("response" ), ToolConfirmation .class )));
162220 } catch (JsonProcessingException e ) {
163221 logger .error ("Failed to parse tool confirmation response" , e );
164222 }
165223
166224 return Optional .empty ();
167225 }
168-
169- private ImmutableMap <String , FunctionCall > filterToolsToResumeWithArgs (
170- ImmutableList <FunctionCall > functionCalls ,
171- Map <String , ToolConfirmation > requestConfirmationFunctionResponses ) {
172- return functionCalls .stream ()
173- .filter (fc -> fc .id ().isPresent ())
174- .filter (fc -> requestConfirmationFunctionResponses .containsKey (fc .id ().get ()))
175- .filter (
176- fc -> Objects .equals (fc .name ().orElse (null ), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME ))
177- .filter (fc -> fc .args ().orElse (ImmutableMap .of ()).containsKey ("originalFunctionCall" ))
178- .collect (
179- toImmutableMap (
180- fc -> fc .id ().get (),
181- fc ->
182- objectMapper .convertValue (
183- fc .args ().get ().get ("originalFunctionCall" ), FunctionCall .class )));
184- }
185226}
0 commit comments