|
21 | 21 | import static com.google.common.truth.Truth.assertThat; |
22 | 22 | import static org.junit.Assert.assertThrows; |
23 | 23 |
|
| 24 | +import com.google.adk.agents.Callbacks.AfterAgentCallback; |
24 | 25 | import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; |
25 | 26 | import com.google.adk.agents.InvocationContext; |
26 | 27 | import com.google.adk.agents.LlmAgent; |
|
35 | 36 | import com.google.genai.types.Part; |
36 | 37 | import com.google.genai.types.Schema; |
37 | 38 | import io.reactivex.rxjava3.core.Flowable; |
| 39 | +import io.reactivex.rxjava3.core.Maybe; |
38 | 40 | import java.util.Map; |
39 | 41 | import java.util.Optional; |
40 | 42 | import org.junit.Test; |
@@ -421,6 +423,35 @@ public void call_withoutInputSchema_requestIsSentToAgent() throws Exception { |
421 | 423 | .containsExactly(Content.fromParts(Part.fromText("magic"))); |
422 | 424 | } |
423 | 425 |
|
| 426 | + @Test |
| 427 | + public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exception { |
| 428 | + AfterAgentCallback afterAgentCallback = |
| 429 | + (callbackContext) -> { |
| 430 | + callbackContext.state().put("test_key", "test_value"); |
| 431 | + return Maybe.empty(); |
| 432 | + }; |
| 433 | + TestLlm testLlm = |
| 434 | + createTestLlm( |
| 435 | + LlmResponse.builder() |
| 436 | + .content(Content.fromParts(Part.fromText("test response"))) |
| 437 | + .build()); |
| 438 | + LlmAgent testAgent = |
| 439 | + createTestAgentBuilder(testLlm) |
| 440 | + .name("agent name") |
| 441 | + .description("agent description") |
| 442 | + .afterAgentCallback(afterAgentCallback) |
| 443 | + .build(); |
| 444 | + AgentTool agentTool = AgentTool.create(testAgent); |
| 445 | + ToolContext toolContext = createToolContext(testAgent); |
| 446 | + |
| 447 | + assertThat(toolContext.state()).doesNotContainKey("test_key"); |
| 448 | + |
| 449 | + Map<String, Object> unused = |
| 450 | + agentTool.runAsync(ImmutableMap.of("request", "magic"), toolContext).blockingGet(); |
| 451 | + |
| 452 | + assertThat(toolContext.state()).containsEntry("test_key", "test_value"); |
| 453 | + } |
| 454 | + |
424 | 455 | private static ToolContext createToolContext(LlmAgent agent) { |
425 | 456 | return ToolContext.builder( |
426 | 457 | new InvocationContext( |
|
0 commit comments