Skip to content

Commit d8f9e4c

Browse files
authored
Merge pull request #107 from fjuma/spec
Use a ContainerRequestFilter to make sure streaming and non-streaming requests are handled properly
2 parents 7e35e86 + a026098 commit d8f9e4c

File tree

2 files changed

+133
-34
lines changed

2 files changed

+133
-34
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package io.a2a.server.apps;
2+
3+
import static io.a2a.spec.A2A.CANCEL_TASK_METHOD;
4+
import static io.a2a.spec.A2A.GET_TASK_METHOD;
5+
import static io.a2a.spec.A2A.GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD;
6+
import static io.a2a.spec.A2A.SEND_MESSAGE_METHOD;
7+
import static io.a2a.spec.A2A.SEND_STREAMING_MESSAGE_METHOD;
8+
import static io.a2a.spec.A2A.SEND_TASK_RESUBSCRIPTION_METHOD;
9+
import static io.a2a.spec.A2A.SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD;
10+
11+
import java.io.ByteArrayInputStream;
12+
import java.io.IOException;
13+
import java.io.InputStream;
14+
15+
import jakarta.ws.rs.container.ContainerRequestContext;
16+
import jakarta.ws.rs.container.ContainerRequestFilter;
17+
import jakarta.ws.rs.container.PreMatching;
18+
import jakarta.ws.rs.core.MediaType;
19+
import jakarta.ws.rs.ext.Provider;
20+
21+
@Provider
22+
@PreMatching
23+
public class A2ARequestFilter implements ContainerRequestFilter {
24+
25+
@Override
26+
public void filter(ContainerRequestContext requestContext) {
27+
if (requestContext.getMethod().equals("POST") && requestContext.hasEntity()) {
28+
try (InputStream entityInputStream = requestContext.getEntityStream()) {
29+
byte[] requestBodyBytes = entityInputStream.readAllBytes();
30+
String requestBody = new String(requestBodyBytes);
31+
// ensure the request is treated as a streaming request or a non-streaming request
32+
// based on the method in the request body
33+
if (isStreamingRequest(requestBody)) {
34+
putAcceptHeader(requestContext, MediaType.SERVER_SENT_EVENTS);
35+
} else if (isNonStreamingRequest(requestBody)) {
36+
putAcceptHeader(requestContext, MediaType.APPLICATION_JSON);
37+
}
38+
// reset the entity stream
39+
requestContext.setEntityStream(new ByteArrayInputStream(requestBodyBytes));
40+
} catch(IOException e){
41+
throw new RuntimeException("Unable to read the request body");
42+
}
43+
}
44+
}
45+
46+
private static boolean isStreamingRequest(String requestBody) {
47+
return requestBody.contains(SEND_STREAMING_MESSAGE_METHOD) ||
48+
requestBody.contains(SEND_TASK_RESUBSCRIPTION_METHOD);
49+
}
50+
51+
private static boolean isNonStreamingRequest(String requestBody) {
52+
return requestBody.contains(GET_TASK_METHOD) ||
53+
requestBody.contains(CANCEL_TASK_METHOD) ||
54+
requestBody.contains(SEND_MESSAGE_METHOD) ||
55+
requestBody.contains(SET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD) ||
56+
requestBody.contains(GET_TASK_PUSH_NOTIFICATION_CONFIG_METHOD);
57+
}
58+
59+
private static void putAcceptHeader(ContainerRequestContext requestContext, String mediaType) {
60+
requestContext.getHeaders().putSingle("Accept", mediaType);
61+
}
62+
63+
}

src/test/java/io/a2a/server/apps/A2AServerResourceTest.java

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
import io.a2a.spec.UnsupportedOperationError;
7070
import io.a2a.util.Utils;
7171
import io.quarkus.test.junit.QuarkusTest;
72+
import io.restassured.RestAssured;
73+
import io.restassured.specification.RequestSpecification;
7274

7375
import org.junit.jupiter.api.Test;
7476

@@ -113,12 +115,24 @@ public class A2AServerResourceTest {
113115

114116
@Test
115117
public void testGetTaskSuccess() {
118+
testGetTask();
119+
}
120+
121+
private void testGetTask() {
122+
testGetTask(null);
123+
}
124+
125+
private void testGetTask(String mediaType) {
116126
taskStore.save(MINIMAL_TASK);
117127
try {
118128
GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId()));
119-
GetTaskResponse response = given()
129+
RequestSpecification requestSpecification = RestAssured.given()
120130
.contentType(MediaType.APPLICATION_JSON)
121-
.body(request)
131+
.body(request);
132+
if (mediaType != null) {
133+
requestSpecification = requestSpecification.accept(mediaType);
134+
}
135+
GetTaskResponse response = requestSpecification
122136
.when()
123137
.post("/")
124138
.then()
@@ -285,34 +299,7 @@ public void testSendMessageExistingTaskSuccess() {
285299

286300
@Test
287301
public void testSendMessageStreamNewMessageSuccess() throws Exception {
288-
Message message = new Message.Builder(MESSAGE)
289-
.taskId(MINIMAL_TASK.getId())
290-
.contextId(MINIMAL_TASK.getContextId())
291-
.build();
292-
SendStreamingMessageRequest request = new SendStreamingMessageRequest(
293-
"1", new MessageSendParams(message, null, null));
294-
Client client = ClientBuilder.newClient();
295-
WebTarget target = client.target("http://localhost:8081/");
296-
Response response = target.request(MediaType.SERVER_SENT_EVENTS).post(Entity.json(request));
297-
InputStream inputStream = response.readEntity(InputStream.class);
298-
boolean dataRead = false;
299-
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
300-
String line;
301-
while ((line = reader.readLine()) != null) {
302-
if (line.startsWith("data: ")) {
303-
SendStreamingMessageResponse sendStreamingMessageResponse = Utils.OBJECT_MAPPER.readValue(line.substring("data: ".length()).trim(), SendStreamingMessageResponse.class);
304-
assertNull(sendStreamingMessageResponse.getError());
305-
Message messageResponse = (Message) sendStreamingMessageResponse.getResult();
306-
assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId());
307-
assertEquals(MESSAGE.getRole(), messageResponse.getRole());
308-
Part<?> part = messageResponse.getParts().get(0);
309-
assertEquals(Part.Kind.TEXT, part.getKind());
310-
assertEquals("test message", ((TextPart) part).getText());
311-
dataRead = true;
312-
}
313-
}
314-
}
315-
assertTrue(dataRead);
302+
testSendStreamingMessage();
316303
}
317304

