Skip to content

Commit c783318

Browse files
committed
Use ChatCompletionAccumulator to simplify streamed chat/response decoration
1 parent 6e654e5 commit c783318

File tree

4 files changed

+14
-129
lines changed

4 files changed

+14
-129
lines changed

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/main/java/datadog/trace/instrumentation/openai_java/OpenAiDecorator.java

Lines changed: 1 addition & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -295,122 +295,12 @@ private static LLMObs.LLMMessage llmMessage(ChatCompletion.Choice choice) {
295295
}
296296

297297
public void withChatCompletionChunks(AgentSpan span, List<ChatCompletionChunk> chunks) {
298-
if (chunks.isEmpty()) {
299-
return;
300-
}
301-
ChatCompletionChunk firstChunk = chunks.get(0);
302-
String modelName = firstChunk.model();
303-
span.setTag(RESPONSE_MODEL, modelName);
304-
span.setTag("_ml_obs_tag.model_name", modelName);
305-
span.setTag("_ml_obs_tag.model_provider", "openai");
306-
307-
// assume that number of choices is the same for each chunk
308-
final int choiceNum = firstChunk.choices().size();
309-
// collect roles by choices by the first chunk
310-
String[] roles = new String[choiceNum];
311-
for (int i = 0; i < choiceNum; i++) {
312-
ChatCompletionChunk.Choice choice = firstChunk.choices().get(i);
313-
Optional<String> role = choice.delta().role().flatMap(r -> r._value().asString());
314-
if (role.isPresent()) {
315-
roles[i] = role.get();
316-
}
317-
}
318-
// collect content by choices for all chunks
319-
StringBuilder[] contents = new StringBuilder[choiceNum];
320-
for (int i = 0; i < choiceNum; i++) {
321-
contents[i] = new StringBuilder(128);
322-
}
323-
// collect tool calls by choices for all chunks
324-
// Map from choice index -> (tool call index -> accumulated tool call data)
325-
@SuppressWarnings("unchecked")
326-
Map<Long, StreamingToolCallData>[] toolCallsByChoice = new Map[choiceNum];
327-
for (int i = 0; i < choiceNum; i++) {
328-
toolCallsByChoice[i] = new HashMap<>();
329-
}
330-
331-
// Create an accumulator
332298
ChatCompletionAccumulator accumulator = ChatCompletionAccumulator.create();
333-
334-
// Accumulate each chunk as it arrives
335299
for (ChatCompletionChunk chunk : chunks) {
336300
accumulator.accumulate(chunk);
337301
}
338-
339-
// Get the final ChatCompletion
340302
ChatCompletion chatCompletion = accumulator.chatCompletion();
341-
342-
for (ChatCompletionChunk chunk : chunks) {
343-
// choices can be empty for the last chunk
344-
List<ChatCompletionChunk.Choice> choices = chunk.choices();
345-
for (int i = 0; i < choiceNum && i < choices.size(); i++) {
346-
ChatCompletionChunk.Choice choice = choices.get(i);
347-
ChatCompletionChunk.Choice.Delta delta = choice.delta();
348-
delta.content().ifPresent(contents[i]::append);
349-
350-
// accumulate tool calls
351-
Optional<List<ChatCompletionChunk.Choice.Delta.ToolCall>> toolCallsOpt = delta.toolCalls();
352-
if (toolCallsOpt.isPresent()) {
353-
for (ChatCompletionChunk.Choice.Delta.ToolCall toolCall : toolCallsOpt.get()) {
354-
long index = toolCall.index();
355-
StreamingToolCallData data =
356-
toolCallsByChoice[i].computeIfAbsent(index, k -> new StreamingToolCallData());
357-
toolCall.id().ifPresent(id -> data.id = id);
358-
toolCall
359-
.type()
360-
.flatMap(t -> t._value().asString())
361-
.ifPresent(type -> data.type = type);
362-
toolCall
363-
.function()
364-
.ifPresent(
365-
fn -> {
366-
fn.name().ifPresent(data.name::append);
367-
fn.arguments().ifPresent(data.arguments::append);
368-
});
369-
}
370-
}
371-
}
372-
chunk.usage().ifPresent(usage -> withCompletionUsage(span, usage));
373-
}
374-
// build LLMMessages
375-
List<LLMObs.LLMMessage> llmMessages = new ArrayList<>(choiceNum);
376-
for (int i = 0; i < choiceNum; i++) {
377-
List<LLMObs.ToolCall> toolCalls = buildToolCallsFromStreamingData(toolCallsByChoice[i]);
378-
llmMessages.add(LLMObs.LLMMessage.from(roles[i], contents[i].toString(), toolCalls));
379-
}
380-
span.setTag("_ml_obs_tag.output", llmMessages);
381-
}
382-
383-
/** Helper class to accumulate streaming tool call data across chunks */
384-
private static class StreamingToolCallData {
385-
String id;
386-
String type = "function";
387-
StringBuilder name = new StringBuilder();
388-
StringBuilder arguments = new StringBuilder();
389-
}
390-
391-
private List<LLMObs.ToolCall> buildToolCallsFromStreamingData(
392-
Map<Long, StreamingToolCallData> toolCallDataMap) {
393-
if (toolCallDataMap.isEmpty()) {
394-
return Collections.emptyList();
395-
}
396-
List<LLMObs.ToolCall> toolCalls = new ArrayList<>();
397-
// Sort by index to maintain order
398-
toolCallDataMap.entrySet().stream()
399-
.sorted(Map.Entry.comparingByKey())
400-
.forEach(
401-
entry -> {
402-
StreamingToolCallData data = entry.getValue();
403-
String name = data.name.toString();
404-
String argumentsJson = data.arguments.toString();
405-
Map<String, Object> arguments = Collections.singletonMap("value", argumentsJson);
406-
try {
407-
arguments = ToolCallExtractor.parseArguments(argumentsJson);
408-
} catch (Exception e) {
409-
// keep default map with raw value
410-
}
411-
toolCalls.add(LLMObs.ToolCall.from(name, data.type, data.id, arguments));
412-
});
413-
return toolCalls;
303+
withChatCompletion(span, chatCompletion);
414304
}
415305

