Skip to content

Commit 937f6e2

Browse files
google-genai-botcopybara-github
authored andcommitted
feat(HITL): Let ADK resume after HITL approval is present
feat(HITL): Declining a proposal now correctly intercepts the run fix: Events for HITL are now emitted correctly fix: HITL endless loop when asking for approvals PiperOrigin-RevId: 839858592
1 parent 441c9a6 commit 937f6e2

File tree

9 files changed

+422
-108
lines changed

9 files changed

+422
-108
lines changed

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,8 @@ public IncludeContents includeContents() {
736736
return includeContents;
737737
}
738738

739-
public List<BaseTool> tools() {
740-
return canonicalTools().toList().blockingGet();
739+
public Single<List<BaseTool>> tools() {
740+
return canonicalTools().toList();
741741
}
742742

743743
public List<Object> toolsUnion() {

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package com.google.adk.flows.llmflows;
1919

20+
import static com.google.common.collect.ImmutableList.toImmutableList;
2021
import static com.google.common.collect.ImmutableMap.toImmutableMap;
2122

2223
import com.google.adk.Telemetry;
@@ -64,7 +65,7 @@
6465
public final class Functions {
6566

6667
private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";
67-
static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
68+
public static final String REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = "adk_request_confirmation";
6869
private static final Logger logger = LoggerFactory.getLogger(Functions.class);
6970

7071
/** Generates a unique ID for a function call. */
@@ -147,12 +148,22 @@ public static Maybe<Event> handleFunctionCalls(
147148
Function<FunctionCall, Maybe<Event>> functionCallMapper =
148149
functionCall -> {
149150
BaseTool tool = tools.get(functionCall.name().get());
151+
ToolConfirmation toolConfirmation = toolConfirmations.get(functionCall.id().orElse(null));
150152
ToolContext toolContext =
151153
ToolContext.builder(invocationContext)
152154
.functionCallId(functionCall.id().orElse(""))
153-
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
155+
.toolConfirmation(toolConfirmation)
154156
.build();
155157

158+
if (toolConfirmation != null && !toolConfirmation.confirmed()) {
159+
return Maybe.just(
160+
buildResponseEvent(
161+
tool,
162+
ImmutableMap.of("error", "User declined tool execution for " + tool.name()),
163+
toolContext,
164+
invocationContext));
165+
}
166+
156167
Map<String, Object> functionArgs = functionCall.args().orElse(ImmutableMap.of());
157168

158169
Maybe<Map<String, Object>> maybeFunctionResult =
@@ -241,6 +252,18 @@ public static Maybe<Event> handleFunctionCalls(
241252
*/
242253
public static Maybe<Event> handleFunctionCallsLive(
243254
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
255+
return handleFunctionCallsLive(invocationContext, functionCallEvent, tools, ImmutableMap.of());
256+
}
257+
258+
/**
259+
* Handles function calls in a live/streaming context with tool confirmations, supporting
260+
* background execution and stream termination.
261+
*/
262+
public static Maybe<Event> handleFunctionCallsLive(
263+
InvocationContext invocationContext,
264+
Event functionCallEvent,
265+
Map<String, BaseTool> tools,
266+
Map<String, ToolConfirmation> toolConfirmations) {
244267
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
245268

246269
for (FunctionCall functionCall : functionCalls) {
@@ -255,7 +278,9 @@ public static Maybe<Event> handleFunctionCallsLive(
255278
ToolContext toolContext =
256279
ToolContext.builder(invocationContext)
257280
.functionCallId(functionCall.id().orElse(""))
281+
.toolConfirmation(toolConfirmations.get(functionCall.id().orElse(null)))
258282
.build();
283+
259284
Map<String, Object> functionArgs = functionCall.args().orElse(new HashMap<>());
260285

261286
Maybe<Map<String, Object>> maybeFunctionResult =
@@ -664,5 +689,29 @@ public static Optional<Event> generateRequestConfirmationEvent(
664689
.build());
665690
}
666691

692+
/**
693+
* Gets the ask user confirmation function calls from the event.
694+
*
695+
* @param event The event to extract function calls from.
696+
* @return A list of function calls for asking user confirmation.
697+
*/
698+
public static ImmutableList<FunctionCall> getAskUserConfirmationFunctionCalls(Event event) {
699+
return event
700+
.content()
701+
.flatMap(Content::parts)
702+
.map(
703+
parts ->
704+
parts.stream()
705+
.flatMap(part -> part.functionCall().stream())
706+
.filter(
707+
functionCall ->
708+
functionCall
709+
.name()
710+
.map(name -> name.equals(REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
711+
.orElse(false))
712+
.collect(toImmutableList()))
713+
.orElse(ImmutableList.of());
714+
}
715+
667716
private Functions() {}
668717
}

core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java

Lines changed: 109 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME;
2020
import static com.google.common.collect.ImmutableList.toImmutableList;
2121
import static com.google.common.collect.ImmutableMap.toImmutableMap;
22+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
2223

2324
import com.fasterxml.jackson.core.JsonProcessingException;
2425
import com.fasterxml.jackson.databind.ObjectMapper;
25-
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
26+
import com.google.adk.JsonBaseModel;
2627
import com.google.adk.agents.InvocationContext;
2728
import com.google.adk.agents.LlmAgent;
2829
import com.google.adk.events.Event;
@@ -31,14 +32,15 @@
3132
import com.google.adk.tools.ToolConfirmation;
3233
import com.google.common.collect.ImmutableList;
3334
import com.google.common.collect.ImmutableMap;
35+
import com.google.common.collect.ImmutableSet;
3436
import com.google.genai.types.Content;
3537
import com.google.genai.types.FunctionCall;
3638
import com.google.genai.types.FunctionResponse;
3739
import com.google.genai.types.Part;
3840
import io.reactivex.rxjava3.core.Maybe;
3941
import io.reactivex.rxjava3.core.Single;
4042
import java.util.Collection;
41-
import java.util.List;
43+
import java.util.HashMap;
4244
import java.util.Map;
4345
import java.util.Objects;
4446
import java.util.Optional;
@@ -49,68 +51,137 @@
4951
public class RequestConfirmationLlmRequestProcessor implements RequestProcessor {
5052
private static final Logger logger =
5153
LoggerFactory.getLogger(RequestConfirmationLlmRequestProcessor.class);
52-
private final ObjectMapper objectMapper;
53-
54-
public RequestConfirmationLlmRequestProcessor() {
55-
objectMapper = new ObjectMapper().registerModule(new Jdk8Module());
56-
}
54+
private static final ObjectMapper OBJECT_MAPPER = JsonBaseModel.getMapper();
5755

5856
@Override
5957
public Single<RequestProcessor.RequestProcessingResult> processRequest(
6058
InvocationContext invocationContext, LlmRequest llmRequest) {
61-
List<Event> events = invocationContext.session().events();
59+
ImmutableList<Event> events = ImmutableList.copyOf(invocationContext.session().events());
6260
if (events.isEmpty()) {
6361
logger.info(
6462
"No events are present in the session. Skipping request confirmation processing.");
6563
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
6664
}
6765

68-
ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses =
69-
filterRequestConfirmationFunctionResponses(events);
66+
ImmutableMap<String, ToolConfirmation> responses = ImmutableMap.of();
67+
int confirmationEventIndex = -1;
68+
for (int i = events.size() - 1; i >= 0; i--) {
69+
Event event = events.get(i);
70+
if (!Objects.equals(event.author(), "user")) {
71+
continue;
72+
}
73+
if (event.functionResponses().isEmpty()) {
74+
continue;
75+
}
76+
responses =
77+
event.functionResponses().stream()
78+
.filter(functionResponse -> functionResponse.id().isPresent())
79+
.filter(
80+
functionResponse ->
81+
Objects.equals(
82+
functionResponse.name().orElse(null),
83+
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
84+
.map(this::maybeCreateToolConfirmationEntry)
85+
.flatMap(Optional::stream)
86+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
87+
confirmationEventIndex = i;
88+
break;
89+
}
90+
91+
// Make it final to enable access from lambda expressions.
92+
final ImmutableMap<String, ToolConfirmation> requestConfirmationFunctionResponses = responses;
93+
7094
if (requestConfirmationFunctionResponses.isEmpty()) {
7195
logger.info("No request confirmation function responses found.");
7296
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
7397
}
7498

75-
for (ImmutableList<FunctionCall> functionCalls :
76-
events.stream()
77-
.map(Event::functionCalls)
78-
.filter(fc -> !fc.isEmpty())
79-
.collect(toImmutableList())) {
99+
for (int i = events.size() - 2; i >= 0; i--) {
100+
Event event = events.get(i);
101+
if (event.functionCalls().isEmpty()) {
102+
continue;
103+
}
104+
105+
Map<String, ToolConfirmation> toolsToResumeWithConfirmation = new HashMap<>();
106+
Map<String, FunctionCall> toolsToResumeWithArgs = new HashMap<>();
107+
108+
event.functionCalls().stream()
109+
.filter(
110+
fc ->
111+
fc.id().isPresent()
112+
&& requestConfirmationFunctionResponses.containsKey(fc.id().get()))
113+
.forEach(
114+
fc ->
115+
getOriginalFunctionCall(fc)
116+
.ifPresent(
117+
ofc -> {
118+
toolsToResumeWithConfirmation.put(
119+
ofc.id().get(),
120+
requestConfirmationFunctionResponses.get(fc.id().get()));
121+
toolsToResumeWithArgs.put(ofc.id().get(), ofc);
122+
}));
123+
124+
if (toolsToResumeWithConfirmation.isEmpty()) {
125+
continue;
126+
}
127+
128+
// Remove the tools that have already been confirmed.
129+
ImmutableSet<String> alreadyConfirmedIds =
130+
events.subList(confirmationEventIndex + 1, events.size()).stream()
131+
.flatMap(e -> e.functionResponses().stream())
132+
.map(FunctionResponse::id)
133+
.flatMap(Optional::stream)
134+
.collect(toImmutableSet());
135+
toolsToResumeWithConfirmation.keySet().removeAll(alreadyConfirmedIds);
136+
toolsToResumeWithArgs.keySet().removeAll(alreadyConfirmedIds);
80137

81-
ImmutableMap<String, FunctionCall> toolsToResumeWithArgs =
82-
filterToolsToResumeWithArgs(functionCalls, requestConfirmationFunctionResponses);
83-
ImmutableMap<String, ToolConfirmation> toolsToResumeWithConfirmation =
84-
toolsToResumeWithArgs.keySet().stream()
85-
.filter(
86-
id ->
87-
events.stream()
88-
.flatMap(e -> e.functionResponses().stream())
89-
.anyMatch(fr -> Objects.equals(fr.id().orElse(null), id)))
90-
.collect(toImmutableMap(k -> k, requestConfirmationFunctionResponses::get));
91138
if (toolsToResumeWithConfirmation.isEmpty()) {
92-
logger.info("No tools to resume with confirmation.");
93139
continue;
94140
}
95141

96142
return assembleEvent(
97-
invocationContext, toolsToResumeWithArgs.values(), toolsToResumeWithConfirmation)
98-
.map(event -> RequestProcessingResult.create(llmRequest, ImmutableList.of(event)))
143+
invocationContext,
144+
toolsToResumeWithArgs.values(),
145+
ImmutableMap.copyOf(toolsToResumeWithConfirmation))
146+
.map(e -> RequestProcessingResult.create(llmRequest, ImmutableList.of(e)))
99147
.toSingle();
100148
}
101149

102150
return Single.just(RequestProcessingResult.create(llmRequest, ImmutableList.of()));
103151
}
104152

153+
private Optional<FunctionCall> getOriginalFunctionCall(FunctionCall functionCall) {
154+
if (!functionCall.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall")) {
155+
return Optional.empty();
156+
}
157+
try {
158+
FunctionCall originalFunctionCall =
159+
OBJECT_MAPPER.convertValue(
160+
functionCall.args().get().get("originalFunctionCall"), FunctionCall.class);
161+
if (originalFunctionCall.id().isEmpty()) {
162+
return Optional.empty();
163+
}
164+
return Optional.of(originalFunctionCall);
165+
} catch (IllegalArgumentException e) {
166+
logger.warn("Failed to convert originalFunctionCall argument.", e);
167+
return Optional.empty();
168+
}
169+
}
170+
105171
private Maybe<Event> assembleEvent(
106172
InvocationContext invocationContext,
107173
Collection<FunctionCall> functionCalls,
108174
Map<String, ToolConfirmation> toolConfirmations) {
109-
ImmutableMap.Builder<String, BaseTool> toolsBuilder = ImmutableMap.builder();
175+
Single<ImmutableMap<String, BaseTool>> toolsMapSingle;
110176
if (invocationContext.agent() instanceof LlmAgent llmAgent) {
111-
for (BaseTool tool : llmAgent.tools()) {
112-
toolsBuilder.put(tool.name(), tool);
113-
}
177+
toolsMapSingle =
178+
llmAgent
179+
.tools()
180+
.map(
181+
toolList ->
182+
toolList.stream().collect(toImmutableMap(BaseTool::name, tool -> tool)));
183+
} else {
184+
toolsMapSingle = Single.just(ImmutableMap.of());
114185
}
115186

116187
var functionCallEvent =
@@ -124,23 +195,10 @@ private Maybe<Event> assembleEvent(
124195
.build())
125196
.build();
126197

127-
return Functions.handleFunctionCalls(
128-
invocationContext, functionCallEvent, toolsBuilder.buildOrThrow(), toolConfirmations);
129-
}
130-
131-
private ImmutableMap<String, ToolConfirmation> filterRequestConfirmationFunctionResponses(
132-
List<Event> events) {
133-
return events.stream()
134-
.filter(event -> Objects.equals(event.author(), "user"))
135-
.flatMap(event -> event.functionResponses().stream())
136-
.filter(functionResponse -> functionResponse.id().isPresent())
137-
.filter(
138-
functionResponse ->
139-
Objects.equals(
140-
functionResponse.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
141-
.map(this::maybeCreateToolConfirmationEntry)
142-
.flatMap(Optional::stream)
143-
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
198+
return toolsMapSingle.flatMapMaybe(
199+
toolsMap ->
200+
Functions.handleFunctionCalls(
201+
invocationContext, functionCallEvent, toolsMap, toolConfirmations));
144202
}
145203

146204
private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmationEntry(
@@ -150,36 +208,19 @@ private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmatio
150208
return Optional.of(
151209
Map.entry(
152210
functionResponse.id().get(),
153-
objectMapper.convertValue(responseMap, ToolConfirmation.class)));
211+
OBJECT_MAPPER.convertValue(responseMap, ToolConfirmation.class)));
154212
}
155213

156214
try {
157215
return Optional.of(
158216
Map.entry(
159217
functionResponse.id().get(),
160-
objectMapper.readValue(
218+
OBJECT_MAPPER.readValue(
161219
(String) responseMap.get("response"), ToolConfirmation.class)));
162220
} catch (JsonProcessingException e) {
163221
logger.error("Failed to parse tool confirmation response", e);
164222
}
165223

166224
return Optional.empty();
167225
}
168-
169-
private ImmutableMap<String, FunctionCall> filterToolsToResumeWithArgs(
170-
ImmutableList<FunctionCall> functionCalls,
171-
Map<String, ToolConfirmation> requestConfirmationFunctionResponses) {
172-
return functionCalls.stream()
173-
.filter(fc -> fc.id().isPresent())
174-
.filter(fc -> requestConfirmationFunctionResponses.containsKey(fc.id().get()))
175-
.filter(
176-
fc -> Objects.equals(fc.name().orElse(null), REQUEST_CONFIRMATION_FUNCTION_CALL_NAME))
177-
.filter(fc -> fc.args().orElse(ImmutableMap.of()).containsKey("originalFunctionCall"))
178-
.collect(
179-
toImmutableMap(
180-
fc -> fc.id().get(),
181-
fc ->
182-
objectMapper.convertValue(
183-
fc.args().get().get("originalFunctionCall"), FunctionCall.class)));
184-
}
185226
}

core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public class SingleFlow extends BaseLlmFlow {
3131
new Identity(),
3232
new Contents(),
3333
new Examples(),
34+
new RequestConfirmationLlmRequestProcessor(),
3435
CodeExecution.requestProcessor);
3536

3637
protected static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS =

0 commit comments

Comments
 (0)