|
| 1 | +package datadog.trace.llmobs.writer.ddintake |
| 2 | + |
| 3 | +import datadog.trace.common.writer.ListWriter |
| 4 | +import datadog.communication.serialization.ByteBufferConsumer |
| 5 | +import datadog.communication.serialization.FlushingBuffer |
| 6 | +import datadog.communication.serialization.msgpack.MsgPackWriter |
| 7 | +import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes |
| 8 | +import datadog.trace.bootstrap.instrumentation.api.Tags |
| 9 | +import datadog.trace.api.llmobs.LLMObsTags |
| 10 | +import datadog.trace.api.DDTags |
| 11 | +import datadog.trace.core.DDSpan |
| 12 | +import datadog.trace.core.test.DDCoreSpecification |
| 13 | +import org.msgpack.core.MessagePack |
| 14 | +import org.msgpack.core.MessageUnpacker |
| 15 | + |
| 16 | +import java.nio.ByteBuffer |
| 17 | + |
| 18 | +class LLMObsSpanMapperTest extends DDCoreSpecification { |
| 19 | + |
| 20 | + def "test LLMObs span mapper"() { |
| 21 | + setup: |
| 22 | + def tracer = tracerBuilder().writer(new ListWriter()).build() |
| 23 | + DDSpan span = (DDSpan) tracer.buildSpan("llm-operation") |
| 24 | + .withServiceName("my-llm-service") |
| 25 | + .withSpanType(InternalSpanTypes.LLMOBS) |
| 26 | + .start() |
| 27 | + |
| 28 | + // Add LLM-specific tags with proper prefixes |
| 29 | + span.setTag("_ml_obs_tag.span.kind", Tags.LLMOBS_WORKFLOW_SPAN_KIND) |
| 30 | + span.setTag("_ml_obs_tag." + LLMObsTags.MODEL_NAME, "gpt-4") |
| 31 | + span.setTag("_ml_obs_tag." + LLMObsTags.MODEL_PROVIDER, "openai") |
| 32 | + span.setTag("_ml_obs_tag.input", "What is the weather?") |
| 33 | + span.setTag("_ml_obs_tag.output", "It's sunny today.") |
| 34 | + span.setTag("_ml_obs_tag.custom_tag", "test-value") |
| 35 | + span.setTag("_ml_obs_metric.input_tokens", 10.0) |
| 36 | + span.setTag("_ml_obs_metric.output_tokens", 5.0) |
| 37 | + span.setTag("_ml_obs_metric.total_cost", 0.005) |
| 38 | + |
| 39 | + // Add some metadata |
| 40 | + Map<String, Object> metadata = [ |
| 41 | + "temperature": 0.7, |
| 42 | + "max_tokens": 100 |
| 43 | + ] |
| 44 | + span.setTag("_ml_obs_tag." + LLMObsTags.METADATA, metadata) |
| 45 | + |
| 46 | + def trace = [span] |
| 47 | + |
| 48 | + when: |
| 49 | + LLMObsSpanMapper spanMapper = new LLMObsSpanMapper() |
| 50 | + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() |
| 51 | + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(1024, sink)) |
| 52 | + packer.format(trace, spanMapper) |
| 53 | + packer.flush() |
| 54 | + |
| 55 | + then: |
| 56 | + sink.captured != null |
| 57 | + |
| 58 | + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(sink.captured) |
| 59 | + |
| 60 | + int topLevelMapSize = unpacker.unpackMapHeader() |
| 61 | + topLevelMapSize == 3 |
| 62 | + |
| 63 | + Map<String, Object> topLevel = [:] |
| 64 | + for (int i = 0; i < topLevelMapSize; i++) { |
| 65 | + String key = unpacker.unpackString() |
| 66 | + if (key == "event_type") { |
| 67 | + topLevel[key] = unpacker.unpackString() |
| 68 | + } else if (key == "_dd.stage") { |
| 69 | + topLevel[key] = unpacker.unpackString() |
| 70 | + } else if (key == "spans") { |
| 71 | + int spansArraySize = unpacker.unpackArrayHeader() |
| 72 | + topLevel[key] = spansArraySize |
| 73 | + unpacker.skipValue() // TODO: add check for span data |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + topLevel["event_type"] == "span" |
| 78 | + topLevel["_dd.stage"] == "raw" |
| 79 | + topLevel["spans"] == 1 // Should have 1 span |
| 80 | + |
| 81 | + cleanup: |
| 82 | + tracer.close() |
| 83 | + } |
| 84 | + |
| 85 | + def "test non-LLMObs span is filtered out"() { |
| 86 | + setup: |
| 87 | + def tracer = tracerBuilder().writer(new ListWriter()).build() |
| 88 | + DDSpan regularSpan = (DDSpan) tracer.buildSpan("regular-operation") |
| 89 | + .withServiceName("my-service") |
| 90 | + .start() |
| 91 | + |
| 92 | + def trace = [regularSpan] |
| 93 | + |
| 94 | + when: |
| 95 | + LLMObsSpanMapper spanMapper = new LLMObsSpanMapper() |
| 96 | + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() |
| 97 | + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(1024, sink)) |
| 98 | + packer.format(trace, spanMapper) |
| 99 | + packer.flush() |
| 100 | + |
| 101 | + then: |
| 102 | + sink.captured != null |
| 103 | + |
| 104 | + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(sink.captured) |
| 105 | + |
| 106 | + int topLevelMapSize = unpacker.unpackMapHeader() |
| 107 | + topLevelMapSize == 3 |
| 108 | + |
| 109 | + Map<String, Object> topLevel = [:] |
| 110 | + for (int i = 0; i < topLevelMapSize; i++) { |
| 111 | + String key = unpacker.unpackString() |
| 112 | + if (key == "event_type") { |
| 113 | + topLevel[key] = unpacker.unpackString() |
| 114 | + } else if (key == "_dd.stage") { |
| 115 | + topLevel[key] = unpacker.unpackString() |
| 116 | + } else if (key == "spans") { |
| 117 | + int spansArraySize = unpacker.unpackArrayHeader() |
| 118 | + topLevel[key] = spansArraySize |
| 119 | + // Since array is empty, no need to skip anything |
| 120 | + } |
| 121 | + } |
| 122 | + |
| 123 | + // Verify that no spans are included since regular span is filtered out |
| 124 | + topLevel["spans"] == 0 |
| 125 | + |
| 126 | + cleanup: |
| 127 | + tracer.close() |
| 128 | + } |
| 129 | + |
| 130 | + def "test LLM span with error"() { |
| 131 | + setup: |
| 132 | + def tracer = tracerBuilder().writer(new ListWriter()).build() |
| 133 | + DDSpan span = (DDSpan) tracer.buildSpan("llm-operation") |
| 134 | + .withServiceName("my-llm-service") |
| 135 | + .withSpanType(InternalSpanTypes.LLMOBS) |
| 136 | + .start() |
| 137 | + |
| 138 | + span.setTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) |
| 139 | + span.setError(true) |
| 140 | + span.setTag(DDTags.ERROR_MSG, "API rate limit exceeded") |
| 141 | + span.setTag(DDTags.ERROR_TYPE, "RateLimitError") |
| 142 | + span.setTag(DDTags.ERROR_STACK, "java.lang.RuntimeException: API rate limit exceeded") |
| 143 | + |
| 144 | + def trace = [span] |
| 145 | + |
| 146 | + when: |
| 147 | + LLMObsSpanMapper spanMapper = new LLMObsSpanMapper() |
| 148 | + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() |
| 149 | + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(1024, sink)) |
| 150 | + packer.format(trace, spanMapper) |
| 151 | + packer.flush() |
| 152 | + |
| 153 | + then: |
| 154 | + sink.captured != null |
| 155 | + |
| 156 | + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(sink.captured) |
| 157 | + |
| 158 | + int topLevelMapSize = unpacker.unpackMapHeader() |
| 159 | + topLevelMapSize == 3 |
| 160 | + |
| 161 | + Map<String, Object> topLevel = [:] |
| 162 | + for (int i = 0; i < topLevelMapSize; i++) { |
| 163 | + String key = unpacker.unpackString() |
| 164 | + if (key == "event_type") { |
| 165 | + topLevel[key] = unpacker.unpackString() |
| 166 | + } else if (key == "_dd.stage") { |
| 167 | + topLevel[key] = unpacker.unpackString() |
| 168 | + } else if (key == "spans") { |
| 169 | + int spansArraySize = unpacker.unpackArrayHeader() |
| 170 | + topLevel[key] = spansArraySize |
| 171 | + |
| 172 | + // Parse the spans array to check error information |
| 173 | + for (int spanIndex = 0; spanIndex < spansArraySize; spanIndex++) { |
| 174 | + int spanMapSize = unpacker.unpackMapHeader() |
| 175 | + spanMapSize == 11 |
| 176 | + |
| 177 | + Map<String, Object> spanData = [:] |
| 178 | + for (int fieldIndex = 0; fieldIndex < spanMapSize; fieldIndex++) { |
| 179 | + String fieldKey = unpacker.unpackString() |
| 180 | + if (fieldKey == "error") { |
| 181 | + spanData[fieldKey] = unpacker.unpackInt() |
| 182 | + } else if (fieldKey == "status") { |
| 183 | + spanData[fieldKey] = unpacker.unpackString() |
| 184 | + } else if (fieldKey == "meta") { |
| 185 | + int metaMapSize = unpacker.unpackMapHeader() |
| 186 | + Map<String, String> metaMap = [:] |
| 187 | + for (int metaIndex = 0; metaIndex < metaMapSize; metaIndex++) { |
| 188 | + String metaKey = unpacker.unpackString() |
| 189 | + String metaValue = unpacker.unpackString() |
| 190 | + metaMap[metaKey] = metaValue |
| 191 | + } |
| 192 | + spanData[fieldKey] = metaMap |
| 193 | + } else { |
| 194 | + // Skip other fields |
| 195 | + unpacker.skipValue() |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + // Verify error information |
| 200 | + spanData["error"] == 1 |
| 201 | + spanData["status"] == "error" |
| 202 | + |
| 203 | + Map<String, String> meta = (Map<String, String>) spanData["meta"] |
| 204 | + meta[DDTags.ERROR_MSG] == "API rate limit exceeded" |
| 205 | + meta[DDTags.ERROR_TYPE] == "RateLimitError" |
| 206 | + meta[DDTags.ERROR_STACK] == "java.lang.RuntimeException: API rate limit exceeded" |
| 207 | + } |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + topLevel["spans"] == 1 |
| 212 | + |
| 213 | + cleanup: |
| 214 | + tracer.close() |
| 215 | + } |
| 216 | + |
| 217 | + static class CapturingByteBufferConsumer implements ByteBufferConsumer { |
| 218 | + |
| 219 | + ByteBuffer captured |
| 220 | + |
| 221 | + @Override |
| 222 | + void accept(int messageCount, ByteBuffer buffer) { |
| 223 | + captured = buffer |
| 224 | + } |
| 225 | + } |
| 226 | +} |
0 commit comments