|
4 | 4 | import com.openai.core.JsonField; |
5 | 5 | import com.openai.core.http.Headers; |
6 | 6 | import com.openai.core.http.HttpResponse; |
| 7 | +import com.openai.helpers.ChatCompletionAccumulator; |
7 | 8 | import com.openai.models.ResponsesModel; |
8 | 9 | import com.openai.models.chat.completions.ChatCompletion; |
9 | 10 | import com.openai.models.chat.completions.ChatCompletionChunk; |
@@ -319,24 +320,99 @@ public void withChatCompletionChunks(AgentSpan span, List<ChatCompletionChunk> c |
319 | 320 | for (int i = 0; i < choiceNum; i++) { |
320 | 321 | contents[i] = new StringBuilder(128); |
321 | 322 | } |
| 323 | + // collect tool calls by choices for all chunks |
| 324 | + // Map from choice index -> (tool call index -> accumulated tool call data) |
| 325 | + @SuppressWarnings("unchecked") |
| 326 | + Map<Long, StreamingToolCallData>[] toolCallsByChoice = new Map[choiceNum]; |
| 327 | + for (int i = 0; i < choiceNum; i++) { |
| 328 | + toolCallsByChoice[i] = new HashMap<>(); |
| 329 | + } |
| 330 | + |
| 331 | + // Create an accumulator |
| 332 | + ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create(); |
| 333 | + |
| 334 | +// Accumulate each chunk as it arrives |
| 335 | + for (ChatCompletionChunk chunk : chunks) { |
| 336 | + accumulator.accumulate(chunk); |
| 337 | + } |
| 338 | + |
| 339 | +// Get the final ChatCompletion |
| 340 | + ChatCompletion chatCompletion = accumulator.chatCompletion(); |
| 341 | + |
322 | 342 | for (ChatCompletionChunk chunk : chunks) { |
323 | 343 | // choices can be empty for the last chunk |
324 | 344 | List<ChatCompletionChunk.Choice> choices = chunk.choices(); |
325 | 345 | for (int i = 0; i < choiceNum && i < choices.size(); i++) { |
326 | 346 | ChatCompletionChunk.Choice choice = choices.get(i); |
327 | 347 | ChatCompletionChunk.Choice.Delta delta = choice.delta(); |
328 | 348 | delta.content().ifPresent(contents[i]::append); |
| 349 | + |
| 350 | + // accumulate tool calls |
| 351 | + Optional<List<ChatCompletionChunk.Choice.Delta.ToolCall>> toolCallsOpt = delta.toolCalls(); |
| 352 | + if (toolCallsOpt.isPresent()) { |
| 353 | + for (ChatCompletionChunk.Choice.Delta.ToolCall toolCall : toolCallsOpt.get()) { |
| 354 | + long index = toolCall.index(); |
| 355 | + StreamingToolCallData data = |
| 356 | + toolCallsByChoice[i].computeIfAbsent(index, k -> new StreamingToolCallData()); |
| 357 | + toolCall.id().ifPresent(id -> data.id = id); |
| 358 | + toolCall |
| 359 | + .type() |
| 360 | + .flatMap(t -> t._value().asString()) |
| 361 | + .ifPresent(type -> data.type = type); |
| 362 | + toolCall |
| 363 | + .function() |
| 364 | + .ifPresent( |
| 365 | + fn -> { |
| 366 | + fn.name().ifPresent(data.name::append); |
| 367 | + fn.arguments().ifPresent(data.arguments::append); |
| 368 | + }); |
| 369 | + } |
| 370 | + } |
329 | 371 | } |
330 | 372 | chunk.usage().ifPresent(usage -> withCompletionUsage(span, usage)); |
331 | 373 | } |
332 | 374 | // build LLMMessages |
333 | 375 | List<LLMObs.LLMMessage> llmMessages = new ArrayList<>(choiceNum); |
334 | 376 | for (int i = 0; i < choiceNum; i++) { |
335 | | - llmMessages.add(LLMObs.LLMMessage.from(roles[i], contents[i].toString())); |
| 377 | + List<LLMObs.ToolCall> toolCalls = buildToolCallsFromStreamingData(toolCallsByChoice[i]); |
| 378 | + llmMessages.add(LLMObs.LLMMessage.from(roles[i], contents[i].toString(), toolCalls)); |
336 | 379 | } |
337 | 380 | span.setTag("_ml_obs_tag.output", llmMessages); |
338 | 381 | } |
339 | 382 |
|
| 383 | + /** Helper class to accumulate streaming tool call data across chunks */ |
| 384 | + private static class StreamingToolCallData { |
| 385 | + String id; |
| 386 | + String type = "function"; |
| 387 | + StringBuilder name = new StringBuilder(); |
| 388 | + StringBuilder arguments = new StringBuilder(); |
| 389 | + } |
| 390 | + |
| 391 | + private List<LLMObs.ToolCall> buildToolCallsFromStreamingData( |
| 392 | + Map<Long, StreamingToolCallData> toolCallDataMap) { |
| 393 | + if (toolCallDataMap.isEmpty()) { |
| 394 | + return Collections.emptyList(); |
| 395 | + } |
| 396 | + List<LLMObs.ToolCall> toolCalls = new ArrayList<>(); |
| 397 | + // Sort by index to maintain order |
| 398 | + toolCallDataMap.entrySet().stream() |
| 399 | + .sorted(Map.Entry.comparingByKey()) |
| 400 | + .forEach( |
| 401 | + entry -> { |
| 402 | + StreamingToolCallData data = entry.getValue(); |
| 403 | + String name = data.name.toString(); |
| 404 | + String argumentsJson = data.arguments.toString(); |
| 405 | + Map<String, Object> arguments = Collections.singletonMap("value", argumentsJson); |
| 406 | + try { |
| 407 | + arguments = ToolCallExtractor.parseArguments(argumentsJson); |
| 408 | + } catch (Exception e) { |
| 409 | + // keep default map with raw value |
| 410 | + } |
| 411 | + toolCalls.add(LLMObs.ToolCall.from(name, data.type, data.id, arguments)); |
| 412 | + }); |
| 413 | + return toolCalls; |
| 414 | + } |
| 415 | + |
340 | 416 | public void withEmbeddingCreateParams(AgentSpan span, EmbeddingCreateParams params) { |
341 | 417 | span.setTag("_ml_obs_tag.span.kind", Tags.LLMOBS_EMBEDDING_SPAN_KIND); |
342 | 418 | span.setResourceName(EMBEDDINGS_CREATE); |
|
0 commit comments