318305
@Test
@@ -327,7 +314,7 @@ public void testSendMessageStreamExistingTaskSuccess() {
327314
"1", new MessageSendParams(message, null, null));
328315
Client client = ClientBuilder.newClient();
329316
WebTarget target = client.target("http://localhost:8081/");
330-
Response response = target.request(MediaType.SERVER_SENT_EVENTS).post(Entity.json(request));
317+
Response response = target.request().post(Entity.json(request));
331318
InputStream inputStream = response.readEntity(InputStream.class);
332319
boolean dataRead = false;
333320
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
@@ -445,7 +432,7 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
445432
Client client = ClientBuilder.newClient();
446433
WebTarget target = client.target("http://localhost:8081/");
447434
taskResubscriptionRequestSent.countDown();
448-
Response response = target.request(MediaType.SERVER_SENT_EVENTS).post(Entity.json(taskResubscriptionRequest));
435+
Response response = target.request().post(Entity.json(taskResubscriptionRequest));
449436
InputStream inputStream = response.readEntity(InputStream.class);
450437
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
451438
String line;
@@ -532,7 +519,7 @@ public void testResubscribeNoExistingTaskError() throws Exception {
532519
TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams("non-existent-task"));
533520
Client client = ClientBuilder.newClient();
534521
WebTarget target = client.target("http://localhost:8081/");
535-
Response response = target.request(MediaType.SERVER_SENT_EVENTS).post(Entity.json(request));
522+
Response response = target.request().post(Entity.json(request));
536523
InputStream inputStream = response.readEntity(InputStream.class);
537524
boolean dataRead = false;
538525
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
@@ -562,7 +549,6 @@ public void testError() {
562549
"1", new MessageSendParams(message, null, null));
563550
SendMessageResponse response = given()
564551
.contentType(MediaType.APPLICATION_JSON)
565-
.accept(MediaType.APPLICATION_JSON)
566552
.body(request)
567553
.when()
568554
.post("/")
@@ -729,4 +715,54 @@ public void testInvalidJSONRPCRequestNonExistentMethod() {
729715
assertNotNull(response.getError());
730716
assertEquals(new MethodNotFoundError().getCode(), response.getError().getCode());
731717
}
718+
719+
@Test
720+
public void testNonStreamingMethodWithAcceptHeader() {
721+
testGetTask(MediaType.APPLICATION_JSON);
722+
}
723+
724+
@Test
725+
public void testStreamingMethodWithAcceptHeader() throws Exception {
726+
testSendStreamingMessage(MediaType.SERVER_SENT_EVENTS);
727+
}
728+
729+
private void testSendStreamingMessage() throws Exception {
730+
testSendStreamingMessage(null);
731+
}
732+
733+
private void testSendStreamingMessage(String mediaType) throws Exception {
734+
Message message = new Message.Builder(MESSAGE)
735+
.taskId(MINIMAL_TASK.getId())
736+
.contextId(MINIMAL_TASK.getContextId())
737+
.build();
738+
SendStreamingMessageRequest request = new SendStreamingMessageRequest(
739+
"1", new MessageSendParams(message, null, null));
740+
Client client = ClientBuilder.newClient();
741+
WebTarget target = client.target("http://localhost:8081/");
742+
Response response;
743+
if (mediaType != null) {
744+
response = target.request(mediaType).post(Entity.json(request));
745+
} else {
746+
response = target.request().post(Entity.json(request));
747+
}
748+
InputStream inputStream = response.readEntity(InputStream.class);
749+
boolean dataRead = false;
750+
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
751+
String line;
752+
while ((line = reader.readLine()) != null) {
753+
if (line.startsWith("data: ")) {
754+
SendStreamingMessageResponse sendStreamingMessageResponse = Utils.OBJECT_MAPPER.readValue(line.substring("data: ".length()).trim(), SendStreamingMessageResponse.class);
755+
assertNull(sendStreamingMessageResponse.getError());
756+
Message messageResponse = (Message) sendStreamingMessageResponse.getResult();
757+
assertEquals(MESSAGE.getMessageId(), messageResponse.getMessageId());
758+
assertEquals(MESSAGE.getRole(), messageResponse.getRole());
759+
Part<?> part = messageResponse.getParts().get(0);
760+
assertEquals(Part.Kind.TEXT, part.getKind());
761+
assertEquals("test message", ((TextPart) part).getText());
762+
dataRead = true;
763+
}
764+
}
765+
}
766+
assertTrue(dataRead);
767+
}
732768
}

0 commit comments

Comments
 (0)