416306
public void withEmbeddingCreateParams(AgentSpan span, EmbeddingCreateParams params) {

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/main/java/datadog/trace/instrumentation/openai_java/OpenAiModule.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ public OpenAiModule() {
1616
public String[] helperClassNames() {
1717
return new String[] {
1818
packageName + ".OpenAiDecorator",
19-
packageName + ".OpenAiDecorator$1",
20-
packageName + ".OpenAiDecorator$StreamingToolCallData",
2119
packageName + ".ResponseWrappers",
2220
packageName + ".ResponseWrappers$DDHttpResponseFor",
2321
packageName + ".ResponseWrappers$1",

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/main/java/datadog/trace/instrumentation/openai_java/ToolCallExtractor.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public static LLMObs.ToolCall getToolCall(ChatCompletionMessageToolCall toolCall
3131

3232
Map<String, Object> arguments;
3333
try {
34-
arguments = parseArguments(argumentsJson);
34+
arguments = MAPPER.readValue(argumentsJson, MAP_TYPE_REF);
3535
} catch (Exception e) {
3636
log.debug("Failed to parse tool call arguments as JSON: {}", argumentsJson, e);
3737
arguments = Collections.singletonMap("value", argumentsJson);
@@ -49,8 +49,4 @@ public static LLMObs.ToolCall getToolCall(ChatCompletionMessageToolCall toolCall
4949
}
5050
return null;
5151
}
52-
53-
public static Map<String, Object> parseArguments(String argumentsJson) throws Exception {
54-
return MAPPER.readValue(argumentsJson, MAP_TYPE_REF);
55-
}
5652
}

dd-java-agent/instrumentation/openai-java/openai-java-3.0/src/test/groovy/ChatCompletionServiceTest.groovy

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,20 @@ class ChatCompletionServiceTest extends OpenAiTest {
136136
toolcall.toolId instanceof String
137137
toolcall.type == "function"
138138
toolcall.arguments == [
139-
name: 'David Nguyen',
140-
major: 'computer science',
141-
school: 'Stanford University',
142-
grades: 3.8,
143-
clubs: ['Chess Club', 'South Asian Student Association']
139+
name: 'David Nguyen',
140+
major: 'computer science',
141+
school: 'Stanford University',
142+
grades: 3.8,
143+
clubs: ['Chess Club', 'South Asian Student Association']
144144
]
145145
}
146146

147147
def "create streaming chat/completion test with tool calls"() {
148148
runnableUnderTrace("parent") {
149149
StreamResponse<ChatCompletionChunk> streamCompletion = openAiClient.chat().completions().createStreaming(chatCompletionCreateParamsWithTools())
150150
try (Stream stream = streamCompletion.stream()) {
151-
stream.forEach { chunk ->
151+
stream.forEach {
152+
chunk ->
152153
// chunks.add(chunk)
153154
}
154155
}
@@ -166,11 +167,11 @@ class ChatCompletionServiceTest extends OpenAiTest {
166167
toolcall.toolId instanceof String
167168
toolcall.type == "function"
168169
toolcall.arguments == [
169-
name: 'David Nguyen',
170-
major: 'computer science',
171-
school: 'Stanford University',
172-
grades: 3.8,
173-
clubs: ['Chess Club', 'South Asian Student Association']
170+
name: 'David Nguyen',
171+
major: 'computer science',
172+
school: 'Stanford University',
173+
grades: 3.8,
174+
clubs: ['Chess Club', 'South Asian Student Association']
174175
]
175176
}
176177

0 commit comments

Comments
 (0)