Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,9 @@ class A2AStreamingHandler(
activeStreams.remove(streamId)
}

// Send initial connection established event
try {
emitter.send(SseEmitter.event()
.name("connected")
.data(mapOf("streamId" to streamId))
)
} catch (e: Exception) {
logger.error("Error sending initial event", e)
emitter.completeWithError(e)
emitter.onError { throwable ->
logger.error("Stream error for streamId: {}", streamId, throwable)
activeStreams.remove(streamId)
}

return emitter
Expand All @@ -88,49 +82,15 @@ class A2AStreamingHandler(

try {
val eventData = when (event) {
is Message -> {
SseEmitter.event()
.name("message")
.data(objectMapper.writeValueAsString(event), MediaType.APPLICATION_JSON)
}
is Task -> {
SseEmitter.event()
.name("task")
.data(objectMapper.writeValueAsString(event), MediaType.APPLICATION_JSON)
}
is TaskStatusUpdateEvent -> {
SseEmitter.event()
.name("task-update")
.data(
objectMapper.writeValueAsString(
SendStreamingMessageResponse(
"2.0",
streamId,
event,
null
)
), MediaType.APPLICATION_JSON
)
}
is TaskArtifactUpdateEvent -> {
SseEmitter.event()
.name("task-update")
.data(
objectMapper.writeValueAsString(
SendStreamingMessageResponse(
"2.0",
streamId,
event,
null
)
), MediaType.APPLICATION_JSON
)
is Message, is Task, is TaskArtifactUpdateEvent, is TaskStatusUpdateEvent -> {
createEventData(streamId, event)
}
}
emitter.send(eventData)
} catch (e: Exception) {
logger.error("Error sending stream event", e)
emitter.completeWithError(e)
logger.error("Error sending stream event to streamId: {}", streamId, e)
val emitter = activeStreams.remove(streamId)
emitter?.completeWithError(e)
}
}

Expand All @@ -142,16 +102,48 @@ class A2AStreamingHandler(
emitter?.complete()
}

private fun createEventData(streamId: String, event: StreamingEventKind): SseEmitter.SseEventBuilder {
val eventName = when (event) {
is Message -> "message"
is Task -> "task"
is TaskStatusUpdateEvent, is TaskArtifactUpdateEvent -> "task-update"
}

val response = SendStreamingMessageResponse(
"2.0",
streamId,
event,
null
)

return SseEmitter.event()
.name(eventName)
.data(objectMapper.writeValueAsString(response), MediaType.APPLICATION_JSON)
}

/**
* Shuts down the streaming handler
*/
fun shutdown() {
logger.info("Shutting down, closing {} active streams", activeStreams.size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code here was suggested by AI but, imho, looks good


activeStreams.values.forEach { emitter ->
try {
emitter.complete()
} catch (e: Exception) {
logger.warn("Error completing emitter during shutdown", e)
}
}
activeStreams.clear()

scheduler.shutdown()
try {
if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) {
logger.warn("Scheduler did not terminate gracefully, forcing shutdown")
scheduler.shutdownNow()
}
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
scheduler.shutdownNow()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,23 +163,34 @@ class AutonomyA2ARequestHandler(
val emitter = streamingHandler.createStream(streamId)

Thread.startVirtualThread {
val taskId = ensureTaskId(params.message.taskId)
val contextId = ensureContextId(params.message.contextId)
val taskStatusUpdateEventBuilder = TaskStatusUpdateEvent.Builder()
.taskId(taskId)
.contextId(contextId)
val taskBuilder = Task.Builder()
.id(taskId)
.contextId(contextId)

try {
// Send initial status event
// Send task submitted
val taskSubmitted = taskBuilder
.status(createSubmittedTaskStatus(taskId, contextId))
.history(listOfNotNull(params.message))
.artifacts(emptyList())
.metadata(null)
.build()
streamingHandler.sendStreamEvent(streamId, taskSubmitted)

// Send status before running the agent
streamingHandler.sendStreamEvent(
streamId, TaskStatusUpdateEvent.Builder()
.taskId(params.message.taskId)
.contextId(params.message.contextId)
.status(createWorkingTaskStatus(params, "Task started..."))
streamId, taskStatusUpdateEventBuilder
.status(createWorkingTaskStatus(params, "Starting task..."))
.build()
)

// Send the received message, if any
params.message?.let { userMsg ->
streamingHandler.sendStreamEvent(streamId, userMsg)
}

val intent = params.message?.parts?.filterIsInstance<TextPart>()?.firstOrNull()?.text
?: "Task ${params.message.taskId}"
?: "Task ${taskId}"

// Execute the task using autonomy service
val result = autonomy.chooseAndRunAgent(
Expand All @@ -190,35 +201,28 @@ class AutonomyA2ARequestHandler(

// Send intermediate status updates
streamingHandler.sendStreamEvent(
streamId, TaskStatusUpdateEvent.Builder()
.taskId(params.message.taskId)
.contextId(ensureContextId(params.message.contextId))
streamId, taskStatusUpdateEventBuilder
.status(createWorkingTaskStatus(params, "Processing task..."))
.build()
)

// Send result
val taskResult = Task.Builder()
.id(params.message.taskId)
.contextId("ctx_${UUID.randomUUID()}")
.status(createCompletedTaskStatus(params))
val taskResult = taskBuilder
.status(createCompletedTaskStatus(taskId, contextId))
.history(listOfNotNull(params.message))
.artifacts(
listOf(
createResultArtifact(result, params.configuration?.acceptedOutputModes)
)
)
.metadata(null)
.build()
streamingHandler.sendStreamEvent(streamId, taskResult)
} catch (e: Exception) {
logger.error("Streaming error", e)
try {
streamingHandler.sendStreamEvent(
streamId, TaskStatusUpdateEvent.Builder()
.taskId(params.message.taskId)
.contextId(ensureContextId(params.message.contextId))
.status(createFailedTaskStatus(params, e))
streamId, taskStatusUpdateEventBuilder
.status(createFailedTaskStatus(taskId, contextId, e))
.build()
)
} catch (sendError: Exception) {
Expand Down Expand Up @@ -246,29 +250,34 @@ class AutonomyA2ARequestHandler(
TODO()
}

private fun createFailedTaskStatus(params: MessageSendParams, e: Exception): TaskStatus = TaskStatus(
private fun createFailedTaskStatus(
taskId: String,
contextId: String,
e: Exception
): TaskStatus = TaskStatus(
TaskState.FAILED,
Message.Builder()
.messageId(UUID.randomUUID().toString())
.role(Message.Role.AGENT)
.parts(listOf(TextPart("Error: ${e.message}")))
.contextId(params.message.contextId)
.taskId(params.message.taskId)
.contextId(contextId)
.taskId(taskId)
.build(),
LocalDateTime.now()
)

private fun createCompletedTaskStatus(
params: MessageSendParams,
taskId: String,
contextId: String,
textPart: String = "Task completed successfully"
): TaskStatus = TaskStatus(
TaskState.COMPLETED,
Message.Builder()
.messageId(UUID.randomUUID().toString())
.role(Message.Role.AGENT)
.parts(listOf(TextPart(textPart)))
.contextId(params.message.contextId)
.taskId(params.message.taskId)
.contextId(contextId)
.taskId(taskId)
.build(),
LocalDateTime.now()
)
Expand All @@ -288,6 +297,22 @@ class AutonomyA2ARequestHandler(
LocalDateTime.now()
)

private fun createSubmittedTaskStatus(
taskId: String,
contextId: String,
textPart: String = "Submitted..."
): TaskStatus = TaskStatus(
TaskState.SUBMITTED,
Message.Builder()
.messageId(UUID.randomUUID().toString())
.role(Message.Role.AGENT)
.parts(listOf(TextPart(textPart)))
.contextId(contextId)
.taskId(taskId)
.build(),
LocalDateTime.now()
)

private fun ensureContextId(providedContextId: String?): String {
return providedContextId ?: ("ctx_" + UUID.randomUUID().toString())
}
Expand Down
Loading