6969import io .a2a .spec .UnsupportedOperationError ;
7070import io .a2a .util .Utils ;
7171import io .quarkus .test .junit .QuarkusTest ;
72+ import io .restassured .RestAssured ;
73+ import io .restassured .specification .RequestSpecification ;
7274
7375import org .junit .jupiter .api .Test ;
7476
@@ -113,12 +115,24 @@ public class A2AServerResourceTest {
113115
114116 @ Test
115117 public void testGetTaskSuccess () {
118+ testGetTask ();
119+ }
120+
121+ private void testGetTask () {
122+ testGetTask (null );
123+ }
124+
125+ private void testGetTask (String mediaType ) {
116126 taskStore .save (MINIMAL_TASK );
117127 try {
118128 GetTaskRequest request = new GetTaskRequest ("1" , new TaskQueryParams (MINIMAL_TASK .getId ()));
119- GetTaskResponse response = given ()
129+ RequestSpecification requestSpecification = RestAssured . given ()
120130 .contentType (MediaType .APPLICATION_JSON )
121- .body (request )
131+ .body (request );
132+ if (mediaType != null ) {
133+ requestSpecification = requestSpecification .accept (mediaType );
134+ }
135+ GetTaskResponse response = requestSpecification
122136 .when ()
123137 .post ("/" )
124138 .then ()
@@ -285,34 +299,7 @@ public void testSendMessageExistingTaskSuccess() {
285299
286300 @ Test
287301 public void testSendMessageStreamNewMessageSuccess () throws Exception {
288- Message message = new Message .Builder (MESSAGE )
289- .taskId (MINIMAL_TASK .getId ())
290- .contextId (MINIMAL_TASK .getContextId ())
291- .build ();
292- SendStreamingMessageRequest request = new SendStreamingMessageRequest (
293- "1" , new MessageSendParams (message , null , null ));
294- Client client = ClientBuilder .newClient ();
295- WebTarget target = client .target ("http://localhost:8081/" );
296- Response response = target .request (MediaType .SERVER_SENT_EVENTS ).post (Entity .json (request ));
297- InputStream inputStream = response .readEntity (InputStream .class );
298- boolean dataRead = false ;
299- try (BufferedReader reader = new BufferedReader (new InputStreamReader (inputStream , StandardCharsets .UTF_8 ))) {
300- String line ;
301- while ((line = reader .readLine ()) != null ) {
302- if (line .startsWith ("data: " )) {
303- SendStreamingMessageResponse sendStreamingMessageResponse = Utils .OBJECT_MAPPER .readValue (line .substring ("data: " .length ()).trim (), SendStreamingMessageResponse .class );
304- assertNull (sendStreamingMessageResponse .getError ());
305- Message messageResponse = (Message ) sendStreamingMessageResponse .getResult ();
306- assertEquals (MESSAGE .getMessageId (), messageResponse .getMessageId ());
307- assertEquals (MESSAGE .getRole (), messageResponse .getRole ());
308- Part <?> part = messageResponse .getParts ().get (0 );
309- assertEquals (Part .Kind .TEXT , part .getKind ());
310- assertEquals ("test message" , ((TextPart ) part ).getText ());
311- dataRead = true ;
312- }
313- }
314- }
315- assertTrue (dataRead );
302+ testSendStreamingMessage ();
316303 }
317304
318305 @ Test
@@ -327,7 +314,7 @@ public void testSendMessageStreamExistingTaskSuccess() {
327314 "1" , new MessageSendParams (message , null , null ));
328315 Client client = ClientBuilder .newClient ();
329316 WebTarget target = client .target ("http://localhost:8081/" );
330- Response response = target .request (MediaType . SERVER_SENT_EVENTS ).post (Entity .json (request ));
317+ Response response = target .request ().post (Entity .json (request ));
331318 InputStream inputStream = response .readEntity (InputStream .class );
332319 boolean dataRead = false ;
333320 try (BufferedReader reader = new BufferedReader (new InputStreamReader (inputStream , StandardCharsets .UTF_8 ))) {
@@ -445,7 +432,7 @@ public void testResubscribeExistingTaskSuccess() throws Exception {
445432 Client client = ClientBuilder .newClient ();
446433 WebTarget target = client .target ("http://localhost:8081/" );
447434 taskResubscriptionRequestSent .countDown ();
448- Response response = target .request (MediaType . SERVER_SENT_EVENTS ).post (Entity .json (taskResubscriptionRequest ));
435+ Response response = target .request ().post (Entity .json (taskResubscriptionRequest ));
449436 InputStream inputStream = response .readEntity (InputStream .class );
450437 try (BufferedReader reader = new BufferedReader (new InputStreamReader (inputStream , StandardCharsets .UTF_8 ))) {
451438 String line ;
@@ -532,7 +519,7 @@ public void testResubscribeNoExistingTaskError() throws Exception {
532519 TaskResubscriptionRequest request = new TaskResubscriptionRequest ("1" , new TaskIdParams ("non-existent-task" ));
533520 Client client = ClientBuilder .newClient ();
534521 WebTarget target = client .target ("http://localhost:8081/" );
535- Response response = target .request (MediaType . SERVER_SENT_EVENTS ).post (Entity .json (request ));
522+ Response response = target .request ().post (Entity .json (request ));
536523 InputStream inputStream = response .readEntity (InputStream .class );
537524 boolean dataRead = false ;
538525 try (BufferedReader reader = new BufferedReader (new InputStreamReader (inputStream , StandardCharsets .UTF_8 ))) {
@@ -562,7 +549,6 @@ public void testError() {
562549 "1" , new MessageSendParams (message , null , null ));
563550 SendMessageResponse response = given ()
564551 .contentType (MediaType .APPLICATION_JSON )
565- .accept (MediaType .APPLICATION_JSON )
566552 .body (request )
567553 .when ()
568554 .post ("/" )
@@ -729,4 +715,54 @@ public void testInvalidJSONRPCRequestNonExistentMethod() {
729715 assertNotNull (response .getError ());
730716 assertEquals (new MethodNotFoundError ().getCode (), response .getError ().getCode ());
731717 }
718+
719+ @ Test
720+ public void testNonStreamingMethodWithAcceptHeader () {
721+ testGetTask (MediaType .APPLICATION_JSON );
722+ }
723+
724+ @ Test
725+ public void testStreamingMethodWithAcceptHeader () throws Exception {
726+ testSendStreamingMessage (MediaType .SERVER_SENT_EVENTS );
727+ }
728+
729+ private void testSendStreamingMessage () throws Exception {
730+ testSendStreamingMessage (null );
731+ }
732+
733+ private void testSendStreamingMessage (String mediaType ) throws Exception {
734+ Message message = new Message .Builder (MESSAGE )
735+ .taskId (MINIMAL_TASK .getId ())
736+ .contextId (MINIMAL_TASK .getContextId ())
737+ .build ();
738+ SendStreamingMessageRequest request = new SendStreamingMessageRequest (
739+ "1" , new MessageSendParams (message , null , null ));
740+ Client client = ClientBuilder .newClient ();
741+ WebTarget target = client .target ("http://localhost:8081/" );
742+ Response response ;
743+ if (mediaType != null ) {
744+ response = target .request (mediaType ).post (Entity .json (request ));
745+ } else {
746+ response = target .request ().post (Entity .json (request ));
747+ }
748+ InputStream inputStream = response .readEntity (InputStream .class );
749+ boolean dataRead = false ;
750+ try (BufferedReader reader = new BufferedReader (new InputStreamReader (inputStream , StandardCharsets .UTF_8 ))) {
751+ String line ;
752+ while ((line = reader .readLine ()) != null ) {
753+ if (line .startsWith ("data: " )) {
754+ SendStreamingMessageResponse sendStreamingMessageResponse = Utils .OBJECT_MAPPER .readValue (line .substring ("data: " .length ()).trim (), SendStreamingMessageResponse .class );
755+ assertNull (sendStreamingMessageResponse .getError ());
756+ Message messageResponse = (Message ) sendStreamingMessageResponse .getResult ();
757+ assertEquals (MESSAGE .getMessageId (), messageResponse .getMessageId ());
758+ assertEquals (MESSAGE .getRole (), messageResponse .getRole ());
759+ Part <?> part = messageResponse .getParts ().get (0 );
760+ assertEquals (Part .Kind .TEXT , part .getKind ());
761+ assertEquals ("test message" , ((TextPart ) part ).getText ());
762+ dataRead = true ;
763+ }
764+ }
765+ }
766+ assertTrue (dataRead );
767+ }
732768}
0 commit comments