Skip to content

Commit c24c3b7

Browse files
fix: enable service injection in A2ASendMessageExecutor
Updated A2ASendMessageExecutor and A2ARemoteConfiguration to support dependency injection of session, artifact, and memory services. This allows persistent service implementations to replace hard-coded in-memory versions. Changes: - Modified A2ASendMessageExecutor constructor to accept BaseSessionService, BaseArtifactService, and BaseMemoryService parameters - Updated A2ARemoteConfiguration to autowire service beans with fallback to in-memory defaults when custom beans are not provided - Added comprehensive test coverage for service injection scenarios Tasks: [x] Update A2ASendMessageExecutor to accept injected services [x] Modify A2ARemoteConfiguration to autowire service beans [x] Add unit tests for A2ASendMessageExecutor [x] Add integration tests for service injection [x] Add Spring configuration tests for custom services
1 parent b66e4a5 commit c24c3b7

File tree

9 files changed

+635
-15
lines changed

9 files changed

+635
-15
lines changed

a2a/src/main/java/com/google/adk/a2a/A2ASendMessageExecutor.java

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import com.google.adk.a2a.converters.ResponseConverter;
99
import com.google.adk.agents.BaseAgent;
1010
import com.google.adk.agents.RunConfig;
11-
import com.google.adk.artifacts.InMemoryArtifactService;
11+
import com.google.adk.artifacts.BaseArtifactService;
1212
import com.google.adk.events.Event;
13-
import com.google.adk.memory.InMemoryMemoryService;
13+
import com.google.adk.memory.BaseMemoryService;
1414
import com.google.adk.runner.Runner;
15-
import com.google.adk.sessions.InMemorySessionService;
15+
import com.google.adk.sessions.BaseSessionService;
1616
import com.google.adk.sessions.Session;
1717
import com.google.common.collect.ImmutableList;
1818
import com.google.genai.types.Content;
@@ -51,29 +51,63 @@ Single<ImmutableList<Event>> execute(
5151
String invocationId);
5252
}
5353

54-
private final InMemorySessionService sessionService;
54+
private final BaseSessionService sessionService;
5555
private final String appName;
5656
@Nullable private final Runner runner;
5757
@Nullable private final Duration agentTimeout;
5858
private static final RunConfig DEFAULT_RUN_CONFIG =
5959
RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build();
6060

