Skip to content

Commit 02cae20

Browse files
Initial implementation of the AI Guard SDK (#9628)
Initial implementation of the AI Guard SDK
1 parent 6549d20 commit 02cae20

File tree

25 files changed

+2005
-11
lines changed

25 files changed

+2005
-11
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
# @DataDog/asm-java (AppSec/IAST)
4848
/buildSrc/call-site-instrumentation-plugin/ @DataDog/asm-java
49+
/dd-java-agent/agent-aiguard/ @DataDog/asm-java
4950
/dd-java-agent/agent-iast/ @DataDog/asm-java
5051
/dd-java-agent/appsec/appsec-test-fixtures/ @DataDog/asm-java
5152
/dd-java-agent/instrumentation/*iast* @DataDog/asm-java
@@ -58,6 +59,7 @@
5859
/dd-smoke-tests/spring-security/ @DataDog/asm-java
5960
/dd-java-agent/instrumentation/commons-fileupload/ @DataDog/asm-java
6061
/dd-java-agent/instrumentation/spring/spring-security/ @DataDog/asm-java
62+
/dd-trace-api/src/main/java/datadog/trace/api/aiguard/ @DataDog/asm-java
6163
/dd-trace-api/src/main/java/datadog/trace/api/EventTracker.java @DataDog/asm-java
6264
/internal-api/src/main/java/datadog/trace/api/gateway/ @DataDog/asm-java
6365
**/appsec/ @DataDog/asm-java

communication/src/main/java/datadog/communication/serialization/Codec.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
11
package datadog.communication.serialization;
22

3+
import datadog.communication.serialization.custom.aiguard.FunctionWriter;
4+
import datadog.communication.serialization.custom.aiguard.MessageWriter;
5+
import datadog.communication.serialization.custom.aiguard.ToolCallWriter;
36
import datadog.communication.serialization.custom.stacktrace.StackTraceEventFrameWriter;
47
import datadog.communication.serialization.custom.stacktrace.StackTraceEventWriter;
8+
import datadog.trace.api.Config;
9+
import datadog.trace.api.aiguard.AIGuard;
510
import datadog.trace.util.stacktrace.StackTraceEvent;
611
import datadog.trace.util.stacktrace.StackTraceFrame;
712
import java.nio.ByteBuffer;
813
import java.nio.CharBuffer;
914
import java.util.Collection;
1015
import java.util.Collections;
16+
import java.util.HashMap;
1117
import java.util.Map;
12-
import java.util.stream.Collectors;
13-
import java.util.stream.Stream;
1418

1519
public final class Codec extends ClassValue<ValueWriter<?>> {
1620

17-
private static final Map<Class<?>, ValueWriter<?>> defaultConfig =
18-
Stream.of(
19-
new Object[][] {
20-
{StackTraceEvent.class, new StackTraceEventWriter()},
21-
{StackTraceFrame.class, new StackTraceEventFrameWriter()},
22-
})
23-
.collect(Collectors.toMap(data -> (Class<?>) data[0], data -> (ValueWriter<?>) data[1]));
21+
public static final Codec INSTANCE;
2422

25-
public static final Codec INSTANCE = new Codec(defaultConfig);
23+
static {
24+
final Map<Class<?>, ValueWriter<?>> writers = new HashMap<>(1 << 3);
25+
writers.put(StackTraceEvent.class, new StackTraceEventWriter());
26+
writers.put(StackTraceFrame.class, new StackTraceEventFrameWriter());
27+
if (Config.get().isAiGuardEnabled()) {
28+
writers.put(AIGuard.Message.class, new MessageWriter());
29+
writers.put(AIGuard.ToolCall.class, new ToolCallWriter());
30+
writers.put(AIGuard.ToolCall.Function.class, new FunctionWriter());
31+
}
32+
INSTANCE = new Codec(writers);
33+
}
2634

2735
private final Map<Class<?>, ValueWriter<?>> config;
2836

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package datadog.communication.serialization.custom.aiguard;
2+
3+
import datadog.communication.serialization.EncodingCache;
4+
import datadog.communication.serialization.ValueWriter;
5+
import datadog.communication.serialization.Writable;
6+
import datadog.trace.api.aiguard.AIGuard;
7+
8+
public class FunctionWriter implements ValueWriter<AIGuard.ToolCall.Function> {
9+
10+
@Override
11+
public void write(
12+
final AIGuard.ToolCall.Function function,
13+
final Writable writable,
14+
final EncodingCache encodingCache) {
15+
writable.startMap(2);
16+
writable.writeString("name", encodingCache);
17+
writable.writeString(function.getName(), encodingCache);
18+
writable.writeString("arguments", encodingCache);
19+
writable.writeString(function.getArguments(), encodingCache);
20+
}
21+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package datadog.communication.serialization.custom.aiguard;
2+
3+
import datadog.communication.serialization.EncodingCache;
4+
import datadog.communication.serialization.ValueWriter;
5+
import datadog.communication.serialization.Writable;
6+
import datadog.trace.api.aiguard.AIGuard;
7+
import datadog.trace.util.Strings;
8+
import java.util.List;
9+
10+
public class MessageWriter implements ValueWriter<AIGuard.Message> {
11+
12+
@Override
13+
public void write(
14+
final AIGuard.Message value, final Writable writable, final EncodingCache encodingCache) {
15+
final int[] size = {0};
16+
final boolean hasRole = isNotBlank(value.getRole(), size);
17+
final boolean hasContent = isNotBlank(value.getContent(), size);
18+
final boolean hasToolCallId = isNotBlank(value.getToolCallId(), size);
19+
final boolean hasToolCalls = isNotEmpty(value.getToolCalls(), size);
20+
writable.startMap(size[0]);
21+
writeString(hasRole, "role", value.getRole(), writable, encodingCache);
22+
writeString(hasContent, "content", value.getContent(), writable, encodingCache);
23+
writeString(hasToolCallId, "tool_call_id", value.getToolCallId(), writable, encodingCache);
24+
writeToolCallArray(hasToolCalls, "tool_calls", value.getToolCalls(), writable, encodingCache);
25+
}
26+
27+
private static void writeString(
28+
final boolean present,
29+
final String key,
30+
final String value,
31+
final Writable writable,
32+
final EncodingCache encodingCache) {
33+
if (present) {
34+
writable.writeString(key, encodingCache);
35+
writable.writeString(value, encodingCache);
36+
}
37+
}
38+
39+
private static void writeToolCallArray(
40+
final boolean present,
41+
final String key,
42+
final List<AIGuard.ToolCall> values,
43+
final Writable writable,
44+
final EncodingCache encodingCache) {
45+
if (present) {
46+
writable.writeString(key, encodingCache);
47+
writable.writeObject(values, encodingCache);
48+
}
49+
}
50+
51+
private static boolean isNotBlank(final String value, final int[] nonBlankCount) {
52+
final boolean hasText = Strings.isNotBlank(value);
53+
if (hasText) {
54+
nonBlankCount[0]++;
55+
}
56+
return hasText;
57+
}
58+
59+
private static boolean isNotEmpty(final List<?> value, final int[] nonEmptyCount) {
60+
final boolean nonEmpty = value != null && !value.isEmpty();
61+
if (nonEmpty) {
62+
nonEmptyCount[0]++;
63+
}
64+
return nonEmpty;
65+
}
66+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package datadog.communication.serialization.custom.aiguard;
2+
3+
import datadog.communication.serialization.EncodingCache;
4+
import datadog.communication.serialization.ValueWriter;
5+
import datadog.communication.serialization.Writable;
6+
import datadog.trace.api.aiguard.AIGuard;
7+
8+
public class ToolCallWriter implements ValueWriter<AIGuard.ToolCall> {
9+
10+
@Override
11+
public void write(
12+
final AIGuard.ToolCall value, final Writable writable, final EncodingCache encodingCache) {
13+
writable.startMap(2);
14+
writable.writeString("id", encodingCache);
15+
writable.writeString(value.getId(), encodingCache);
16+
writable.writeString("function", encodingCache);
17+
writable.writeObject(value.getFunction(), encodingCache);
18+
}
19+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package datadog.communication.serialization.aiguard
2+
3+
import datadog.communication.serialization.EncodingCache
4+
import datadog.communication.serialization.GrowableBuffer
5+
import datadog.communication.serialization.msgpack.MsgPackWriter
6+
import datadog.trace.api.aiguard.AIGuard
7+
import datadog.trace.test.util.DDSpecification
8+
import org.msgpack.core.MessagePack
9+
import org.msgpack.value.Value
10+
11+
import java.nio.charset.StandardCharsets
12+
import java.util.function.Function
13+
14+
class MessageWriterTest extends DDSpecification {
15+
16+
private EncodingCache encodingCache
17+
private GrowableBuffer buffer
18+
private MsgPackWriter writer
19+
20+
void setup() {
21+
injectSysConfig('ai_guard.enabled', 'true')
22+
final HashMap<CharSequence, byte[]> cache = new HashMap<>()
23+
encodingCache = new EncodingCache() {
24+
@Override
25+
byte[] encode(CharSequence chars) {
26+
cache.computeIfAbsent(chars, s -> s.toString().getBytes(StandardCharsets.UTF_8))
27+
}
28+
}
29+
buffer = new GrowableBuffer(1024)
30+
writer = new MsgPackWriter(buffer)
31+
}
32+
33+
void 'test write message'() {
34+
given:
35+
final message = AIGuard.Message.message('user', 'What day is today?')
36+
37+
when:
38+
writer.writeObject(message, encodingCache)
39+
40+
then:
41+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
42+
final value = asStringValueMap(unpacker.unpackValue())
43+
value.size() == 2
44+
value.role == 'user'
45+
value.content == 'What day is today?'
46+
}
47+
}
48+
49+
void 'test write tool call'() {
50+
given:
51+
final message =
52+
AIGuard.Message.assistant(
53+
AIGuard.ToolCall.toolCall('call_1', 'function_1', 'args_1'),
54+
AIGuard.ToolCall.toolCall('call_2', 'function_2', 'args_2'))
55+
56+
when:
57+
writer.writeObject(message, encodingCache)
58+
59+
then:
60+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
61+
final value = asStringKeyMap(unpacker.unpackValue())
62+
value.size() == 2
63+
asString(value.role) == 'assistant'
64+
65+
final toolCalls = value.get('tool_calls').asArrayValue().list()
66+
toolCalls.size() == 2
67+
68+
final firstCall = asStringKeyMap(toolCalls[0])
69+
asString(firstCall.id) == 'call_1'
70+
final firstFunction = asStringValueMap(firstCall.function)
71+
firstFunction.name == 'function_1'
72+
firstFunction.arguments == 'args_1'
73+
74+
final secondCall = asStringKeyMap(toolCalls[1])
75+
asString(secondCall.id) == 'call_2'
76+
final secondFunction = asStringValueMap(secondCall.function)
77+
secondFunction.name == 'function_2'
78+
secondFunction.arguments == 'args_2'
79+
}
80+
}
81+
82+
void 'test write tool output'() throws IOException {
83+
given:
84+
final message = AIGuard.Message.tool('call_1', 'output')
85+
86+
when:
87+
writer.writeObject(message, encodingCache)
88+
89+
then:
90+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
91+
final value = asStringValueMap(unpacker.unpackValue())
92+
value.size() == 3
93+
value.role == 'tool'
94+
value.tool_call_id == 'call_1'
95+
value.content == 'output'
96+
}
97+
}
98+
99+
private static <K, V> Map<K, V> mapValue(
100+
final Value values,
101+
final Function<Value, K> keyMapper,
102+
final Function<Value, V> valueMapper) {
103+
return values.asMapValue().entrySet().collectEntries {
104+
[(keyMapper.apply(it.key)): valueMapper.apply(it.value)]
105+
}
106+
}
107+
108+
private static Map<String, Value> asStringKeyMap(final Value values) {
109+
return mapValue(values, MessageWriterTest::asString, Function.identity())
110+
}
111+
112+
private static Map<String, String> asStringValueMap(final Value values) {
113+
return mapValue(values, MessageWriterTest::asString, MessageWriterTest::asString)
114+
}
115+
116+
private static String asString(final Value value) {
117+
return value.asStringValue().asString()
118+
}
119+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar
2+
3+
plugins {
4+
id 'com.gradleup.shadow'
5+
}
6+
7+
apply from: "$rootDir/gradle/java.gradle"
8+
apply from: "$rootDir/gradle/version.gradle"
9+
10+
java {
11+
sourceCompatibility = JavaVersion.VERSION_1_8
12+
targetCompatibility = JavaVersion.VERSION_1_8
13+
}
14+
15+
dependencies {
16+
api libs.slf4j
17+
implementation libs.moshi
18+
implementation libs.okhttp
19+
20+
api project(':dd-trace-api')
21+
implementation project(':internal-api')
22+
implementation project(':communication')
23+
24+
testImplementation project(':utils:test-utils')
25+
testImplementation('org.skyscreamer:jsonassert:1.5.3')
26+
testImplementation('com.fasterxml.jackson.core:jackson-databind:2.20.0')
27+
}
28+
29+
tasks.named("shadowJar", ShadowJar) {
30+
dependencies deps.excludeShared
31+
}
32+
33+
tasks.named("jar", Jar) {
34+
archiveClassifier = 'unbundled'
35+
}
36+

0 commit comments

Comments
 (0)