Skip to content

Commit 4983747

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: introduce an experimental parameter to limit number of steps LlmAgent can take
PiperOrigin-RevId: 778386550
1 parent 21c09ac commit 4983747

File tree

7 files changed

+160
-57
lines changed

7 files changed

+160
-57
lines changed

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ public enum IncludeContents {
8787
private final IncludeContents includeContents;
8888

8989
private final boolean planning;
90+
private final Optional<Integer> maxSteps;
9091
private final boolean disallowTransferToParent;
9192
private final boolean disallowTransferToPeers;
9293
private final Optional<List<BeforeModelCallback>> beforeModelCallback;
@@ -118,6 +119,7 @@ protected LlmAgent(Builder builder) {
118119
this.includeContents =
119120
builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT;
120121
this.planning = builder.planning != null && builder.planning;
122+
this.maxSteps = Optional.ofNullable(builder.maxSteps);
121123
this.disallowTransferToParent = builder.disallowTransferToParent;
122124
this.disallowTransferToPeers = builder.disallowTransferToPeers;
123125
this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback);
@@ -156,6 +158,7 @@ public static class Builder {
156158
private BaseExampleProvider exampleProvider;
157159
private IncludeContents includeContents;
158160
private Boolean planning;
161+
private Integer maxSteps;
159162
private Boolean disallowTransferToParent;
160163
private Boolean disallowTransferToPeers;
161164
private ImmutableList<BeforeModelCallback> beforeModelCallback;
@@ -290,6 +293,12 @@ public Builder planning(boolean planning) {
290293
return this;
291294
}
292295

296+
@CanIgnoreReturnValue
297+
public Builder maxSteps(int maxSteps) {
298+
this.maxSteps = maxSteps;
299+
return this;
300+
}
301+
293302
@CanIgnoreReturnValue
294303
public Builder disallowTransferToParent(boolean disallowTransferToParent) {
295304
this.disallowTransferToParent = disallowTransferToParent;
@@ -588,9 +597,9 @@ public LlmAgent build() {
588597

589598
protected BaseLlmFlow determineLlmFlow() {
590599
if (disallowTransferToParent() && disallowTransferToPeers() && subAgents().isEmpty()) {
591-
return new SingleFlow();
600+
return new SingleFlow(maxSteps);
592601
} else {
593-
return new AutoFlow();
602+
return new AutoFlow(maxSteps);
594603
}
595604
}
596605

core/src/main/java/com/google/adk/flows/llmflows/AutoFlow.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.adk.flows.llmflows;
1818

1919
import com.google.common.collect.ImmutableList;
20+
import java.util.Optional;
2021

2122
/** LLM flow with automatic agent transfer support. */
2223
public class AutoFlow extends SingleFlow {
@@ -32,6 +33,10 @@ public class AutoFlow extends SingleFlow {
3233
private static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS = ImmutableList.of();
3334

3435
public AutoFlow() {
35-
super(REQUEST_PROCESSORS, RESPONSE_PROCESSORS);
36+
this(/* maxSteps= */ Optional.empty());
37+
}
38+
39+
public AutoFlow(Optional<Integer> maxSteps) {
40+
super(REQUEST_PROCESSORS, RESPONSE_PROCESSORS, maxSteps);
3641
}
3742
}

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,23 @@ public abstract class BaseLlmFlow implements BaseFlow {
6565
protected final List<RequestProcessor> requestProcessors;
6666
protected final List<ResponseProcessor> responseProcessors;
6767

68+
// Warning: This is local, in-process state that won't be preserved if the runtime is restarted.
69+
// "Max steps" is experimental and may evolve in the future (e.g., to support persistence).
70+
protected int stepsCompleted = 0;
71+
protected final int maxSteps;
72+
6873
public BaseLlmFlow(
6974
List<RequestProcessor> requestProcessors, List<ResponseProcessor> responseProcessors) {
75+
this(requestProcessors, responseProcessors, /* maxSteps= */ Optional.empty());
76+
}
77+
78+
public BaseLlmFlow(
79+
List<RequestProcessor> requestProcessors,
80+
List<ResponseProcessor> responseProcessors,
81+
Optional<Integer> maxSteps) {
7082
this.requestProcessors = requestProcessors;
7183
this.responseProcessors = responseProcessors;
84+
this.maxSteps = maxSteps.orElse(Integer.MAX_VALUE);
7285
}
7386

7487
/**
@@ -407,6 +420,11 @@ private Flowable<Event> runOneStep(InvocationContext context) {
407420
@Override
408421
public Flowable<Event> run(InvocationContext invocationContext) {
409422
Flowable<Event> currentStepEvents = runOneStep(invocationContext).cache();
423+
if (++stepsCompleted >= maxSteps) {
424+
logger.debug("Ending flow execution because max steps reached.");
425+
return currentStepEvents;
426+
}
427+
410428
return currentStepEvents.concatWith(
411429
currentStepEvents
412430
.toList()

core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.google.common.collect.ImmutableList;
2020
import java.util.List;
21+
import java.util.Optional;
2122

2223
/** Basic LLM flow with fixed request processors and no response post-processing. */
2324
public class SingleFlow extends BaseLlmFlow {
@@ -29,11 +30,17 @@ public class SingleFlow extends BaseLlmFlow {
2930
protected static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS = ImmutableList.of();
3031

3132
public SingleFlow() {
32-
super(REQUEST_PROCESSORS, RESPONSE_PROCESSORS);
33+
this(/* maxSteps= */ Optional.empty());
34+
}
35+
36+
public SingleFlow(Optional<Integer> maxSteps) {
37+
this(REQUEST_PROCESSORS, RESPONSE_PROCESSORS, maxSteps);
3338
}
3439

3540
protected SingleFlow(
36-
List<RequestProcessor> requestProcessors, List<ResponseProcessor> responseProcessors) {
37-
super(requestProcessors, responseProcessors);
41+
List<RequestProcessor> requestProcessors,
42+
List<ResponseProcessor> responseProcessors,
43+
Optional<Integer> maxSteps) {
44+
super(requestProcessors, responseProcessors, maxSteps);
3845
}
3946
}

core/src/test/java/com/google/adk/agents/LlmAgentTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.agents;
1818

19+
import static com.google.adk.testing.TestUtils.assertEqualIgnoringFunctionIds;
1920
import static com.google.adk.testing.TestUtils.createInvocationContext;
2021
import static com.google.adk.testing.TestUtils.createLlmResponse;
2122
import static com.google.adk.testing.TestUtils.createTestAgent;
@@ -28,6 +29,7 @@
2829
import com.google.adk.events.Event;
2930
import com.google.adk.models.LlmResponse;
3031
import com.google.adk.testing.TestLlm;
32+
import com.google.adk.testing.TestUtils.EchoTool;
3133
import com.google.adk.tools.BaseTool;
3234
import com.google.common.collect.ImmutableList;
3335
import com.google.common.collect.ImmutableMap;
@@ -108,6 +110,33 @@ public void testRun_withoutOutputKey_doesNotSaveState() {
108110
assertThat(events.get(0).actions().stateDelta()).isEmpty();
109111
}
110112

113+
@Test
114+
public void run_withToolsAndMaxSteps_stopsAfterMaxSteps() {
115+
ImmutableMap<String, Object> echoArgs = ImmutableMap.of("arg", "value");
116+
Content contentWithFunctionCall =
117+
Content.fromParts(Part.fromText("text"), Part.fromFunctionCall("echo_tool", echoArgs));
118+
Content unreachableContent = Content.fromParts(Part.fromText("This should never be returned."));
119+
TestLlm testLlm =
120+
createTestLlm(
121+
createLlmResponse(contentWithFunctionCall),
122+
createLlmResponse(contentWithFunctionCall),
123+
createLlmResponse(unreachableContent));
124+
LlmAgent agent = createTestAgentBuilder(testLlm).tools(new EchoTool()).maxSteps(2).build();
125+
InvocationContext invocationContext = createInvocationContext(agent);
126+
127+
List<Event> events = agent.runAsync(invocationContext).toList().blockingGet();
128+
129+
Content expectedFunctionResponseContent =
130+
Content.fromParts(
131+
Part.fromFunctionResponse(
132+
"echo_tool", ImmutableMap.<String, Object>of("result", echoArgs)));
133+
assertThat(events).hasSize(4);
134+
assertEqualIgnoringFunctionIds(events.get(0).content().get(), contentWithFunctionCall);
135+
assertEqualIgnoringFunctionIds(events.get(1).content().get(), expectedFunctionResponseContent);
136+
assertEqualIgnoringFunctionIds(events.get(2).content().get(), contentWithFunctionCall);
137+
assertEqualIgnoringFunctionIds(events.get(3).content().get(), expectedFunctionResponseContent);
138+
}
139+
111140
@Test
112141
public void build_withOutputSchemaAndTools_throwsIllegalArgumentException() {
113142
BaseTool tool =

core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.flows.llmflows;
1818

19+
import static com.google.adk.testing.TestUtils.assertEqualIgnoringFunctionIds;
1920
import static com.google.adk.testing.TestUtils.createInvocationContext;
2021
import static com.google.adk.testing.TestUtils.createLlmResponse;
2122
import static com.google.adk.testing.TestUtils.createTestAgent;
@@ -36,9 +37,7 @@
3637
import com.google.common.collect.ImmutableList;
3738
import com.google.common.collect.ImmutableMap;
3839
import com.google.genai.types.Content;
39-
import com.google.genai.types.FunctionCall;
4040
import com.google.genai.types.FunctionDeclaration;
41-
import com.google.genai.types.FunctionResponse;
4241
import com.google.genai.types.Part;
4342
import io.reactivex.rxjava3.core.Flowable;
4443
import io.reactivex.rxjava3.core.Single;
@@ -91,21 +90,51 @@ public void run_withFunctionCall_returnsCorrectEvents() {
9190
List<Event> events = baseLlmFlow.run(invocationContext).toList().blockingGet();
9291

9392
assertThat(events).hasSize(3);
94-
assertContentIgnoringFunctionId(events.get(0).content().get(), firstContent);
95-
assertContentIgnoringFunctionId(
93+
assertEqualIgnoringFunctionIds(events.get(0).content().get(), firstContent);
94+
assertEqualIgnoringFunctionIds(
9695
events.get(1).content().get(),
97-
Content.fromParts(
98-
Part.builder()
99-
.functionResponse(
100-
FunctionResponse.builder()
101-
.id("")
102-
.name("my_function")
103-
.response(testResponse)
104-
.build())
105-
.build()));
96+
Content.fromParts(Part.fromFunctionResponse("my_function", testResponse)));
10697
assertThat(events.get(2).content()).hasValue(secondContent);
10798
}
10899

100+
@Test
101+
public void run_withFunctionCallsAndMaxSteps_stopsAfterMaxSteps() {
102+
Content contentWithFunctionCall =
103+
Content.fromParts(
104+
Part.fromText("LLM response with function call"),
105+
Part.fromFunctionCall("my_function", ImmutableMap.of("arg1", "value1")));
106+
Content unreachableContent = Content.fromParts(Part.fromText("This should never be returned."));
107+
TestLlm testLlm =
108+
createTestLlm(
109+
Flowable.just(createLlmResponse(contentWithFunctionCall)),
110+
Flowable.just(createLlmResponse(contentWithFunctionCall)),
111+
Flowable.just(createLlmResponse(unreachableContent)));
112+
ImmutableMap<String, Object> testResponse =
113+
ImmutableMap.<String, Object>of("response", "response for my_function");
114+
InvocationContext invocationContext =
115+
createInvocationContext(
116+
createTestAgentBuilder(testLlm)
117+
.tools(ImmutableList.of(new TestTool("my_function", testResponse)))
118+
.build());
119+
BaseLlmFlow baseLlmFlow =
120+
createBaseLlmFlow(
121+
/* requestProcessors= */ ImmutableList.of(),
122+
/* responseProcessors= */ ImmutableList.of(),
123+
/* maxSteps= */ Optional.of(2));
124+
125+
List<Event> events = baseLlmFlow.run(invocationContext).toList().blockingGet();
126+
127+
assertThat(events).hasSize(4);
128+
assertEqualIgnoringFunctionIds(events.get(0).content().get(), contentWithFunctionCall);
129+
assertEqualIgnoringFunctionIds(
130+
events.get(1).content().get(),
131+
Content.fromParts(Part.fromFunctionResponse("my_function", testResponse)));
132+
assertEqualIgnoringFunctionIds(events.get(2).content().get(), contentWithFunctionCall);
133+
assertEqualIgnoringFunctionIds(
134+
events.get(3).content().get(),
135+
Content.fromParts(Part.fromFunctionResponse("my_function", testResponse)));
136+
}
137+
109138
@Test
110139
public void run_withRequestProcessor_doesNotModifyRequest() {
111140
Content content = Content.fromParts(Part.fromText("LLM response"));
@@ -206,7 +235,15 @@ private static BaseLlmFlow createBaseLlmFlowWithoutProcessors() {
206235

207236
private static BaseLlmFlow createBaseLlmFlow(
208237
List<RequestProcessor> requestProcessors, List<ResponseProcessor> responseProcessors) {
209-
return new BaseLlmFlow(requestProcessors, responseProcessors) {};
238+
return createBaseLlmFlow(
239+
requestProcessors, responseProcessors, /* maxSteps= */ Optional.empty());
240+
}
241+
242+
private static BaseLlmFlow createBaseLlmFlow(
243+
List<RequestProcessor> requestProcessors,
244+
List<ResponseProcessor> responseProcessors,
245+
Optional<Integer> maxSteps) {
246+
return new BaseLlmFlow(requestProcessors, responseProcessors, maxSteps) {};
210247
}
211248

212249
private static RequestProcessor createRequestProcessor() {
@@ -256,41 +293,4 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
256293
return Single.just(response);
257294
}
258295
}
259-
260-
private static void assertContentIgnoringFunctionId(
261-
Content actualContent, Content expectedContent) {
262-
263-
assertThat(actualContent.role()).isEqualTo(expectedContent.role());
264-
265-
Optional<List<Part>> actualPartsOpt = actualContent.parts();
266-
Optional<List<Part>> expectedPartsOpt = expectedContent.parts();
267-
assertThat(actualPartsOpt.isPresent()).isEqualTo(expectedPartsOpt.isPresent());
268-
269-
if (expectedPartsOpt.isPresent()) {
270-
List<Part> actualParts = actualPartsOpt.get();
271-
List<Part> expectedParts = expectedPartsOpt.get();
272-
assertThat(actualParts).hasSize(expectedParts.size());
273-
274-
for (int i = 0; i < expectedParts.size(); i++) {
275-
Part actualPart = actualParts.get(i);
276-
Part expectedPart = expectedParts.get(i);
277-
278-
if (expectedPart.functionCall().isPresent()) {
279-
assertThat(actualPart.functionCall()).isPresent();
280-
FunctionCall actualFc = actualPart.functionCall().get();
281-
FunctionCall expectedFc = expectedPart.functionCall().get();
282-
assertThat(actualFc.name()).isEqualTo(expectedFc.name());
283-
assertThat(actualFc.args()).isEqualTo(expectedFc.args());
284-
} else if (expectedPart.functionResponse().isPresent()) {
285-
assertThat(actualPart.functionResponse()).isPresent();
286-
FunctionResponse actualFr = actualPart.functionResponse().get();
287-
FunctionResponse expectedFr = expectedPart.functionResponse().get();
288-
assertThat(actualFr.name()).isEqualTo(expectedFr.name());
289-
assertThat(actualFr.response()).isEqualTo(expectedFr.response());
290-
} else {
291-
assertThat(actualPart).isEqualTo(expectedPart);
292-
}
293-
}
294-
}
295-
}
296296
}

core/src/test/java/com/google/adk/testing/TestUtils.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.adk.testing;
1818

1919
import static com.google.common.collect.ImmutableList.toImmutableList;
20+
import static com.google.common.truth.Truth.assertThat;
2021
import static java.util.stream.Collectors.joining;
2122

2223
import com.google.adk.agents.BaseAgent;
@@ -142,6 +143,40 @@ public static ImmutableList<Object> simplifyEvents(List<Event> events) {
142143
.collect(toImmutableList());
143144
}
144145

146+
public static void assertEqualIgnoringFunctionIds(
147+
Content actualContent, Content expectedContent) {
148+
assertThat(overwriteFunctionIdsInContent(actualContent))
149+
.isEqualTo(overwriteFunctionIdsInContent(expectedContent));
150+
}
151+
152+
private static Content overwriteFunctionIdsInContent(Content content) {
153+
if (content.parts().isEmpty()) {
154+
return content;
155+
}
156+
return content.toBuilder()
157+
.parts(
158+
content.parts().get().stream()
159+
.map(TestUtils::overwriteFunctionIdsInPart)
160+
.collect(toImmutableList()))
161+
.build();
162+
}
163+
164+
private static Part overwriteFunctionIdsInPart(Part part) {
165+
if (part.functionCall().isPresent()) {
166+
return part.toBuilder()
167+
.functionCall(
168+
part.functionCall().get().toBuilder().id("<overwritten by TestUtils>").build())
169+
.build();
170+
}
171+
if (part.functionResponse().isPresent()) {
172+
return part.toBuilder()
173+
.functionResponse(
174+
part.functionResponse().get().toBuilder().id("<overwritten by TestUtils>").build())
175+
.build();
176+
}
177+
return part;
178+
}
179+
145180
public static TestBaseAgent createRootAgent(BaseAgent... subAgents) {
146181
return createRootAgent(Arrays.asList(subAgents));
147182
}

0 commit comments

Comments
 (0)