|
| 1 | +// Copyright Envoy AI Gateway Authors |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// The full text of the Apache license is available in the LICENSE file at |
| 4 | +// the root of the repo. |
| 5 | + |
| 6 | +package anthropic |
| 7 | + |
| 8 | +import ( |
| 9 | + "encoding/json" |
| 10 | + "fmt" |
| 11 | + |
| 12 | + "go.opentelemetry.io/otel/attribute" |
| 13 | + "go.opentelemetry.io/otel/codes" |
| 14 | + "go.opentelemetry.io/otel/trace" |
| 15 | + |
| 16 | + "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic" |
| 17 | + "github.com/envoyproxy/ai-gateway/internal/metrics" |
| 18 | + tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" |
| 19 | + "github.com/envoyproxy/ai-gateway/internal/tracing/openinference" |
| 20 | +) |
| 21 | + |
| 22 | +// MessageRecorder implements recorders for OpenInference chat completion spans. |
| 23 | +type MessageRecorder struct { |
| 24 | + traceConfig *openinference.TraceConfig |
| 25 | +} |
| 26 | + |
| 27 | +// NewMessageRecorderFromEnv creates an api.MessageRecorder |
| 28 | +// from environment variables using the OpenInference configuration specification. |
| 29 | +// |
| 30 | +// See: https://github.com/Arize-ai/openinference/blob/main/spec/configuration.md |
| 31 | +func NewMessageRecorderFromEnv() tracing.MessageRecorder { |
| 32 | + return NewMessageRecorder(nil) |
| 33 | +} |
| 34 | + |
| 35 | +// NewMessageRecorder creates a tracing.MessageRecorder with the |
| 36 | +// given config using the OpenInference configuration specification. |
| 37 | +// |
| 38 | +// Parameters: |
| 39 | +// - config: configuration for redaction. Defaults to NewTraceConfigFromEnv(). |
| 40 | +// |
| 41 | +// See: https://github.com/Arize-ai/openinference/blob/main/spec/configuration.md |
| 42 | +func NewMessageRecorder(config *openinference.TraceConfig) tracing.MessageRecorder { |
| 43 | + if config == nil { |
| 44 | + config = openinference.NewTraceConfigFromEnv() |
| 45 | + } |
| 46 | + return &MessageRecorder{traceConfig: config} |
| 47 | +} |
| 48 | + |
| 49 | +// startOpts sets trace.SpanKindInternal as that's the span kind used in |
| 50 | +// OpenInference. |
| 51 | +var startOpts = []trace.SpanStartOption{trace.WithSpanKind(trace.SpanKindInternal)} |
| 52 | + |
| 53 | +// StartParams implements the same method as defined in tracing.MessageRecorder. |
| 54 | +func (r *MessageRecorder) StartParams(*anthropic.MessagesRequest, []byte) (spanName string, opts []trace.SpanStartOption) { |
| 55 | + return "Message", startOpts |
| 56 | +} |
| 57 | + |
| 58 | +// RecordRequest implements the same method as defined in tracing.MessageRecorder. |
| 59 | +func (r *MessageRecorder) RecordRequest(span trace.Span, chatReq *anthropic.MessagesRequest, body []byte) { |
| 60 | + span.SetAttributes(buildRequestAttributes(chatReq, string(body), r.traceConfig)...) |
| 61 | +} |
| 62 | + |
| 63 | +// RecordResponseChunks implements the same method as defined in tracing.MessageRecorder. |
| 64 | +func (r *MessageRecorder) RecordResponseChunks(span trace.Span, chunks []*anthropic.MessagesStreamChunk) { |
| 65 | + if len(chunks) > 0 { |
| 66 | + span.AddEvent("First Token Stream Event") |
| 67 | + } |
| 68 | + converted := convertSSEToResponse(chunks) |
| 69 | + r.RecordResponse(span, converted) |
| 70 | +} |
| 71 | + |
| 72 | +// RecordResponseOnError implements the same method as defined in tracing.MessageRecorder. |
| 73 | +func (r *MessageRecorder) RecordResponseOnError(span trace.Span, statusCode int, body []byte) { |
| 74 | + openinference.RecordResponseError(span, statusCode, string(body)) |
| 75 | +} |
| 76 | + |
| 77 | +// RecordResponse implements the same method as defined in tracing.MessageRecorder. |
| 78 | +func (r *MessageRecorder) RecordResponse(span trace.Span, resp *anthropic.MessagesResponse) { |
| 79 | + // Set output attributes. |
| 80 | + var attrs []attribute.KeyValue |
| 81 | + attrs = buildResponseAttributes(resp, r.traceConfig) |
| 82 | + |
| 83 | + bodyString := openinference.RedactedValue |
| 84 | + if !r.traceConfig.HideOutputs { |
| 85 | + marshaled, err := json.Marshal(resp) |
| 86 | + if err == nil { |
| 87 | + bodyString = string(marshaled) |
| 88 | + } |
| 89 | + } |
| 90 | + attrs = append(attrs, attribute.String(openinference.OutputValue, bodyString)) |
| 91 | + span.SetAttributes(attrs...) |
| 92 | + span.SetStatus(codes.Ok, "") |
| 93 | +} |
| 94 | + |
| 95 | +// llmInvocationParameters is the representation of LLMInvocationParameters, |
| 96 | +// which includes all parameters except messages and tools, which have their |
| 97 | +// own attributes. |
| 98 | +// See: openinference-instrumentation-openai _request_attributes_extractor.py. |
| 99 | +type llmInvocationParameters struct { |
| 100 | + anthropic.MessagesRequest |
| 101 | + Messages []anthropic.MessageParam `json:"messages,omitempty"` |
| 102 | + Tools []anthropic.Tool `json:"tools,omitempty"` |
| 103 | +} |
| 104 | + |
| 105 | +// buildRequestAttributes builds OpenInference attributes from the request. |
| 106 | +func buildRequestAttributes(req *anthropic.MessagesRequest, body string, config *openinference.TraceConfig) []attribute.KeyValue { |
| 107 | + attrs := []attribute.KeyValue{ |
| 108 | + attribute.String(openinference.SpanKind, openinference.SpanKindLLM), |
| 109 | + attribute.String(openinference.LLMSystem, openinference.LLMSystemAnthropic), |
| 110 | + attribute.String(openinference.LLMModelName, req.Model), |
| 111 | + } |
| 112 | + |
| 113 | + if config.HideInputs { |
| 114 | + attrs = append(attrs, attribute.String(openinference.InputValue, openinference.RedactedValue)) |
| 115 | + } else { |
| 116 | + attrs = append(attrs, attribute.String(openinference.InputValue, body)) |
| 117 | + attrs = append(attrs, attribute.String(openinference.InputMimeType, openinference.MimeTypeJSON)) |
| 118 | + } |
| 119 | + |
| 120 | + if !config.HideLLMInvocationParameters { |
| 121 | + if invocationParamsJSON, err := json.Marshal(llmInvocationParameters{ |
| 122 | + MessagesRequest: *req, |
| 123 | + }); err == nil { |
| 124 | + attrs = append(attrs, attribute.String(openinference.LLMInvocationParameters, string(invocationParamsJSON))) |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + if !config.HideInputs && !config.HideInputMessages { |
| 129 | + for i, msg := range req.Messages { |
| 130 | + role := msg.Role |
| 131 | + attrs = append(attrs, attribute.String(openinference.InputMessageAttribute(i, openinference.MessageRole), string(role))) |
| 132 | + switch content := msg.Content; { |
| 133 | + case content.Text != "": |
| 134 | + maybeRedacted := content.Text |
| 135 | + if config.HideInputText { |
| 136 | + maybeRedacted = openinference.RedactedValue |
| 137 | + } |
| 138 | + attrs = append(attrs, attribute.String(openinference.InputMessageAttribute(i, openinference.MessageContent), maybeRedacted)) |
| 139 | + case content.Array != nil: |
| 140 | + for j, param := range content.Array { |
| 141 | + switch { |
| 142 | + case param.Text != nil: |
| 143 | + maybeRedacted := param.Text.Text |
| 144 | + if config.HideInputText { |
| 145 | + maybeRedacted = openinference.RedactedValue |
| 146 | + } |
| 147 | + attrs = append(attrs, |
| 148 | + attribute.String(openinference.InputMessageContentAttribute(i, j, "text"), maybeRedacted), |
| 149 | + attribute.String(openinference.InputMessageContentAttribute(i, j, "type"), "text"), |
| 150 | + ) |
| 151 | + default: |
| 152 | + // TODO: support for other content types. |
| 153 | + } |
| 154 | + } |
| 155 | + } |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + // Add indexed attributes for each tool. |
| 160 | + for i, tool := range req.Tools { |
| 161 | + if toolJSON, err := json.Marshal(tool); err == nil { |
| 162 | + attrs = append(attrs, |
| 163 | + attribute.String(fmt.Sprintf("%s.%d.tool.json_schema", openinference.LLMTools, i), string(toolJSON)), |
| 164 | + ) |
| 165 | + } |
| 166 | + } |
| 167 | + return attrs |
| 168 | +} |
| 169 | + |
| 170 | +func buildResponseAttributes(resp *anthropic.MessagesResponse, config *openinference.TraceConfig) []attribute.KeyValue { |
| 171 | + attrs := []attribute.KeyValue{ |
| 172 | + attribute.String(openinference.LLMModelName, resp.Model), |
| 173 | + } |
| 174 | + |
| 175 | + if !config.HideOutputs { |
| 176 | + attrs = append(attrs, attribute.String(openinference.OutputMimeType, openinference.MimeTypeJSON)) |
| 177 | + } |
| 178 | + |
| 179 | + // Note: compound match here is from Python OpenInference OpenAI config.py. |
| 180 | + role := resp.Role |
| 181 | + if !config.HideOutputs && !config.HideOutputMessages { |
| 182 | + for i, content := range resp.Content { |
| 183 | + attrs = append(attrs, attribute.String(openinference.OutputMessageAttribute(i, openinference.MessageRole), string(role))) |
| 184 | + |
| 185 | + switch { |
| 186 | + case content.Text != nil: |
| 187 | + txt := content.Text.Text |
| 188 | + if config.HideOutputText { |
| 189 | + txt = openinference.RedactedValue |
| 190 | + } |
| 191 | + attrs = append(attrs, attribute.String(openinference.OutputMessageAttribute(i, openinference.MessageContent), txt)) |
| 192 | + case content.Tool != nil: |
| 193 | + tool := content.Tool |
| 194 | + attrs = append(attrs, |
| 195 | + attribute.String(openinference.OutputMessageToolCallAttribute(i, 0, openinference.ToolCallID), tool.ID), |
| 196 | + attribute.String(openinference.OutputMessageToolCallAttribute(i, 0, openinference.ToolCallFunctionName), tool.Name), |
| 197 | + ) |
| 198 | + inputStr, err := json.Marshal(tool.Input) |
| 199 | + if err == nil { |
| 200 | + attrs = append(attrs, |
| 201 | + attribute.String(openinference.OutputMessageToolCallAttribute(i, 0, openinference.ToolCallFunctionArguments), string(inputStr)), |
| 202 | + ) |
| 203 | + } |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + // Token counts are considered metadata and are still included even when output content is hidden. |
| 209 | + u := resp.Usage |
| 210 | + cost := metrics.ExtractTokenUsageFromAnthropic( |
| 211 | + int64(u.InputTokens), |
| 212 | + int64(u.OutputTokens), |
| 213 | + int64(u.CacheReadInputTokens), |
| 214 | + int64(u.CacheCreationInputTokens), |
| 215 | + ) |
| 216 | + input, _ := cost.InputTokens() |
| 217 | + cache, _ := cost.CachedInputTokens() |
| 218 | + output, _ := cost.OutputTokens() |
| 219 | + total, _ := cost.TotalTokens() |
| 220 | + |
| 221 | + attrs = append(attrs, |
| 222 | + attribute.Int(openinference.LLMTokenCountPrompt, int(input)), |
| 223 | + attribute.Int(openinference.LLMTokenCountPromptCacheHit, int(cache)), |
| 224 | + attribute.Int(openinference.LLMTokenCountCompletion, int(output)), |
| 225 | + attribute.Int(openinference.LLMTokenCountTotal, int(total)), |
| 226 | + ) |
| 227 | + return attrs |
| 228 | +} |
| 229 | + |
| 230 | +// convertSSEToResponse converts a complete SSE stream to a single JSON-encoded |
| 231 | +// openai.ChatCompletionResponse. This will not serialize zero values including |
| 232 | +// fields whose values are zero or empty, or nested objects where all fields |
| 233 | +// have zero values. |
| 234 | +// |
| 235 | +// TODO: This can be refactored in "streaming" in stateful way without asking for all chunks at once. |
| 236 | +// That would reduce a slice allocation for events. |
| 237 | +// TODO Or, even better, we can make the chunk version of buildResponseAttributes which accepts a single |
| 238 | +// openai.ChatCompletionResponseChunk one at a time, and then we won't need to accumulate all chunks |
| 239 | +// in memory. |
| 240 | +func convertSSEToResponse(chunks []*anthropic.MessagesStreamChunk) *anthropic.MessagesResponse { |
| 241 | + var response anthropic.MessagesResponse |
| 242 | + toolInputs := make(map[int]string) |
| 243 | + |
| 244 | + for _, event := range chunks { |
| 245 | + switch { |
| 246 | + case event.MessageStart != nil: |
| 247 | + response = *(*anthropic.MessagesResponse)(event.MessageStart) |
| 248 | + // Ensure Content is initialized if nil. |
| 249 | + if response.Content == nil { |
| 250 | + response.Content = []anthropic.MessagesContentBlock{} |
| 251 | + } |
| 252 | + |
| 253 | + case event.MessageDelta != nil: |
| 254 | + delta := event.MessageDelta |
| 255 | + if response.Usage == nil { |
| 256 | + response.Usage = &delta.Usage |
| 257 | + } else { |
| 258 | + // Usage is cumulative for output tokens in message_delta. |
| 259 | + // Input tokens are usually in message_start. |
| 260 | + response.Usage.OutputTokens = delta.Usage.OutputTokens |
| 261 | + } |
| 262 | + response.StopReason = &delta.Delta.StopReason |
| 263 | + response.StopSequence = &delta.Delta.StopSequence |
| 264 | + |
| 265 | + case event.ContentBlockStart != nil: |
| 266 | + idx := event.ContentBlockStart.Index |
| 267 | + // Grow slice if needed. |
| 268 | + if idx >= len(response.Content) { |
| 269 | + newContent := make([]anthropic.MessagesContentBlock, idx+1) |
| 270 | + copy(newContent, response.Content) |
| 271 | + response.Content = newContent |
| 272 | + } |
| 273 | + response.Content[idx] = event.ContentBlockStart.ContentBlock |
| 274 | + |
| 275 | + case event.ContentBlockDelta != nil: |
| 276 | + idx := event.ContentBlockDelta.Index |
| 277 | + if idx < len(response.Content) { |
| 278 | + block := &response.Content[idx] |
| 279 | + delta := event.ContentBlockDelta.Delta |
| 280 | + |
| 281 | + if block.Text != nil && delta.Text != "" { |
| 282 | + block.Text.Text += delta.Text |
| 283 | + } |
| 284 | + if block.Tool != nil && delta.PartialJSON != "" { |
| 285 | + toolInputs[idx] += delta.PartialJSON |
| 286 | + } |
| 287 | + if block.Thinking != nil { |
| 288 | + if delta.Thinking != "" { |
| 289 | + block.Thinking.Thinking += delta.Thinking |
| 290 | + } |
| 291 | + if delta.Signature != "" { |
| 292 | + block.Thinking.Signature = delta.Signature |
| 293 | + } |
| 294 | + } |
| 295 | + } |
| 296 | + |
| 297 | + case event.ContentBlockStop != nil: |
| 298 | + idx := event.ContentBlockStop.Index |
| 299 | + if jsonStr, ok := toolInputs[idx]; ok { |
| 300 | + if idx < len(response.Content) && response.Content[idx].Tool != nil { |
| 301 | + var input map[string]any |
| 302 | + if err := json.Unmarshal([]byte(jsonStr), &input); err == nil { |
| 303 | + response.Content[idx].Tool.Input = input |
| 304 | + } |
| 305 | + } |
| 306 | + delete(toolInputs, idx) |
| 307 | + } |
| 308 | + |
| 309 | + case event.MessageStop != nil: |
| 310 | + // Nothing to do. |
| 311 | + } |
| 312 | + } |
| 313 | + return &response |
| 314 | +} |
0 commit comments