61-
public A2ASendMessageExecutor(InMemorySessionService sessionService, String appName) {
61+
public A2ASendMessageExecutor(BaseSessionService sessionService, String appName) {
6262
this.sessionService = sessionService;
6363
this.appName = appName;
6464
this.runner = null;
6565
this.agentTimeout = null;
6666
}
6767

68-
public A2ASendMessageExecutor(BaseAgent agent, String appName, Duration agentTimeout) {
69-
InMemorySessionService sessionService = new InMemorySessionService();
68+
/**
69+
* Creates an A2A send message executor with explicit service dependencies.
70+
*
71+
* <p>This constructor requires all service implementations to be provided explicitly, enabling
72+
* flexible deployment configurations (e.g., persistent sessions, distributed artifacts).
73+
*
74+
* <p><strong>Note:</strong> In version 0.5.1, the constructor signature changed to require
75+
* explicit service injection. Previously, services were created internally as in-memory
76+
* implementations.
77+
*
78+
* <p><strong>For Spring Boot applications:</strong> Use {@link
79+
* com.google.adk.webservice.A2ARemoteConfiguration} which automatically provides service beans
80+
* with sensible defaults. Direct instantiation is typically only needed for custom frameworks or
81+
* testing.
82+
*
83+
* <p>Example usage:
84+
*
85+
* <pre>{@code
86+
* A2ASendMessageExecutor executor = new A2ASendMessageExecutor(
87+
* myAgent,
88+
* "my-app",
89+
* Duration.ofSeconds(30),
90+
* new InMemorySessionService(), // or DatabaseSessionService for persistence
91+
* new InMemoryArtifactService(), // or S3ArtifactService for distributed storage
92+
* new InMemoryMemoryService()); // or RedisMemoryService for shared state
93+
* }</pre>
94+
*
95+
* @param agent the agent to execute when processing messages
96+
* @param appName the application name used for session identification
97+
* @param agentTimeout maximum duration to wait for agent execution before timing out
98+
* @param sessionService service for managing conversation sessions (required, non-null)
99+
* @param artifactService service for storing and retrieving artifacts (required, non-null)
100+
* @param memoryService service for managing agent memory/state (required, non-null)
101+
*/
102+
public A2ASendMessageExecutor(
103+
BaseAgent agent,
104+
String appName,
105+
Duration agentTimeout,
106+
BaseSessionService sessionService,
107+
BaseArtifactService artifactService,
108+
BaseMemoryService memoryService) {
70109
Runner runnerInstance =
71-
new Runner(
72-
agent,
73-
appName,
74-
new InMemoryArtifactService(),
75-
sessionService,
76-
new InMemoryMemoryService());
110+
new Runner(agent, appName, artifactService, sessionService, memoryService);
77111
this.sessionService = sessionService;
78112
this.appName = appName;
79113
this.runner = runnerInstance;
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
package com.google.adk.a2a;
2+
3+
import static com.google.common.truth.Truth.assertThat;
4+
5+
import com.google.adk.agents.BaseAgent;
6+
import com.google.adk.agents.InvocationContext;
7+
import com.google.adk.artifacts.InMemoryArtifactService;
8+
import com.google.adk.events.Event;
9+
import com.google.adk.memory.InMemoryMemoryService;
10+
import com.google.adk.sessions.InMemorySessionService;
11+
import com.google.common.collect.ImmutableList;
12+
import com.google.genai.types.Content;
13+
import com.google.genai.types.Part;
14+
import io.a2a.spec.Message;
15+
import io.a2a.spec.TextPart;
16+
import io.reactivex.rxjava3.core.Flowable;
17+
import io.reactivex.rxjava3.core.Single;
18+
import java.time.Duration;
19+
import java.util.List;
20+
import java.util.UUID;
21+
import org.junit.Test;
22+
import org.junit.runner.RunWith;
23+
import org.junit.runners.JUnit4;
24+
25+
@RunWith(JUnit4.class)
26+
public class A2ASendMessageExecutorAdvancedTest {
27+
28+
private InMemorySessionService sessionService;
29+
30+
@Test
31+
public void execute_withCustomStrategy_usesStrategy() {
32+
InMemorySessionService sessionService = new InMemorySessionService();
33+
34+
A2ASendMessageExecutor executor = new A2ASendMessageExecutor(sessionService, "test-app");
35+
36+
A2ASendMessageExecutor.AgentExecutionStrategy customStrategy =
37+
(userId, sessionId, userContent, runConfig, invocationId) -> {
38+
Event customEvent =
39+
Event.builder()
40+
.id(UUID.randomUUID().toString())
41+
.invocationId(invocationId)
42+
.author("agent")
43+
.content(
44+
Content.builder()
45+
.role("model")
46+
.parts(
47+
ImmutableList.of(
48+
Part.builder().text("Custom strategy response").build()))
49+
.build())
50+
.build();
51+
return Single.just(ImmutableList.of(customEvent));
52+
};
53+
54+
Message request =
55+
new Message.Builder()
56+
.messageId("msg-1")
57+
.contextId("ctx-1")
58+
.role(Message.Role.USER)
59+
.parts(List.of(new TextPart("Test")))
60+
.build();
61+
62+
Message response = executor.execute(request, customStrategy).blockingGet();
63+
64+
assertThat(response).isNotNull();
65+
assertThat(response.getParts()).isNotEmpty();
66+
assertThat(((TextPart) response.getParts().get(0)).getText())
67+
.contains("Custom strategy response");
68+
}
69+
70+
private A2ASendMessageExecutor createExecutorWithAgent() {
71+
BaseAgent agent = createSimpleAgent();
72+
sessionService = new InMemorySessionService();
73+
return new A2ASendMessageExecutor(
74+
agent,
75+
"test-app",
76+
Duration.ofSeconds(30),
77+
sessionService,
78+
new InMemoryArtifactService(),
79+
new InMemoryMemoryService());
80+
}
81+
82+
@Test
83+
public void execute_withNullMessage_generatesDefaultContext() {
84+
A2ASendMessageExecutor executor = createExecutorWithAgent();
85+
86+
Message response = executor.execute(null).blockingGet();
87+
88+
assertThat(response).isNotNull();
89+
assertThat(response.getContextId()).isNotNull();
90+
assertThat(response.getContextId()).isNotEmpty();
91+
}
92+
93+
@Test
94+
public void execute_withEmptyContextId_generatesNewContext() {
95+
A2ASendMessageExecutor executor = createExecutorWithAgent();
96+
97+
Message request =
98+
new Message.Builder()
99+
.messageId("msg-1")
100+
.role(Message.Role.USER)
101+
.parts(List.of(new TextPart("Test")))
102+
.build();
103+
104+
Message response = executor.execute(request).blockingGet();
105+
106+
assertThat(response).isNotNull();
107+
assertThat(response.getContextId()).isNotNull();
108+
assertThat(response.getContextId()).isNotEmpty();
109+
}
110+
111+
@Test
112+
public void execute_withProvidedContextId_preservesContext() {
113+
A2ASendMessageExecutor executor = createExecutorWithAgent();
114+
115+
String contextId = "my-custom-context";
116+
Message request =
117+
new Message.Builder()
118+
.messageId("msg-1")
119+
.contextId(contextId)
120+
.role(Message.Role.USER)
121+
.parts(List.of(new TextPart("Test")))
122+
.build();
123+
124+
Message response = executor.execute(request).blockingGet();
125+
126+
assertThat(response).isNotNull();
127+
assertThat(response.getContextId()).isEqualTo(contextId);
128+
}
129+
130+
@Test
131+
public void execute_multipleRequests_maintainsSession() {
132+
A2ASendMessageExecutor executor = createExecutorWithAgent();
133+
134+
String contextId = "persistent-context";
135+
136+
Message request1 =
137+
new Message.Builder()
138+
.messageId("msg-1")
139+
.contextId(contextId)
140+
.role(Message.Role.USER)
141+
.parts(List.of(new TextPart("First message")))
142+
.build();
143+
144+
Message response1 = executor.execute(request1).blockingGet();
145+
assertThat(response1.getContextId()).isEqualTo(contextId);
146+
147+
Message request2 =
148+
new Message.Builder()
149+
.messageId("msg-2")
150+
.contextId(contextId)
151+
.role(Message.Role.USER)
152+
.parts(List.of(new TextPart("Second message")))
153+
.build();
154+
155+
Message response2 = executor.execute(request2).blockingGet();
156+
assertThat(response2.getContextId()).isEqualTo(contextId);
157+
}
158+
159+
@Test
160+
public void execute_withoutRunnerConfig_throwsException() {
161+
InMemorySessionService sessionService = new InMemorySessionService();
162+
163+
A2ASendMessageExecutor executor = new A2ASendMessageExecutor(sessionService, "test-app");
164+
165+
Message request =
166+
new Message.Builder()
167+
.messageId("msg-1")
168+
.contextId("ctx-1")
169+
.role(Message.Role.USER)
170+
.parts(List.of(new TextPart("Test")))
171+
.build();
172+
173+
try {
174+
executor.execute(request).blockingGet();
175+
assertThat(false).isTrue();
176+
} catch (IllegalStateException e) {
177+
assertThat(e.getMessage()).contains("Runner-based handle invoked without configured runner");
178+
}
179+
}
180+
181+
@Test
182+
public void execute_errorInStrategy_returnsErrorResponse() {
183+
InMemorySessionService sessionService = new InMemorySessionService();
184+
185+
A2ASendMessageExecutor executor = new A2ASendMessageExecutor(sessionService, "test-app");
186+
187+
A2ASendMessageExecutor.AgentExecutionStrategy failingStrategy =
188+
(userId, sessionId, userContent, runConfig, invocationId) -> {
189+
return Single.error(new RuntimeException("Strategy failed"));
190+
};
191+
192+
Message request =
193+
new Message.Builder()
194+
.messageId("msg-1")
195+
.contextId("ctx-1")
196+
.role(Message.Role.USER)
197+
.parts(List.of(new TextPart("Test")))
198+
.build();
199+
200+
Message response = executor.execute(request, failingStrategy).blockingGet();
201+
202+
assertThat(response).isNotNull();
203+
assertThat(response.getParts()).isNotEmpty();
204+
assertThat(((TextPart) response.getParts().get(0)).getText()).contains("Error:");
205+
assertThat(((TextPart) response.getParts().get(0)).getText()).contains("Strategy failed");
206+
}
207+
208+
private BaseAgent createSimpleAgent() {
209+
return new BaseAgent("test", "test agent", ImmutableList.of(), null, null) {
210+
@Override
211+
protected Flowable<Event> runAsyncImpl(InvocationContext ctx) {
212+
return Flowable.just(
213+
Event.builder()
214+
.content(
215+
Content.builder()
216+
.role("model")
217+
.parts(
218+
ImmutableList.of(
219+
com.google.genai.types.Part.builder().text("Response").build()))
220+
.build())
221+
.build());
222+
}
223+
224+
@Override
225+
protected Flowable<Event> runLiveImpl(InvocationContext ctx) {
226+
return Flowable.empty();
227+
}
228+
};
229+
}
230+
}

0 commit comments

Comments
 (0)