Skip to content

Commit f36db5e

Browse files
authored
fix: fixing the handling of historyLength being set to 0 by default in gRPC (#429)
* Using int instead of Integer and setting the default value to 0. * If historyLength == 0 then we should return the full history * Removing the handling of null Issue: #423 Fixes #423 🦕 Signed-off-by: Emmanuel Hugonnet <[email protected]>
1 parent 425508d commit f36db5e

File tree

8 files changed

+17
-24
lines changed

8 files changed

+17
-24
lines changed

client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package io.a2a.client.transport.grpc;
22

3-
import static io.a2a.grpc.A2AServiceGrpc.A2AServiceBlockingV2Stub;
4-
import static io.a2a.grpc.A2AServiceGrpc.A2AServiceStub;
5-
import static io.a2a.grpc.utils.ProtoUtils.FromProto;
6-
import static io.a2a.grpc.utils.ProtoUtils.ToProto;
73
import static io.a2a.util.Assert.checkNotNullParam;
84

95
import java.util.List;
@@ -127,9 +123,7 @@ public Task getTask(TaskQueryParams request, @Nullable ClientCallContext context
127123

128124
GetTaskRequest.Builder requestBuilder = GetTaskRequest.newBuilder();
129125
requestBuilder.setName("tasks/" + request.id());
130-
if (request.historyLength() != null) {
131-
requestBuilder.setHistoryLength(request.historyLength());
132-
}
126+
requestBuilder.setHistoryLength(request.historyLength());
133127
GetTaskRequest getTaskRequest = requestBuilder.build();
134128
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, getTaskRequest,
135129
agentCard, context);

client/transport/rest/src/main/java/io/a2a/client/transport/rest/RestTransport.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public Task getTask(TaskQueryParams taskQueryParams, @Nullable ClientCallContext
125125
agentCard, context);
126126
try {
127127
String url;
128-
if (taskQueryParams.historyLength() != null) {
128+
if (taskQueryParams.historyLength() > 0) {
129129
url = agentUrl + String.format("/v1/tasks/%1s?historyLength=%2d", taskQueryParams.id(), taskQueryParams.historyLength());
130130
} else {
131131
url = agentUrl + String.format("/v1/tasks/%1s", taskQueryParams.id());

reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ public void getTask(RoutingContext rc) {
122122
if (taskId == null || taskId.isEmpty()) {
123123
response = jsonRestHandler.createErrorResponse(new InvalidParamsError("bad task id"));
124124
} else {
125-
Integer historyLength = null;
125+
int historyLength = 0;
126126
if (rc.request().params().contains("history_length")) {
127-
historyLength = Integer.valueOf(rc.request().params().get("history_length"));
127+
historyLength = Integer.parseInt(rc.request().params().get("history_length"));
128128
}
129129
response = jsonRestHandler.getTask(taskId, historyLength, context);
130130
}

reference/rest/src/test/java/io/a2a/server/rest/quarkus/A2AServerRoutesTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import static org.junit.jupiter.api.Assertions.assertEquals;
55
import static org.junit.jupiter.api.Assertions.assertNotNull;
66
import static org.mockito.ArgumentMatchers.any;
7+
import static org.mockito.ArgumentMatchers.anyInt;
78
import static org.mockito.ArgumentMatchers.anyString;
89
import static org.mockito.ArgumentMatchers.eq;
910
import static org.mockito.Mockito.mock;
@@ -137,15 +138,15 @@ public void testGetTask_MethodNameSetInContext() {
137138
when(mockHttpResponse.getStatusCode()).thenReturn(200);
138139
when(mockHttpResponse.getContentType()).thenReturn("application/json");
139140
when(mockHttpResponse.getBody()).thenReturn("{test:value}");
140-
when(mockRestHandler.getTask(anyString(), any(), any(ServerCallContext.class))).thenReturn(mockHttpResponse);
141+
when(mockRestHandler.getTask(anyString(), anyInt(), any(ServerCallContext.class))).thenReturn(mockHttpResponse);
141142

142143
ArgumentCaptor<ServerCallContext> contextCaptor = ArgumentCaptor.forClass(ServerCallContext.class);
143144

144145
// Act
145146
routes.getTask(mockRoutingContext);
146147

147148
// Assert
148-
verify(mockRestHandler).getTask(eq("task123"), eq(null), contextCaptor.capture());
149+
verify(mockRestHandler).getTask(eq("task123"), anyInt(), contextCaptor.capture());
149150
ServerCallContext capturedContext = contextCaptor.getValue();
150151
assertNotNull(capturedContext);
151152
assertEquals(GetTaskRequest.METHOD, capturedContext.getState().get(METHOD_NAME_KEY));

server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.function.Supplier;
2020

2121
import jakarta.enterprise.context.ApplicationScoped;
22-
import jakarta.enterprise.inject.Instance;
2322
import jakarta.inject.Inject;
2423

2524
import io.a2a.server.ServerCallContext;
@@ -34,7 +33,6 @@
3433
import io.a2a.server.events.TaskQueueExistsException;
3534
import io.a2a.server.tasks.PushNotificationConfigStore;
3635
import io.a2a.server.tasks.PushNotificationSender;
37-
import io.a2a.server.tasks.TaskStateProvider;
3836
import io.a2a.server.tasks.ResultAggregator;
3937
import io.a2a.server.tasks.TaskManager;
4038
import io.a2a.server.tasks.TaskStore;
@@ -103,10 +101,10 @@ public Task onGetTask(TaskQueryParams params, ServerCallContext context) throws
103101
LOGGER.debug("No task found for {}. Throwing TaskNotFoundError", params.id());
104102
throw new TaskNotFoundError();
105103
}
106-
if (params.historyLength() != null && task.getHistory() != null && params.historyLength() < task.getHistory().size()) {
104+
if (task.getHistory() != null && params.historyLength() < task.getHistory().size()) {
107105
List<Message> history;
108106
if (params.historyLength() <= 0) {
109-
history = new ArrayList<>();
107+
history = task.getHistory();
110108
} else {
111109
history = task.getHistory().subList(
112110
task.getHistory().size() - params.historyLength(),

spec/src/main/java/io/a2a/spec/TaskQueryParams.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,20 @@
1717

1818
@JsonInclude(JsonInclude.Include.NON_ABSENT)
1919
@JsonIgnoreProperties(ignoreUnknown = true)
20-
public record TaskQueryParams(String id, @Nullable Integer historyLength, @Nullable Map<String, Object> metadata) {
20+
public record TaskQueryParams(String id, int historyLength, @Nullable Map<String, Object> metadata) {
2121

2222
public TaskQueryParams {
2323
Assert.checkNotNullParam("id", id);
24-
if (historyLength != null && historyLength < 0) {
24+
if (historyLength < 0) {
2525
throw new IllegalArgumentException("Invalid history length");
2626
}
2727
}
2828

2929
public TaskQueryParams(String id) {
30-
this(id, null, null);
30+
this(id, 0, null);
3131
}
3232

33-
public TaskQueryParams(String id, @Nullable Integer historyLength) {
33+
public TaskQueryParams(String id, int historyLength) {
3434
this(id, historyLength, null);
3535
}
3636
}

transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ public HTTPRestResponse resubscribeTask(String taskId, ServerCallContext context
162162
}
163163
}
164164

165-
public HTTPRestResponse getTask(String taskId, @Nullable Integer historyLength, ServerCallContext context) {
165+
public HTTPRestResponse getTask(String taskId, int historyLength, ServerCallContext context) {
166166
try {
167167
TaskQueryParams params = new TaskQueryParams(taskId, historyLength);
168168
Task task = requestHandler.onGetTask(params, context);

transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public void testGetTaskSuccess() {
2626
RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor);
2727
taskStore.save(MINIMAL_TASK);
2828

29-
RestHandler.HTTPRestResponse response = handler.getTask(MINIMAL_TASK.getId(),null, callContext);
29+
RestHandler.HTTPRestResponse response = handler.getTask(MINIMAL_TASK.getId(), 0, callContext);
3030

3131
Assertions.assertEquals(200, response.getStatusCode());
3232
Assertions.assertEquals("application/json", response.getContentType());
@@ -43,7 +43,7 @@ public void testGetTaskSuccess() {
4343
public void testGetTaskNotFound() {
4444
RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor);
4545

46-
RestHandler.HTTPRestResponse response = handler.getTask("nonexistent", null, callContext);
46+
RestHandler.HTTPRestResponse response = handler.getTask("nonexistent", 0, callContext);
4747

4848
Assertions.assertEquals(404, response.getStatusCode());
4949
Assertions.assertEquals("application/json", response.getContentType());
@@ -315,7 +315,7 @@ public void testHttpStatusCodeMapping() {
315315
Assertions.assertEquals(400, response.getStatusCode());
316316

317317
// Test 404 for not found
318-
response = handler.getTask("nonexistent", null, callContext);
318+
response = handler.getTask("nonexistent", 0, callContext);
319319
Assertions.assertEquals(404, response.getStatusCode());
320320
}
321321

0 commit comments

Comments
 (0)