diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/A2AStreamingHandler.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/A2AStreamingHandler.kt index 208c82fa1..b16aaa0d9 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/A2AStreamingHandler.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/A2AStreamingHandler.kt @@ -63,15 +63,9 @@ class A2AStreamingHandler( activeStreams.remove(streamId) } - // Send initial connection established event - try { - emitter.send(SseEmitter.event() - .name("connected") - .data(mapOf("streamId" to streamId)) - ) - } catch (e: Exception) { - logger.error("Error sending initial event", e) - emitter.completeWithError(e) + emitter.onError { throwable -> + logger.error("Stream error for streamId: {}", streamId, throwable) + activeStreams.remove(streamId) } return emitter @@ -88,49 +82,15 @@ class A2AStreamingHandler( try { val eventData = when (event) { - is Message -> { - SseEmitter.event() - .name("message") - .data(objectMapper.writeValueAsString(event), MediaType.APPLICATION_JSON) - } - is Task -> { - SseEmitter.event() - .name("task") - .data(objectMapper.writeValueAsString(event), MediaType.APPLICATION_JSON) - } - is TaskStatusUpdateEvent -> { - SseEmitter.event() - .name("task-update") - .data( - objectMapper.writeValueAsString( - SendStreamingMessageResponse( - "2.0", - streamId, - event, - null - ) - ), MediaType.APPLICATION_JSON - ) - } - is TaskArtifactUpdateEvent -> { - SseEmitter.event() - .name("task-update") - .data( - objectMapper.writeValueAsString( - SendStreamingMessageResponse( - "2.0", - streamId, - event, - null - ) - ), MediaType.APPLICATION_JSON - ) + is Message, is Task, is TaskArtifactUpdateEvent, is TaskStatusUpdateEvent -> { + createEventData(streamId, event) } } emitter.send(eventData) } catch (e: Exception) { - logger.error("Error sending stream event", e) - emitter.completeWithError(e) + logger.error("Error sending stream event to streamId: {}", streamId, e) + val emitter = activeStreams.remove(streamId) + emitter?.completeWithError(e) } } @@ -142,16 +102,48 @@ class A2AStreamingHandler( emitter?.complete() } + private fun createEventData(streamId: String, event: StreamingEventKind): SseEmitter.SseEventBuilder { + val eventName = when (event) { + is Message -> "message" + is Task -> "task" + is TaskStatusUpdateEvent, is TaskArtifactUpdateEvent -> "task-update" + } + + val response = SendStreamingMessageResponse( + "2.0", + streamId, + event, + null + ) + + return SseEmitter.event() + .name(eventName) + .data(objectMapper.writeValueAsString(response), MediaType.APPLICATION_JSON) + } + /** * Shuts down the streaming handler */ fun shutdown() { + logger.info("Shutting down, closing {} active streams", activeStreams.size) + + activeStreams.values.forEach { emitter -> + try { + emitter.complete() + } catch (e: Exception) { + logger.warn("Error completing emitter during shutdown", e) + } + } + activeStreams.clear() + scheduler.shutdown() try { if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) { + logger.warn("Scheduler did not terminate gracefully, forcing shutdown") scheduler.shutdownNow() } } catch (e: InterruptedException) { + Thread.currentThread().interrupt() scheduler.shutdownNow() } } diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/AutonomyA2ARequestHandler.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/AutonomyA2ARequestHandler.kt index 7a2156e6e..bcef418de 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/AutonomyA2ARequestHandler.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/a2a/server/support/AutonomyA2ARequestHandler.kt @@ -163,23 +163,34 @@ class AutonomyA2ARequestHandler( val emitter = streamingHandler.createStream(streamId) Thread.startVirtualThread { + val taskId = ensureTaskId(params.message.taskId) + val contextId = ensureContextId(params.message.contextId) + val taskStatusUpdateEventBuilder = TaskStatusUpdateEvent.Builder() + .taskId(taskId) + .contextId(contextId) + val taskBuilder = Task.Builder() + .id(taskId) + .contextId(contextId) + try { - // Send initial status event + // Send task submitted + val taskSubmitted = taskBuilder + .status(createSubmittedTaskStatus(taskId, contextId)) + .history(listOfNotNull(params.message)) + .artifacts(emptyList()) + .metadata(null) + .build() + streamingHandler.sendStreamEvent(streamId, taskSubmitted) + + // Send status before running the agent streamingHandler.sendStreamEvent( - streamId, TaskStatusUpdateEvent.Builder() - .taskId(params.message.taskId) - .contextId(params.message.contextId) - .status(createWorkingTaskStatus(params, "Task started...")) + streamId, taskStatusUpdateEventBuilder + .status(createWorkingTaskStatus(params, "Starting task...")) .build() ) - // Send the received message, if any - params.message?.let { userMsg -> - streamingHandler.sendStreamEvent(streamId, userMsg) - } - val intent = params.message?.parts?.filterIsInstance()?.firstOrNull()?.text - ?: "Task ${params.message.taskId}" + ?: "Task ${taskId}" // Execute the task using autonomy service val result = autonomy.chooseAndRunAgent( @@ -190,35 +201,28 @@ class AutonomyA2ARequestHandler( // Send intermediate status updates streamingHandler.sendStreamEvent( - streamId, TaskStatusUpdateEvent.Builder() - .taskId(params.message.taskId) - .contextId(ensureContextId(params.message.contextId)) + streamId, taskStatusUpdateEventBuilder .status(createWorkingTaskStatus(params, "Processing task...")) .build() ) // Send result - val taskResult = Task.Builder() - .id(params.message.taskId) - .contextId("ctx_${UUID.randomUUID()}") - .status(createCompletedTaskStatus(params)) + val taskResult = taskBuilder + .status(createCompletedTaskStatus(taskId, contextId)) .history(listOfNotNull(params.message)) .artifacts( listOf( createResultArtifact(result, params.configuration?.acceptedOutputModes) ) ) - .metadata(null) .build() streamingHandler.sendStreamEvent(streamId, taskResult) } catch (e: Exception) { logger.error("Streaming error", e) try { streamingHandler.sendStreamEvent( - streamId, TaskStatusUpdateEvent.Builder() - .taskId(params.message.taskId) - .contextId(ensureContextId(params.message.contextId)) - .status(createFailedTaskStatus(params, e)) + streamId, taskStatusUpdateEventBuilder + .status(createFailedTaskStatus(taskId, contextId, e)) .build() ) } catch (sendError: Exception) { @@ -246,20 +250,25 @@ class AutonomyA2ARequestHandler( TODO() } - private fun createFailedTaskStatus(params: MessageSendParams, e: Exception): TaskStatus = TaskStatus( + private fun createFailedTaskStatus( + taskId: String, + contextId: String, + e: Exception + ): TaskStatus = TaskStatus( TaskState.FAILED, Message.Builder() .messageId(UUID.randomUUID().toString()) .role(Message.Role.AGENT) .parts(listOf(TextPart("Error: ${e.message}"))) - .contextId(params.message.contextId) - .taskId(params.message.taskId) + .contextId(contextId) + .taskId(taskId) .build(), LocalDateTime.now() ) private fun createCompletedTaskStatus( - params: MessageSendParams, + taskId: String, + contextId: String, textPart: String = "Task completed successfully" ): TaskStatus = TaskStatus( TaskState.COMPLETED, @@ -267,8 +276,8 @@ class AutonomyA2ARequestHandler( .messageId(UUID.randomUUID().toString()) .role(Message.Role.AGENT) .parts(listOf(TextPart(textPart))) - .contextId(params.message.contextId) - .taskId(params.message.taskId) + .contextId(contextId) + .taskId(taskId) .build(), LocalDateTime.now() ) @@ -288,6 +297,22 @@ class AutonomyA2ARequestHandler( LocalDateTime.now() ) + private fun createSubmittedTaskStatus( + taskId: String, + contextId: String, + textPart: String = "Submitted..." + ): TaskStatus = TaskStatus( + TaskState.SUBMITTED, + Message.Builder() + .messageId(UUID.randomUUID().toString()) + .role(Message.Role.AGENT) + .parts(listOf(TextPart(textPart))) + .contextId(contextId) + .taskId(taskId) + .build(), + LocalDateTime.now() + ) + private fun ensureContextId(providedContextId: String?): String { return providedContextId ?: ("ctx_" + UUID.randomUUID().toString()) }