Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
import io.a2a.util.Utils;
import io.agentscope.core.a2a.server.AgentScopeA2aServer;
import io.agentscope.core.a2a.server.transport.jsonrpc.JsonRpcTransportWrapper;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
Expand All @@ -54,8 +52,8 @@ public A2aJsonRpcController(AgentScopeA2aServer agentScopeA2aServer) {
consumes = MediaType.APPLICATION_JSON_VALUE,
produces = {MediaType.APPLICATION_JSON_VALUE, MediaType.TEXT_EVENT_STREAM_VALUE})
@ResponseBody
public Object handleRequest(@RequestBody String body, HttpServletRequest httpRequest) {
Map<String, String> header = getHeaders(httpRequest);
public Object handleRequest(
@RequestBody String body, @RequestHeader Map<String, String> header) {
Object result = getJsonRpcHandler().handleRequest(body, header, Map.of());
if (result instanceof Flux<?> fluxResult) {
return fluxResult
Expand All @@ -76,17 +74,6 @@ private JsonRpcTransportWrapper getJsonRpcHandler() {
return jsonRpcHandler;
}

private Map<String, String> getHeaders(HttpServletRequest request) {
Map<String, String> headers = new HashMap<>();
Enumeration<String> headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String headerName = headerNames.nextElement();
String headerValue = request.getHeader(headerName);
headers.put(headerName, headerValue);
}
return headers;
}

private ServerSentEvent<String> convertToSse(JSONRPCResponse<?> response) {
try {
String data = Utils.OBJECT_MAPPER.writeValueAsString(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
import io.a2a.spec.TransportProtocol;
import io.agentscope.core.a2a.server.AgentScopeA2aServer;
import io.agentscope.core.a2a.server.transport.jsonrpc.JsonRpcTransportWrapper;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
Expand All @@ -57,7 +56,7 @@ class A2aJsonRpcControllerTest {

@Mock private JsonRpcTransportWrapper jsonRpcTransportWrapper;

@Mock private HttpServletRequest httpRequest;
private Map<String, String> headers;

@BeforeEach
void setUp() {
Expand All @@ -67,6 +66,7 @@ void setUp() {
eq(TransportProtocol.JSONRPC.asString()),
eq(JsonRpcTransportWrapper.class)))
.thenReturn(jsonRpcTransportWrapper);
headers = Collections.emptyMap();
}

@Nested
Expand All @@ -79,11 +79,10 @@ void shouldHandleJsonRpcRequestAndReturnPlainObject() {
String requestBody = "{\"method\": \"test\"}";
String responseBody = "{\"result\": \"success\"}";

when(httpRequest.getHeaderNames()).thenReturn(Collections.emptyEnumeration());
when(jsonRpcTransportWrapper.handleRequest(anyString(), anyMap(), anyMap()))
.thenReturn(responseBody);

Object result = controller.handleRequest(requestBody, httpRequest);
Object result = controller.handleRequest(requestBody, headers);

assertEquals(responseBody, result);

Expand Down Expand Up @@ -111,11 +110,10 @@ void shouldHandleJsonRpcRequestAndReturnFluxWithJsonRpcResponse() {
Message message = A2A.toAgentMessage("test");
SendStreamingMessageResponse response = new SendStreamingMessageResponse(1, message);

when(httpRequest.getHeaderNames()).thenReturn(Collections.emptyEnumeration());
when(jsonRpcTransportWrapper.handleRequest(anyString(), anyMap(), anyMap()))
.thenReturn(Flux.just(response));

Object result = controller.handleRequest(requestBody, httpRequest);
Object result = controller.handleRequest(requestBody, headers);

assertTrue(result instanceof Flux);

Expand All @@ -141,17 +139,13 @@ void shouldHandleRequestWithHeaders() {
String responseBody = "{\"result\": \"success\"}";

// Mock headers
Enumeration<String> headerNames =
Collections.enumeration(
java.util.Arrays.asList("Content-Type", "Authorization"));
when(httpRequest.getHeaderNames()).thenReturn(headerNames);
when(httpRequest.getHeader("Content-Type")).thenReturn("application/json");
when(httpRequest.getHeader("Authorization")).thenReturn("Bearer token");
Map<String, String> header =
Map.of("Content-Type", "application/json", "Authorization", "Bearer token");

when(jsonRpcTransportWrapper.handleRequest(anyString(), anyMap(), anyMap()))
.thenReturn(responseBody);

Object result = controller.handleRequest(requestBody, httpRequest);
Object result = controller.handleRequest(requestBody, header);

assertEquals(responseBody, result);

Expand Down Expand Up @@ -179,16 +173,12 @@ void shouldExtractHeadersFromHttpServletRequest() {
String requestBody = "{\"method\": \"test\"}";
String responseBody = "{\"result\": \"success\"}";

Enumeration<String> headerNames =
Collections.enumeration(java.util.Arrays.asList("Header1", "Header2"));
when(httpRequest.getHeaderNames()).thenReturn(headerNames);
when(httpRequest.getHeader("Header1")).thenReturn("Value1");
when(httpRequest.getHeader("Header2")).thenReturn("Value2");
Map<String, String> header = Map.of("Header1", "Value1", "Header2", "Value2");

when(jsonRpcTransportWrapper.handleRequest(anyString(), anyMap(), anyMap()))
.thenReturn(responseBody);

controller.handleRequest(requestBody, httpRequest);
controller.handleRequest(requestBody, header);

ArgumentCaptor<java.util.Map<String, String>> headersCaptor =
ArgumentCaptor.forClass(java.util.Map.class);
Expand All @@ -206,11 +196,10 @@ void shouldHandleEmptyHeaders() {
String requestBody = "{\"method\": \"test\"}";
String responseBody = "{\"result\": \"success\"}";

when(httpRequest.getHeaderNames()).thenReturn(Collections.emptyEnumeration());
when(jsonRpcTransportWrapper.handleRequest(anyString(), anyMap(), anyMap()))
.thenReturn(responseBody);

controller.handleRequest(requestBody, httpRequest);
controller.handleRequest(requestBody, headers);

ArgumentCaptor<java.util.Map<String, String>> headersCaptor =
ArgumentCaptor.forClass(java.util.Map.class);
Expand All @@ -232,16 +221,15 @@ void shouldLazilyInitializeJsonRpcHandler() {
String requestBody = "{\"method\": \"test\"}";
String responseBody = "{\"result\": \"success\"}";

when(httpRequest.getHeaderNames()).thenReturn(Collections.emptyEnumeration());
when(jsonRpcTransportWrapper.handleRequest(anyString(), anyMap(), anyMap()))
.thenReturn(responseBody);

// First call should initialize the handler
Object result1 = controller.handleRequest(requestBody, httpRequest);
Object result1 = controller.handleRequest(requestBody, headers);
assertEquals(responseBody, result1);

// Second call should reuse the same handler
Object result2 = controller.handleRequest(requestBody, httpRequest);
Object result2 = controller.handleRequest(requestBody, headers);
assertEquals(responseBody, result2);

// Should only fetch the transport wrapper once
Expand Down
Loading