Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.openai.v1_1;

import static io.opentelemetry.javaagent.instrumentation.openai.v1_1.OpenAiSingletons.TELEMETRY;
import static net.bytebuddy.matcher.ElementMatchers.named;
import static net.bytebuddy.matcher.ElementMatchers.returns;

import com.openai.client.OpenAIClientAsync;
import io.opentelemetry.javaagent.extension.instrumentation.TypeInstrumentation;
import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.matcher.ElementMatcher;

public class OpenAiClientAsyncInstrumentation implements TypeInstrumentation {
@Override
public ElementMatcher<TypeDescription> typeMatcher() {
return named("com.openai.client.okhttp.OpenAIOkHttpClientAsync$Builder");
}

@Override
public void transform(TypeTransformer transformer) {
transformer.applyAdviceToMethod(
named("build").and(returns(named("com.openai.client.OpenAIClientAsync"))),
OpenAiClientAsyncInstrumentation.class.getName() + "$BuildAdvice");
}

@SuppressWarnings("unused")
public static class BuildAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Advice.AssignReturned.ToReturned
public static OpenAIClientAsync onExit(@Advice.Return OpenAIClientAsync client) {
return TELEMETRY.wrap(client);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package io.opentelemetry.javaagent.instrumentation.openai.v1_1;

import static io.opentelemetry.javaagent.extension.matcher.AgentElementMatchers.hasClassesNamed;
import static java.util.Collections.singletonList;
import static java.util.Arrays.asList;

import com.google.auto.service.AutoService;
import io.opentelemetry.javaagent.extension.instrumentation.InstrumentationModule;
Expand All @@ -27,6 +27,6 @@ public ElementMatcher.Junction<ClassLoader> classLoaderMatcher() {

@Override
public List<TypeInstrumentation> typeInstrumentations() {
return singletonList(new OpenAiClientInstrumentation());
return asList(new OpenAiClientInstrumentation(), new OpenAiClientAsyncInstrumentation());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package io.opentelemetry.javaagent.instrumentation.openai.v1_1;

import com.openai.client.OpenAIClient;
import com.openai.client.OpenAIClientAsync;
import io.opentelemetry.instrumentation.openai.v1_1.AbstractChatTest;
import io.opentelemetry.instrumentation.testing.junit.AgentInstrumentationExtension;
import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
Expand All @@ -31,6 +32,11 @@ protected OpenAIClient wrap(OpenAIClient client) {
return client;
}

@Override
protected OpenAIClientAsync wrap(OpenAIClientAsync client) {
return client;
}

@Override
protected final List<Consumer<SpanDataAssert>> maybeWithTransportSpan(
Consumer<SpanDataAssert> span) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

final class ChatCompletionEventsHelper {

private static final AttributeKey<String> EVENT_NAME = stringKey("event.name");

public static void emitPromptLogEvents(
Logger eventLogger, ChatCompletionCreateParams request, boolean captureMessageContent) {
Context context,
Logger eventLogger,
ChatCompletionCreateParams request,
boolean captureMessageContent) {
for (ChatCompletionMessageParam msg : request.messages()) {
String eventType;
Map<String, Value<?>> body = new HashMap<>();
Expand Down Expand Up @@ -84,7 +86,7 @@ public static void emitPromptLogEvents(
} else {
continue;
}
newEvent(eventLogger, eventType).setBody(Value.of(body)).emit();
newEvent(eventLogger, eventType).setContext(context).setBody(Value.of(body)).emit();
}
}

Expand Down Expand Up @@ -160,7 +162,10 @@ private static String joinContentParts(List<ChatCompletionContentPartText> conte
}

public static void emitCompletionLogEvents(
Logger eventLogger, ChatCompletion completion, boolean captureMessageContent) {
Context context,
Logger eventLogger,
ChatCompletion completion,
boolean captureMessageContent) {
for (ChatCompletion.Choice choice : completion.choices()) {
ChatCompletionMessage choiceMsg = choice.message();
Map<String, Value<?>> message = new HashMap<>();
Expand All @@ -179,25 +184,25 @@ public static void emitCompletionLogEvents(
.collect(Collectors.toList())));
});
emitCompletionLogEvent(
eventLogger, choice.index(), choice.finishReason().toString(), Value.of(message), null);
context,
eventLogger,
choice.index(),
choice.finishReason().toString(),
Value.of(message));
}
}

public static void emitCompletionLogEvent(
Context context,
Logger eventLogger,
long index,
String finishReason,
Value<?> eventMessageObject,
@Nullable Context contextOverride) {
Value<?> eventMessageObject) {
Map<String, Value<?>> body = new HashMap<>();
body.put("finish_reason", Value.of(finishReason));
body.put("index", Value.of(index));
body.put("message", eventMessageObject);
LogRecordBuilder builder = newEvent(eventLogger, "gen_ai.choice").setBody(Value.of(body));
if (contextOverride != null) {
builder.setContext(contextOverride);
}
builder.emit();
newEvent(eventLogger, "gen_ai.choice").setContext(context).setBody(Value.of(body)).emit();
}

private static LogRecordBuilder newEvent(Logger eventLogger, String name) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.openai.v1_1;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import java.util.concurrent.CompletableFuture;

final class CompletableFutureWrapper {
private CompletableFutureWrapper() {}

static <T> CompletableFuture<T> wrap(CompletableFuture<T> future, Context context) {
CompletableFuture<T> result = new CompletableFuture<>();
future.whenComplete(
(T value, Throwable throwable) -> {
try (Scope ignored = context.makeCurrent()) {
if (throwable != null) {
result.completeExceptionally(throwable);
} else {
result.complete(value);
}
}
});

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,65 +73,70 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl

private ChatCompletion create(
ChatCompletionCreateParams chatCompletionCreateParams, RequestOptions requestOptions) {
Context parentCtx = Context.current();
if (!instrumenter.shouldStart(parentCtx, chatCompletionCreateParams)) {
return createWithLogs(chatCompletionCreateParams, requestOptions);
Context parentContext = Context.current();
if (!instrumenter.shouldStart(parentContext, chatCompletionCreateParams)) {
return createWithLogs(parentContext, chatCompletionCreateParams, requestOptions);
}

Context ctx = instrumenter.start(parentCtx, chatCompletionCreateParams);
Context context = instrumenter.start(parentContext, chatCompletionCreateParams);
ChatCompletion completion;
try (Scope ignored = ctx.makeCurrent()) {
completion = createWithLogs(chatCompletionCreateParams, requestOptions);
try (Scope ignored = context.makeCurrent()) {
completion = createWithLogs(context, chatCompletionCreateParams, requestOptions);
} catch (Throwable t) {
instrumenter.end(ctx, chatCompletionCreateParams, null, t);
instrumenter.end(context, chatCompletionCreateParams, null, t);
throw t;
}

instrumenter.end(ctx, chatCompletionCreateParams, completion, null);
instrumenter.end(context, chatCompletionCreateParams, completion, null);
return completion;
}

private ChatCompletion createWithLogs(
ChatCompletionCreateParams chatCompletionCreateParams, RequestOptions requestOptions) {
Context context,
ChatCompletionCreateParams chatCompletionCreateParams,
RequestOptions requestOptions) {
ChatCompletionEventsHelper.emitPromptLogEvents(
eventLogger, chatCompletionCreateParams, captureMessageContent);
context, eventLogger, chatCompletionCreateParams, captureMessageContent);
ChatCompletion result = delegate.create(chatCompletionCreateParams, requestOptions);
ChatCompletionEventsHelper.emitCompletionLogEvents(eventLogger, result, captureMessageContent);
ChatCompletionEventsHelper.emitCompletionLogEvents(
context, eventLogger, result, captureMessageContent);
return result;
}

private StreamResponse<ChatCompletionChunk> createStreaming(
ChatCompletionCreateParams chatCompletionCreateParams, RequestOptions requestOptions) {
Context parentCtx = Context.current();
if (!instrumenter.shouldStart(parentCtx, chatCompletionCreateParams)) {
return createStreamingWithLogs(chatCompletionCreateParams, requestOptions, parentCtx, false);
Context parentContext = Context.current();
if (!instrumenter.shouldStart(parentContext, chatCompletionCreateParams)) {
return createStreamingWithLogs(
parentContext, chatCompletionCreateParams, requestOptions, false);
}

Context ctx = instrumenter.start(parentCtx, chatCompletionCreateParams);
try (Scope ignored = ctx.makeCurrent()) {
return createStreamingWithLogs(chatCompletionCreateParams, requestOptions, ctx, true);
Context context = instrumenter.start(parentContext, chatCompletionCreateParams);
try (Scope ignored = context.makeCurrent()) {
return createStreamingWithLogs(context, chatCompletionCreateParams, requestOptions, true);
} catch (Throwable t) {
instrumenter.end(ctx, chatCompletionCreateParams, null, t);
instrumenter.end(context, chatCompletionCreateParams, null, t);
throw t;
}
}

private StreamResponse<ChatCompletionChunk> createStreamingWithLogs(
Context context,
ChatCompletionCreateParams chatCompletionCreateParams,
RequestOptions requestOptions,
Context parentCtx,
boolean newSpan) {
ChatCompletionEventsHelper.emitPromptLogEvents(
eventLogger, chatCompletionCreateParams, captureMessageContent);
context, eventLogger, chatCompletionCreateParams, captureMessageContent);
StreamResponse<ChatCompletionChunk> result =
delegate.createStreaming(chatCompletionCreateParams, requestOptions);
return new TracingStreamedResponse(
result,
parentCtx,
chatCompletionCreateParams,
instrumenter,
eventLogger,
captureMessageContent,
newSpan);
new StreamListener(
context,
chatCompletionCreateParams,
instrumenter,
eventLogger,
captureMessageContent,
newSpan));
}
}
Loading
Loading