Skip to content

Commit 6e654e5

Browse files
committed
Extract and assert chatCompletion toolCalls single and streamed
1 parent ffe043a commit 6e654e5

File tree

5 files changed

+249
-10
lines changed

5 files changed

+249
-10
lines changed

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/main/java/datadog/trace/instrumentation/openai_java/OpenAiDecorator.java

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.openai.core.JsonField;
55
import com.openai.core.http.Headers;
66
import com.openai.core.http.HttpResponse;
7+
import com.openai.helpers.ChatCompletionAccumulator;
78
import com.openai.models.ResponsesModel;
89
import com.openai.models.chat.completions.ChatCompletion;
910
import com.openai.models.chat.completions.ChatCompletionChunk;
@@ -319,24 +320,99 @@ public void withChatCompletionChunks(AgentSpan span, List<ChatCompletionChunk> c
319320
for (int i = 0; i < choiceNum; i++) {
320321
contents[i] = new StringBuilder(128);
321322
}
323+
// collect tool calls by choices for all chunks
324+
// Map from choice index -> (tool call index -> accumulated tool call data)
325+
@SuppressWarnings("unchecked")
326+
Map<Long, StreamingToolCallData>[] toolCallsByChoice = new Map[choiceNum];
327+
for (int i = 0; i < choiceNum; i++) {
328+
toolCallsByChoice[i] = new HashMap<>();
329+
}
330+
331+
// Create an accumulator
332+
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
333+
334+
// Accumulate each chunk as it arrives
335+
for (ChatCompletionChunk chunk : chunks) {
336+
accumulator.accumulate(chunk);
337+
}
338+
339+
// Get the final ChatCompletion
340+
ChatCompletion chatCompletion = accumulator.chatCompletion();
341+
322342
for (ChatCompletionChunk chunk : chunks) {
323343
// choices can be empty for the last chunk
324344
List<ChatCompletionChunk.Choice> choices = chunk.choices();
325345
for (int i = 0; i < choiceNum && i < choices.size(); i++) {
326346
ChatCompletionChunk.Choice choice = choices.get(i);
327347
ChatCompletionChunk.Choice.Delta delta = choice.delta();
328348
delta.content().ifPresent(contents[i]::append);
349+
350+
// accumulate tool calls
351+
Optional<List<ChatCompletionChunk.Choice.Delta.ToolCall>> toolCallsOpt = delta.toolCalls();
352+
if (toolCallsOpt.isPresent()) {
353+
for (ChatCompletionChunk.Choice.Delta.ToolCall toolCall : toolCallsOpt.get()) {
354+
long index = toolCall.index();
355+
StreamingToolCallData data =
356+
toolCallsByChoice[i].computeIfAbsent(index, k -> new StreamingToolCallData());
357+
toolCall.id().ifPresent(id -> data.id = id);
358+
toolCall
359+
.type()
360+
.flatMap(t -> t._value().asString())
361+
.ifPresent(type -> data.type = type);
362+
toolCall
363+
.function()
364+
.ifPresent(
365+
fn -> {
366+
fn.name().ifPresent(data.name::append);
367+
fn.arguments().ifPresent(data.arguments::append);
368+
});
369+
}
370+
}
329371
}
330372
chunk.usage().ifPresent(usage -> withCompletionUsage(span, usage));
331373
}
332374
// build LLMMessages
333375
List<LLMObs.LLMMessage> llmMessages = new ArrayList<>(choiceNum);
334376
for (int i = 0; i < choiceNum; i++) {
335-
llmMessages.add(LLMObs.LLMMessage.from(roles[i], contents[i].toString()));
377+
List<LLMObs.ToolCall> toolCalls = buildToolCallsFromStreamingData(toolCallsByChoice[i]);
378+
llmMessages.add(LLMObs.LLMMessage.from(roles[i], contents[i].toString(), toolCalls));
336379
}
337380
span.setTag("_ml_obs_tag.output", llmMessages);
338381
}
339382

