Skip to content

Commit 0ecee22

Browse files
Initial implementation of AI Guard SDK
1 parent 050cad8 commit 0ecee22

File tree

16 files changed

+1375
-1
lines changed

16 files changed

+1375
-1
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
# @DataDog/asm-java (AppSec/IAST)
4848
/buildSrc/call-site-instrumentation-plugin/ @DataDog/asm-java
49+
/dd-java-agent/agent-aiguard/ @DataDog/asm-java
4950
/dd-java-agent/agent-iast/ @DataDog/asm-java
5051
/dd-java-agent/appsec/appsec-test-fixtures/ @DataDog/asm-java
5152
/dd-java-agent/instrumentation/*iast* @DataDog/asm-java
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
plugins {
2+
id 'com.gradleup.shadow'
3+
}
4+
5+
apply from: "$rootDir/gradle/java.gradle"
6+
apply from: "$rootDir/gradle/version.gradle"
7+
8+
java {
9+
sourceCompatibility = JavaVersion.VERSION_1_8
10+
targetCompatibility = JavaVersion.VERSION_1_8
11+
}
12+
13+
dependencies {
14+
api libs.slf4j
15+
implementation libs.moshi
16+
implementation libs.okhttp
17+
18+
api project(':dd-trace-api')
19+
implementation project(':internal-api')
20+
implementation project(':communication')
21+
22+
testImplementation project(':utils:test-utils')
23+
testImplementation('org.skyscreamer:jsonassert:1.5.1')
24+
}
25+
26+
shadowJar {
27+
dependencies deps.excludeShared
28+
}
29+
30+
jar {
31+
archiveClassifier = 'unbundled'
32+
}
33+
34+
ext {
35+
minimumBranchCoverage = 0.6
36+
minimumInstructionCoverage = 0.8
37+
excludedClassesCoverage = []
38+
excludedClassesBranchCoverage = []
39+
excludedClassesInstructionCoverage = []
40+
}
41+
42+
spotless {
43+
java {
44+
target 'src/**/*.java'
45+
}
46+
}
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
package com.datadog.aiguard;
2+
3+
import com.squareup.moshi.JsonReader;
4+
import com.squareup.moshi.JsonWriter;
5+
import com.squareup.moshi.Moshi;
6+
import datadog.communication.http.OkHttpUtils;
7+
import datadog.trace.api.Config;
8+
import datadog.trace.api.aiguard.AIGuard;
9+
import datadog.trace.api.aiguard.AIGuard.AIGuardAbortError;
10+
import datadog.trace.api.aiguard.AIGuard.AIGuardClientError;
11+
import datadog.trace.api.aiguard.AIGuard.Action;
12+
import datadog.trace.api.aiguard.AIGuard.Evaluation;
13+
import datadog.trace.api.aiguard.AIGuard.Message;
14+
import datadog.trace.api.aiguard.AIGuard.Options;
15+
import datadog.trace.api.aiguard.AIGuard.ToolCall;
16+
import datadog.trace.api.aiguard.AIGuard.ToolCall.Function;
17+
import datadog.trace.api.aiguard.Evaluator;
18+
import datadog.trace.api.aiguard.noop.NoOpEvaluator;
19+
import datadog.trace.bootstrap.instrumentation.api.AgentScope;
20+
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
21+
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
22+
import java.io.IOException;
23+
import java.util.Collection;
24+
import java.util.HashMap;
25+
import java.util.List;
26+
import java.util.Map;
27+
import java.util.stream.Collectors;
28+
import javax.annotation.Nullable;
29+
import okhttp3.HttpUrl;
30+
import okhttp3.MediaType;
31+
import okhttp3.OkHttpClient;
32+
import okhttp3.Request;
33+
import okhttp3.RequestBody;
34+
import okhttp3.Response;
35+
import okhttp3.ResponseBody;
36+
import okio.BufferedSink;
37+
38+
public class AIGuardInternal implements Evaluator {
39+
40+
static final String SPAN_NAME = "ai_guard";
41+
static final String TARGET_TAG = "ai_guard.target";
42+
static final String TOOL_TAG = "ai_guard.tool";
43+
static final String ACTION_TAG = "ai_guard.action";
44+
static final String REASON_TAG = "ai_guard.reason";
45+
static final String BLOCKED_TAG = "ai_guard.blocked";
46+
static final String META_STRUCT_TAG = "ai_guard";
47+
static final String META_STRUCT_KEY = "messages";
48+
49+
public static void install() {
50+
final Config config = Config.get();
51+
final String apiKey = config.getApiKey();
52+
final String appKey = config.getApplicationKey();
53+
if (isEmpty(apiKey) || isEmpty(appKey)) {
54+
throw new RuntimeException(
55+
"AI Guard: Missing api and/or application key, use DD_API_KEY and DD_APP_KEY");
56+
}
57+
String endpoint = config.getAiGuardEndpoint();
58+
if (isEmpty(endpoint)) {
59+
endpoint = String.format("https://app.%s/api/v2/ai-guard", config.getSite());
60+
}
61+
final Map<String, String> headers = new HashMap<>(2);
62+
headers.put("DD-API-KEY", apiKey);
63+
headers.put("DD-APP-KEY", appKey);
64+
final HttpUrl url = HttpUrl.get(endpoint).newBuilder().addPathSegment("evaluate").build();
65+
final int timeout = config.getAiGuardTimeout();
66+
final OkHttpClient client = buildClient(url, timeout);
67+
Installer.install(new AIGuardInternal(url, headers, client));
68+
}
69+
70+
/** Used by tests to reset status */
71+
static void uninstall() {
72+
Installer.install(new NoOpEvaluator());
73+
}
74+
75+
private final HttpUrl url;
76+
private final Moshi moshi;
77+
private final OkHttpClient client;
78+
private final Map<String, String> meta;
79+
private final Map<String, String> headers;
80+
81+
AIGuardInternal(final HttpUrl url, final Map<String, String> headers, final OkHttpClient client) {
82+
this.url = url;
83+
this.headers = headers;
84+
this.client = client;
85+
this.moshi = new Moshi.Builder().build();
86+
final Config config = Config.get();
87+
this.meta = new HashMap<>(2);
88+
this.meta.put("service", config.getServiceName());
89+
this.meta.put("env", config.getEnv());
90+
}
91+
92+
private static List<Message> truncate(List<Message> messages) {
93+
final Config config = Config.get();
94+
if (messages.size() > config.getAiGuardMaxMessagesLength()) {
95+
messages = messages.subList(0, config.getAiGuardMaxMessagesLength());
96+
}
97+
for (int i = 0; i < messages.size(); i++) {
98+
Message source = messages.get(i);
99+
if (source.getContent() != null
100+
&& source.getContent().length() > config.getAiGuardMaxContentSize()) {
101+
source =
102+
new Message(
103+
source.getRole(),
104+
source.getContent().substring(0, config.getAiGuardMaxContentSize()),
105+
source.getToolCalls(),
106+
source.getToolCallId());
107+
messages.set(i, source);
108+
}
109+
}
110+
return messages;
111+
}
112+
113+
private static boolean isToolCall(final Message message) {
114+
return message.getToolCalls() != null || message.getToolCallId() != null;
115+
}
116+
117+
private static String getToolName(final Message current, final List<Message> messages) {
118+
if (current.getToolCalls() != null) {
119+
// assistant message with tool calls
120+
return current.getToolCalls().stream()
121+
.map(ToolCall::getFunction)
122+
.map(Function::getName)
123+
.collect(Collectors.joining(","));
124+
} else {
125+
// assistant message with tool output (search the linked tool call in reverse order)
126+
final String id = current.getToolCallId();
127+
for (int i = messages.size() - 1; i >= 0; i--) {
128+
final Message message = messages.get(i);
129+
if (message.getToolCalls() != null) {
130+
for (final ToolCall toolCall : message.getToolCalls()) {
131+
if (toolCall.getId().equals(id)) {
132+
return toolCall.getFunction() == null ? null : toolCall.getFunction().getName();
133+
}
134+
}
135+
}
136+
}
137+
return null;
138+
}
139+
}
140+
141+
private boolean isBlockingEnabled(final Object isBlockingEnabled) {
142+
return isBlockingEnabled != null && isBlockingEnabled.toString().equalsIgnoreCase("true");
143+
}
144+
145+
@Override
146+
public Evaluation evaluate(final List<Message> messages, final Options options) {
147+
if (messages == null || messages.isEmpty()) {
148+
throw new IllegalArgumentException("messages must not be empty");
149+
}
150+
final AgentTracer.TracerAPI tracer = AgentTracer.get();
151+
final AgentSpan span = tracer.buildSpan(SPAN_NAME, SPAN_NAME).start();
152+
try (final AgentScope scope = tracer.activateSpan(span)) {
153+
final Message current = messages.get(messages.size() - 1);
154+
if (isToolCall(current)) {
155+
span.setTag(TARGET_TAG, "tool");
156+
final String toolName = getToolName(current, messages);
157+
if (toolName != null) {
158+
span.setTag(TOOL_TAG, toolName);
159+
}
160+
} else {
161+
span.setTag(TARGET_TAG, "prompt");
162+
}
163+
final Map<String, Object> metaStruct = new HashMap<>(1);
164+
metaStruct.put(META_STRUCT_KEY, truncate(messages));
165+
span.setMetaStruct(META_STRUCT_TAG, metaStruct);
166+
final Request.Builder request =
167+
new Request.Builder()
168+
.url(url)
169+
.method("POST", new MoshiJsonRequestBody(moshi, messages, meta));
170+
headers.forEach(request::header);
171+
try (final Response response = client.newCall(request.build()).execute()) {
172+
final Map<String, Object> result = parseResponseBody(response);
173+
final String actionStr = (String) result.get("action");
174+
if (actionStr == null) {
175+
throw new IllegalArgumentException("action field is missing in the response");
176+
}
177+
final Action action = Action.valueOf(actionStr);
178+
final String reason = (String) result.get("reason");
179+
span.setTag(ACTION_TAG, action);
180+
span.setTag(REASON_TAG, reason);
181+
final boolean blockingEnabled = isBlockingEnabled(result.get("is_blocking_enabled"));
182+
if (blockingEnabled && options.block() && action != Action.ALLOW) {
183+
span.setTag(BLOCKED_TAG, true);
184+
throw new AIGuardAbortError(action, reason);
185+
}
186+
return new Evaluation(action, reason);
187+
}
188+
} catch (AIGuardAbortError | AIGuardClientError e) {
189+
span.addThrowable(e);
190+
throw e;
191+
} catch (final Exception e) {
192+
final AIGuardClientError error =
193+
new AIGuardClientError("AI Guard service returned unexpected response", e);
194+
span.addThrowable(error);
195+
throw error;
196+
}
197+
}
198+
199+
@SuppressWarnings("unchecked")
200+
private Map<String, Object> parseResponseBody(final Response response) throws IOException {
201+
final ResponseBody body = response.body();
202+
if (body == null) {
203+
throw fail(response.code(), null);
204+
}
205+
final JsonReader reader = JsonReader.of(body.source());
206+
final Map<?, ?> parsedBody = moshi.adapter(Map.class).fromJson(reader);
207+
final Object errors = parsedBody.get("errors");
208+
if (errors != null) {
209+
throw fail(response.code(), errors);
210+
}
211+
final Map<?, ?> data = (Map<?, ?>) parsedBody.get("data");
212+
return (Map<String, Object>) data.get("attributes");
213+
}
214+
215+
private AIGuardClientError fail(final int statusCode, final Object errors) {
216+
return new AIGuardClientError("AI Guard service call failed, status" + statusCode, errors);
217+
}
218+
219+
private static OkHttpClient buildClient(final HttpUrl url, final long timeout) {
220+
return OkHttpUtils.buildHttpClient(url, timeout).newBuilder().build();
221+
}
222+
223+
private static boolean isEmpty(final String value) {
224+
return value == null || value.isEmpty();
225+
}
226+
227+
private static class Installer extends AIGuard {
228+
public static void install(final Evaluator evaluator) {
229+
AIGuard.EVALUATOR = evaluator;
230+
}
231+
}
232+
233+
static class MoshiJsonRequestBody extends RequestBody {
234+
235+
private static final MediaType JSON = MediaType.parse("application/json");
236+
237+
private final Moshi moshi;
238+
private final Map<String, String> meta;
239+
private final Collection<Message> messages;
240+
241+
public MoshiJsonRequestBody(
242+
final Moshi moshi, final Collection<Message> messages, final Map<String, String> meta) {
243+
this.moshi = moshi;
244+
this.messages = messages;
245+
this.meta = meta;
246+
}
247+
248+
@Nullable
249+
@Override
250+
public MediaType contentType() {
251+
return JSON;
252+
}
253+
254+
@Override
255+
public void writeTo(final BufferedSink sink) throws IOException {
256+
final JsonWriter writer = JsonWriter.of(sink);
257+
writer.beginObject(); // request
258+
writer.name("data");
259+
writer.beginObject(); // data
260+
writer.name("attributes");
261+
writer.beginObject(); // attributes
262+
writer.name("messages");
263+
moshi.adapter(Object.class).toJson(writer, messages);
264+
writer.name("meta");
265+
writer.jsonValue(meta);
266+
writer.endObject(); // attributes
267+
writer.endObject(); // data
268+
writer.endObject(); // request
269+
}
270+
}
271+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package com.datadog.aiguard;
2+
3+
public abstract class AIGuardSystem {
4+
5+
private AIGuardSystem() {}
6+
7+
public static void start() {
8+
initializeSDK();
9+
}
10+
11+
private static void initializeSDK() {
12+
AIGuardInternal.install();
13+
}
14+
}

0 commit comments

Comments
 (0)