383+
/** Helper class to accumulate streaming tool call data across chunks */
384+
private static class StreamingToolCallData {
385+
String id;
386+
String type = "function";
387+
StringBuilder name = new StringBuilder();
388+
StringBuilder arguments = new StringBuilder();
389+
}
390+
391+
private List<LLMObs.ToolCall> buildToolCallsFromStreamingData(
392+
Map<Long, StreamingToolCallData> toolCallDataMap) {
393+
if (toolCallDataMap.isEmpty()) {
394+
return Collections.emptyList();
395+
}
396+
List<LLMObs.ToolCall> toolCalls = new ArrayList<>();
397+
// Sort by index to maintain order
398+
toolCallDataMap.entrySet().stream()
399+
.sorted(Map.Entry.comparingByKey())
400+
.forEach(
401+
entry -> {
402+
StreamingToolCallData data = entry.getValue();
403+
String name = data.name.toString();
404+
String argumentsJson = data.arguments.toString();
405+
Map<String, Object> arguments = Collections.singletonMap("value", argumentsJson);
406+
try {
407+
arguments = ToolCallExtractor.parseArguments(argumentsJson);
408+
} catch (Exception e) {
409+
// keep default map with raw value
410+
}
411+
toolCalls.add(LLMObs.ToolCall.from(name, data.type, data.id, arguments));
412+
});
413+
return toolCalls;
414+
}
415+
340416
public void withEmbeddingCreateParams(AgentSpan span, EmbeddingCreateParams params) {
341417
span.setTag("_ml_obs_tag.span.kind", Tags.LLMOBS_EMBEDDING_SPAN_KIND);
342418
span.setResourceName(EMBEDDINGS_CREATE);

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/main/java/datadog/trace/instrumentation/openai_java/OpenAiModule.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public OpenAiModule() {
1616
public String[] helperClassNames() {
1717
return new String[] {
1818
packageName + ".OpenAiDecorator",
19+
packageName + ".OpenAiDecorator$1",
20+
packageName + ".OpenAiDecorator$StreamingToolCallData",
1921
packageName + ".ResponseWrappers",
2022
packageName + ".ResponseWrappers$DDHttpResponseFor",
2123
packageName + ".ResponseWrappers$1",

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/main/java/datadog/trace/instrumentation/openai_java/ToolCallExtractor.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ public static LLMObs.ToolCall getToolCall(ChatCompletionMessageToolCall toolCall
2929
String name = function.name();
3030
String argumentsJson = function.arguments();
3131

32-
Map<String, Object> arguments = Collections.singletonMap("value", argumentsJson);
32+
Map<String, Object> arguments;
3333
try {
34-
arguments = MAPPER.readValue(argumentsJson, MAP_TYPE_REF);
34+
arguments = parseArguments(argumentsJson);
3535
} catch (Exception e) {
3636
log.debug("Failed to parse tool call arguments as JSON: {}", argumentsJson, e);
37+
arguments = Collections.singletonMap("value", argumentsJson);
3738
}
3839

3940
String type = "function";
@@ -48,4 +49,8 @@ public static LLMObs.ToolCall getToolCall(ChatCompletionMessageToolCall toolCall
4849
}
4950
return null;
5051
}
52+
53+
public static Map<String, Object> parseArguments(String argumentsJson) throws Exception {
54+
return MAPPER.readValue(argumentsJson, MAP_TYPE_REF);
55+
}
5156
}

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/test/groovy/ChatCompletionServiceTest.groovy

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.openai.models.chat.completions.ChatCompletion
88
import com.openai.models.chat.completions.ChatCompletionChunk
99
import com.openai.models.completions.Completion
1010
import datadog.trace.api.DDSpanTypes
11+
import datadog.trace.api.llmobs.LLMObs
1112
import datadog.trace.bootstrap.instrumentation.api.Tags
1213
import datadog.trace.instrumentation.openai_java.OpenAiDecorator
1314
import java.util.concurrent.CompletableFuture
@@ -119,21 +120,65 @@ class ChatCompletionServiceTest extends OpenAiTest {
119120
}
120121

121122
def "create chat/completion test with tool calls"() {
122-
ChatCompletion resp = runUnderTrace("parent") {
123+
runUnderTrace("parent") {
123124
openAiClient.chat().completions().create(chatCompletionCreateParamsWithTools())
124125
}
125126

126127
expect:
127-
resp != null
128-
resp.choices().size() == 1
129-
resp.choices().get(0).message().toolCalls().isPresent()
130-
resp.choices().get(0).message().toolCalls().get().size() == 1
131-
resp.choices().get(0).message().toolCalls().get().get(0).function().get().function().name() == "extract_student_info"
128+
List<LLMObs.LLMMessage> outputTag = []
129+
assertChatCompletionTrace(false, outputTag)
132130
and:
133-
assertChatCompletionTrace(false)
131+
outputTag.size() == 1
132+
LLMObs.LLMMessage outputMsg = outputTag.get(0)
133+
outputMsg.toolCalls.size() == 1
134+
def toolcall = outputMsg.toolCalls.get(0)
135+
toolcall.name == "extract_student_info"
136+
toolcall.toolId instanceof String
137+
toolcall.type == "function"
138+
toolcall.arguments == [
139+
name: 'David Nguyen',
140+
major: 'computer science',
141+
school: 'Stanford University',
142+
grades: 3.8,
143+
clubs: ['Chess Club', 'South Asian Student Association']
144+
]
145+
}
146+
147+
def "create streaming chat/completion test with tool calls"() {
148+
runnableUnderTrace("parent") {
149+
StreamResponse<ChatCompletionChunk> streamCompletion = openAiClient.chat().completions().createStreaming(chatCompletionCreateParamsWithTools())
150+
try (Stream stream = streamCompletion.stream()) {
151+
stream.forEach { chunk ->
152+
// chunks.add(chunk)
153+
}
154+
}
155+
}
156+
157+
expect:
158+
List<LLMObs.LLMMessage> outputTag = []
159+
assertChatCompletionTrace(true, outputTag)
160+
and:
161+
outputTag.size() == 1
162+
LLMObs.LLMMessage outputMsg = outputTag.get(0)
163+
outputMsg.toolCalls.size() == 1
164+
def toolcall = outputMsg.toolCalls.get(0)
165+
toolcall.name == "extract_student_info"
166+
toolcall.toolId instanceof String
167+
toolcall.type == "function"
168+
toolcall.arguments == [
169+
name: 'David Nguyen',
170+
major: 'computer science',
171+
school: 'Stanford University',
172+
grades: 3.8,
173+
clubs: ['Chess Club', 'South Asian Student Association']
174+
]
134175
}
135176

136177
private void assertChatCompletionTrace(boolean isStreaming) {
178+
assertChatCompletionTrace(isStreaming, null)
179+
}
180+
181+
private void assertChatCompletionTrace(boolean isStreaming, List outputTagsOut) {
137182
assertTraces(1) {
138183
trace(3) {
139184
sortSpansByStart()
@@ -155,6 +200,10 @@ class ChatCompletionServiceTest extends OpenAiTest {
155200
"_ml_obs_tag.metadata" Map
156201
"_ml_obs_tag.input" List
157202
"_ml_obs_tag.output" List
203+
def outputTags = tag("_ml_obs_tag.output")
204+
if (outputTagsOut != null && outputTags != null) {
205+
outputTagsOut.addAll(outputTags)
206+
}
158207
if (!isStreaming) {
159208
// streamed completions missing usage data
160209
"_ml_obs_metric.input_tokens" Long

0 commit comments

Comments
 (0)