diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index debd9cef7c5..b24ee918edb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -46,6 +46,7 @@ # @DataDog/asm-java (AppSec/IAST) /buildSrc/call-site-instrumentation-plugin/ @DataDog/asm-java +/dd-java-agent/agent-aiguard/ @DataDog/asm-java /dd-java-agent/agent-iast/ @DataDog/asm-java /dd-java-agent/appsec/appsec-test-fixtures/ @DataDog/asm-java /dd-java-agent/instrumentation/*iast* @DataDog/asm-java @@ -58,6 +59,7 @@ /dd-smoke-tests/spring-security/ @DataDog/asm-java /dd-java-agent/instrumentation/commons-fileupload/ @DataDog/asm-java /dd-java-agent/instrumentation/spring/spring-security/ @DataDog/asm-java +/dd-trace-api/src/main/java/datadog/trace/api/aiguard/ @DataDog/asm-java /dd-trace-api/src/main/java/datadog/trace/api/EventTracker.java @DataDog/asm-java /internal-api/src/main/java/datadog/trace/api/gateway/ @DataDog/asm-java **/appsec/ @DataDog/asm-java diff --git a/communication/src/main/java/datadog/communication/serialization/Codec.java b/communication/src/main/java/datadog/communication/serialization/Codec.java index e6f92af6076..b0eb5530827 100644 --- a/communication/src/main/java/datadog/communication/serialization/Codec.java +++ b/communication/src/main/java/datadog/communication/serialization/Codec.java @@ -1,28 +1,36 @@ package datadog.communication.serialization; +import datadog.communication.serialization.custom.aiguard.FunctionWriter; +import datadog.communication.serialization.custom.aiguard.MessageWriter; +import datadog.communication.serialization.custom.aiguard.ToolCallWriter; import datadog.communication.serialization.custom.stacktrace.StackTraceEventFrameWriter; import datadog.communication.serialization.custom.stacktrace.StackTraceEventWriter; +import datadog.trace.api.Config; +import datadog.trace.api.aiguard.AIGuard; import datadog.trace.util.stacktrace.StackTraceEvent; import datadog.trace.util.stacktrace.StackTraceFrame; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; public final class Codec extends ClassValue> { - private static final Map, ValueWriter> defaultConfig = - Stream.of( - new Object[][] { - {StackTraceEvent.class, new StackTraceEventWriter()}, - {StackTraceFrame.class, new StackTraceEventFrameWriter()}, - }) - .collect(Collectors.toMap(data -> (Class) data[0], data -> (ValueWriter) data[1])); + public static final Codec INSTANCE; - public static final Codec INSTANCE = new Codec(defaultConfig); + static { + final Map, ValueWriter> writers = new HashMap<>(1 << 3); + writers.put(StackTraceEvent.class, new StackTraceEventWriter()); + writers.put(StackTraceFrame.class, new StackTraceEventFrameWriter()); + if (Config.get().isAiGuardEnabled()) { + writers.put(AIGuard.Message.class, new MessageWriter()); + writers.put(AIGuard.ToolCall.class, new ToolCallWriter()); + writers.put(AIGuard.ToolCall.Function.class, new FunctionWriter()); + } + INSTANCE = new Codec(writers); + } private final Map, ValueWriter> config; diff --git a/communication/src/main/java/datadog/communication/serialization/custom/aiguard/FunctionWriter.java b/communication/src/main/java/datadog/communication/serialization/custom/aiguard/FunctionWriter.java new file mode 100644 index 00000000000..8db825f4d04 --- /dev/null +++ b/communication/src/main/java/datadog/communication/serialization/custom/aiguard/FunctionWriter.java @@ -0,0 +1,21 @@ +package datadog.communication.serialization.custom.aiguard; + +import datadog.communication.serialization.EncodingCache; +import datadog.communication.serialization.ValueWriter; +import datadog.communication.serialization.Writable; +import datadog.trace.api.aiguard.AIGuard; + +public class FunctionWriter implements ValueWriter { + + @Override + public void write( + final AIGuard.ToolCall.Function function, + final Writable writable, + final EncodingCache encodingCache) { + writable.startMap(2); + writable.writeString("name", encodingCache); + writable.writeString(function.getName(), encodingCache); + writable.writeString("arguments", encodingCache); + writable.writeString(function.getArguments(), encodingCache); + } +} diff --git a/communication/src/main/java/datadog/communication/serialization/custom/aiguard/MessageWriter.java b/communication/src/main/java/datadog/communication/serialization/custom/aiguard/MessageWriter.java new file mode 100644 index 00000000000..a1b089fbb1c --- /dev/null +++ b/communication/src/main/java/datadog/communication/serialization/custom/aiguard/MessageWriter.java @@ -0,0 +1,66 @@ +package datadog.communication.serialization.custom.aiguard; + +import datadog.communication.serialization.EncodingCache; +import datadog.communication.serialization.ValueWriter; +import datadog.communication.serialization.Writable; +import datadog.trace.api.aiguard.AIGuard; +import datadog.trace.util.Strings; +import java.util.List; + +public class MessageWriter implements ValueWriter { + + @Override + public void write( + final AIGuard.Message value, final Writable writable, final EncodingCache encodingCache) { + final int[] size = {0}; + final boolean hasRole = isNotBlank(value.getRole(), size); + final boolean hasContent = isNotBlank(value.getContent(), size); + final boolean hasToolCallId = isNotBlank(value.getToolCallId(), size); + final boolean hasToolCalls = isNotEmpty(value.getToolCalls(), size); + writable.startMap(size[0]); + writeString(hasRole, "role", value.getRole(), writable, encodingCache); + writeString(hasContent, "content", value.getContent(), writable, encodingCache); + writeString(hasToolCallId, "tool_call_id", value.getToolCallId(), writable, encodingCache); + writeToolCallArray(hasToolCalls, "tool_calls", value.getToolCalls(), writable, encodingCache); + } + + private static void writeString( + final boolean present, + final String key, + final String value, + final Writable writable, + final EncodingCache encodingCache) { + if (present) { + writable.writeString(key, encodingCache); + writable.writeString(value, encodingCache); + } + } + + private static void writeToolCallArray( + final boolean present, + final String key, + final List values, + final Writable writable, + final EncodingCache encodingCache) { + if (present) { + writable.writeString(key, encodingCache); + writable.writeObject(values, encodingCache); + } + } + + private static boolean isNotBlank(final String value, final int[] nonBlankCount) { + final boolean hasText = Strings.isNotBlank(value); + if (hasText) { + nonBlankCount[0]++; + } + return hasText; + } + + private static boolean isNotEmpty(final List value, final int[] nonEmptyCount) { + final boolean nonEmpty = value != null && !value.isEmpty(); + if (nonEmpty) { + nonEmptyCount[0]++; + } + return nonEmpty; + } +} diff --git a/communication/src/main/java/datadog/communication/serialization/custom/aiguard/ToolCallWriter.java b/communication/src/main/java/datadog/communication/serialization/custom/aiguard/ToolCallWriter.java new file mode 100644 index 00000000000..2b5107e1eb7 --- /dev/null +++ b/communication/src/main/java/datadog/communication/serialization/custom/aiguard/ToolCallWriter.java @@ -0,0 +1,19 @@ +package datadog.communication.serialization.custom.aiguard; + +import datadog.communication.serialization.EncodingCache; +import datadog.communication.serialization.ValueWriter; +import datadog.communication.serialization.Writable; +import datadog.trace.api.aiguard.AIGuard; + +public class ToolCallWriter implements ValueWriter { + + @Override + public void write( + final AIGuard.ToolCall value, final Writable writable, final EncodingCache encodingCache) { + writable.startMap(2); + writable.writeString("id", encodingCache); + writable.writeString(value.getId(), encodingCache); + writable.writeString("function", encodingCache); + writable.writeObject(value.getFunction(), encodingCache); + } +} diff --git a/communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy b/communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy new file mode 100644 index 00000000000..3328a2330ad --- /dev/null +++ b/communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy @@ -0,0 +1,119 @@ +package datadog.communication.serialization.aiguard + +import datadog.communication.serialization.EncodingCache +import datadog.communication.serialization.GrowableBuffer +import datadog.communication.serialization.msgpack.MsgPackWriter +import datadog.trace.api.aiguard.AIGuard +import datadog.trace.test.util.DDSpecification +import org.msgpack.core.MessagePack +import org.msgpack.value.Value + +import java.nio.charset.StandardCharsets +import java.util.function.Function + +class MessageWriterTest extends DDSpecification { + + private EncodingCache encodingCache + private GrowableBuffer buffer + private MsgPackWriter writer + + void setup() { + injectSysConfig('ai_guard.enabled', 'true') + final HashMap cache = new HashMap<>() + encodingCache = new EncodingCache() { + @Override + byte[] encode(CharSequence chars) { + cache.computeIfAbsent(chars, s -> s.toString().getBytes(StandardCharsets.UTF_8)) + } + } + buffer = new GrowableBuffer(1024) + writer = new MsgPackWriter(buffer) + } + + void 'test write message'() { + given: + final message = AIGuard.Message.message('user', 'What day is today?') + + when: + writer.writeObject(message, encodingCache) + + then: + try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + final value = asStringValueMap(unpacker.unpackValue()) + value.size() == 2 + value.role == 'user' + value.content == 'What day is today?' + } + } + + void 'test write tool call'() { + given: + final message = + AIGuard.Message.assistant( + AIGuard.ToolCall.toolCall('call_1', 'function_1', 'args_1'), + AIGuard.ToolCall.toolCall('call_2', 'function_2', 'args_2')) + + when: + writer.writeObject(message, encodingCache) + + then: + try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + final value = asStringKeyMap(unpacker.unpackValue()) + value.size() == 2 + asString(value.role) == 'assistant' + + final toolCalls = value.get('tool_calls').asArrayValue().list() + toolCalls.size() == 2 + + final firstCall = asStringKeyMap(toolCalls[0]) + asString(firstCall.id) == 'call_1' + final firstFunction = asStringValueMap(firstCall.function) + firstFunction.name == 'function_1' + firstFunction.arguments == 'args_1' + + final secondCall = asStringKeyMap(toolCalls[1]) + asString(secondCall.id) == 'call_2' + final secondFunction = asStringValueMap(secondCall.function) + secondFunction.name == 'function_2' + secondFunction.arguments == 'args_2' + } + } + + void 'test write tool output'() throws IOException { + given: + final message = AIGuard.Message.tool('call_1', 'output') + + when: + writer.writeObject(message, encodingCache) + + then: + try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + final value = asStringValueMap(unpacker.unpackValue()) + value.size() == 3 + value.role == 'tool' + value.tool_call_id == 'call_1' + value.content == 'output' + } + } + + private static Map mapValue( + final Value values, + final Function keyMapper, + final Function valueMapper) { + return values.asMapValue().entrySet().collectEntries { + [(keyMapper.apply(it.key)): valueMapper.apply(it.value)] + } + } + + private static Map asStringKeyMap(final Value values) { + return mapValue(values, MessageWriterTest::asString, Function.identity()) + } + + private static Map asStringValueMap(final Value values) { + return mapValue(values, MessageWriterTest::asString, MessageWriterTest::asString) + } + + private static String asString(final Value value) { + return value.asStringValue().asString() + } +} diff --git a/dd-java-agent/agent-aiguard/build.gradle b/dd-java-agent/agent-aiguard/build.gradle new file mode 100644 index 00000000000..f8dcb4df379 --- /dev/null +++ b/dd-java-agent/agent-aiguard/build.gradle @@ -0,0 +1,36 @@ +import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar + +plugins { + id 'com.gradleup.shadow' +} + +apply from: "$rootDir/gradle/java.gradle" +apply from: "$rootDir/gradle/version.gradle" + +java { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 +} + +dependencies { + api libs.slf4j + implementation libs.moshi + implementation libs.okhttp + + api project(':dd-trace-api') + implementation project(':internal-api') + implementation project(':communication') + + testImplementation project(':utils:test-utils') + testImplementation('org.skyscreamer:jsonassert:1.5.3') + testImplementation('com.fasterxml.jackson.core:jackson-databind:2.20.0') +} + +tasks.named("shadowJar", ShadowJar) { + dependencies deps.excludeShared +} + +tasks.named("jar", Jar) { + archiveClassifier = 'unbundled' +} + diff --git a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java new file mode 100644 index 00000000000..a7d098a4b91 --- /dev/null +++ b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java @@ -0,0 +1,362 @@ +package com.datadog.aiguard; + +import static datadog.trace.util.Strings.isBlank; +import static java.util.Collections.singletonMap; + +import com.squareup.moshi.JsonAdapter; +import com.squareup.moshi.JsonReader; +import com.squareup.moshi.JsonWriter; +import com.squareup.moshi.Moshi; +import com.squareup.moshi.Types; +import datadog.communication.http.OkHttpUtils; +import datadog.trace.api.Config; +import datadog.trace.api.aiguard.AIGuard; +import datadog.trace.api.aiguard.AIGuard.AIGuardAbortError; +import datadog.trace.api.aiguard.AIGuard.AIGuardClientError; +import datadog.trace.api.aiguard.AIGuard.Action; +import datadog.trace.api.aiguard.AIGuard.Evaluation; +import datadog.trace.api.aiguard.AIGuard.Message; +import datadog.trace.api.aiguard.AIGuard.Options; +import datadog.trace.api.aiguard.AIGuard.ToolCall; +import datadog.trace.api.aiguard.AIGuard.ToolCall.Function; +import datadog.trace.api.aiguard.Evaluator; +import datadog.trace.api.aiguard.noop.NoOpEvaluator; +import datadog.trace.bootstrap.instrumentation.api.AgentScope; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import okhttp3.HttpUrl; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okio.BufferedSink; + +/** + * Concrete implementation of the SDK used to interact with the AIGuard REST API. + * + *

An instance of this class is initialized and configured automatically during agent startup + * through {@link AIGuardSystem#start()}. + */ +public class AIGuardInternal implements Evaluator { + + public static class BadConfigurationException extends RuntimeException { + public BadConfigurationException(final String message) { + super(message); + } + } + + static final String SPAN_NAME = "ai_guard"; + static final String TARGET_TAG = "ai_guard.target"; + static final String TOOL_TAG = "ai_guard.tool"; + static final String ACTION_TAG = "ai_guard.action"; + static final String REASON_TAG = "ai_guard.reason"; + static final String BLOCKED_TAG = "ai_guard.blocked"; + static final String META_STRUCT_TAG = "ai_guard"; + static final String META_STRUCT_KEY = "messages"; + + public static void install() { + final Config config = Config.get(); + final String apiKey = config.getApiKey(); + final String appKey = config.getApplicationKey(); + if (isBlank(apiKey) || isBlank(appKey)) { + throw new BadConfigurationException( + "AI Guard: Missing api and/or application key, use DD_API_KEY and DD_APP_KEY"); + } + String endpoint = config.getAiGuardEndpoint(); + if (isBlank(endpoint)) { + endpoint = String.format("https://app.%s/api/v2/ai-guard", config.getSite()); + } + final Map headers = mapOf("DD-API-KEY", apiKey, "DD-APPLICATION-KEY", appKey); + final HttpUrl url = HttpUrl.get(endpoint).newBuilder().addPathSegment("evaluate").build(); + final int timeout = config.getAiGuardTimeout(); + final OkHttpClient client = buildClient(url, timeout); + Installer.install(new AIGuardInternal(url, headers, client)); + } + + /** Used by tests to reset status */ + static void uninstall() { + Installer.install(new NoOpEvaluator()); + } + + private final HttpUrl url; + private final Moshi moshi; + private final OkHttpClient client; + private final Map meta; + private final Map headers; + + AIGuardInternal(final HttpUrl url, final Map headers, final OkHttpClient client) { + this.url = url; + this.headers = headers; + this.client = client; + this.moshi = new Moshi.Builder().add(new AIGuardFactory()).build(); + final Config config = Config.get(); + this.meta = mapOf("service", config.getServiceName(), "env", config.getEnv()); + } + + /** + * Creates a deep copy of the messages before storing them in the metastruct to avoid concurrent + * modifications prior to trace serialization. + */ + private static List messagesForMetaStruct(List messages) { + final Config config = Config.get(); + final int size = Math.min(messages.size(), config.getAiGuardMaxMessagesLength()); + final List result = new ArrayList<>(size); + final int maxContent = config.getAiGuardMaxContentSize(); + for (int i = 0; i < size; i++) { + Message source = messages.get(i); + final String content = source.getContent(); + if (content != null && content.length() > maxContent) { + source = + new Message( + source.getRole(), + content.substring(0, maxContent), + source.getToolCalls(), + source.getToolCallId()); + } + result.add(source); + } + return result; + } + + private static boolean isToolCall(final Message message) { + return message.getToolCalls() != null || message.getToolCallId() != null; + } + + private static String getToolName(final Message current, final List messages) { + if (current.getToolCalls() != null) { + // assistant message with tool calls + return current.getToolCalls().stream() + .map(ToolCall::getFunction) + .map(Function::getName) + .collect(Collectors.joining(",")); + } + // assistant message with tool output (search the linked tool call in reverse order) + final String id = current.getToolCallId(); + for (int i = messages.size() - 1; i >= 0; i--) { + final Message message = messages.get(i); + if (message.getToolCalls() != null) { + for (final ToolCall toolCall : message.getToolCalls()) { + if (toolCall.getId().equals(id)) { + return toolCall.getFunction() == null ? null : toolCall.getFunction().getName(); + } + } + } + } + return null; + } + + private boolean isBlockingEnabled(final Options options, final Object isBlockingEnabled) { + return options.block() && "true".equalsIgnoreCase(isBlockingEnabled.toString()); + } + + @Override + public Evaluation evaluate(final List messages, final Options options) { + if (messages == null || messages.isEmpty()) { + throw new IllegalArgumentException("Messages must not be empty"); + } + final AgentTracer.TracerAPI tracer = AgentTracer.get(); + final AgentTracer.SpanBuilder builder = tracer.buildSpan(SPAN_NAME, SPAN_NAME); + final AgentSpan parent = AgentTracer.activeSpan(); + if (parent != null) { + builder.asChildOf(parent.context()); + } + final AgentSpan span = builder.start(); + try (final AgentScope scope = tracer.activateSpan(span)) { + final Message last = messages.get(messages.size() - 1); + if (isToolCall(last)) { + span.setTag(TARGET_TAG, "tool"); + final String toolName = getToolName(last, messages); + if (toolName != null) { + span.setTag(TOOL_TAG, toolName); + } + } else { + span.setTag(TARGET_TAG, "prompt"); + } + final Map metaStruct = + singletonMap(META_STRUCT_KEY, messagesForMetaStruct(messages)); + span.setMetaStruct(META_STRUCT_TAG, metaStruct); + final Request.Builder request = + new Request.Builder() + .url(url) + .method("POST", new MoshiJsonRequestBody(moshi, messages, meta)); + headers.forEach(request::header); + try (final Response response = client.newCall(request.build()).execute()) { + final Map result = parseResponseBody(response); + final String actionStr = (String) result.get("action"); + if (actionStr == null) { + throw new IllegalArgumentException("Action field is missing in the response"); + } + final Action action = Action.valueOf(actionStr); + final String reason = (String) result.get("reason"); + span.setTag(ACTION_TAG, action); + span.setTag(REASON_TAG, reason); + final boolean blockingEnabled = + isBlockingEnabled(options, result.get("is_blocking_enabled")); + if (blockingEnabled && action != Action.ALLOW) { + span.setTag(BLOCKED_TAG, true); + throw new AIGuardAbortError(action, reason); + } + return new Evaluation(action, reason); + } + } catch (AIGuardAbortError | AIGuardClientError e) { + span.addThrowable(e); + throw e; + } catch (final Exception e) { + final AIGuardClientError error = + new AIGuardClientError("AI Guard service returned unexpected response", e); + span.addThrowable(error); + throw error; + } finally { + span.finish(); + } + } + + @SuppressWarnings("unchecked") + private Map parseResponseBody(final Response response) throws IOException { + final ResponseBody body = response.body(); + if (body == null) { + throw fail(response.code(), null); + } + final JsonReader reader = JsonReader.of(body.source()); + final Map parsedBody = moshi.adapter(Map.class).fromJson(reader); + final Object errors = parsedBody.get("errors"); + if (errors != null) { + throw fail(response.code(), errors); + } + final Map data = (Map) parsedBody.get("data"); + return (Map) data.get("attributes"); + } + + private AIGuardClientError fail(final int statusCode, final Object errors) { + return new AIGuardClientError("AI Guard service call failed, status: " + statusCode, errors); + } + + private static OkHttpClient buildClient(final HttpUrl url, final long timeout) { + return OkHttpUtils.buildHttpClient(url, timeout).newBuilder().build(); + } + + private static Map mapOf( + final String key1, final String prop1, final String key2, final String prop2) { + final Map map = new HashMap<>(2); + map.put(key1, prop1); + map.put(key2, prop2); + return map; + } + + private static class Installer extends AIGuard { + public static void install(final Evaluator evaluator) { + AIGuard.EVALUATOR = evaluator; + } + } + + static class AIGuardFactory implements JsonAdapter.Factory { + + @Nullable + @Override + public JsonAdapter create( + final Type type, final Set annotations, final Moshi moshi) { + final Class rawType = Types.getRawType(type); + if (rawType != AIGuard.Message.class) { + return null; + } + return new MessageAdapter(moshi.adapter(AIGuard.ToolCall.class)).nullSafe(); + } + } + + static class MessageAdapter extends JsonAdapter { + + private final JsonAdapter toolCallAdapter; + + MessageAdapter(final JsonAdapter toolCallAdapter) { + this.toolCallAdapter = toolCallAdapter; + } + + @Nullable + @Override + public Message fromJson(JsonReader reader) throws IOException { + throw new UnsupportedOperationException("Serializing only adapter"); + } + + @Override + public void toJson(final JsonWriter writer, final Message value) throws IOException { + writer.beginObject(); + writeValue(writer, "role", value.getRole()); + writeValue(writer, "content", value.getContent()); + writeArray(writer, "tool_calls", value.getToolCalls()); + writeValue(writer, "tool_call_id", value.getToolCallId()); + writer.endObject(); + } + + private void writeValue(final JsonWriter writer, final String name, final Object value) + throws IOException { + if (value != null) { + writer.name(name); + writer.jsonValue(value); + } + } + + private void writeArray(final JsonWriter writer, final String name, final List value) + throws IOException { + if (value != null) { + writer.name(name); + writer.beginArray(); + for (final ToolCall toolCall : value) { + toolCallAdapter.toJson(writer, toolCall); + } + writer.endArray(); + } + } + } + + static class MoshiJsonRequestBody extends RequestBody { + + private static final MediaType JSON = MediaType.parse("application/json"); + + private final Moshi moshi; + private final Map meta; + private final Collection messages; + + public MoshiJsonRequestBody( + final Moshi moshi, final Collection messages, final Map meta) { + this.moshi = moshi; + this.messages = messages; + this.meta = meta; + } + + @Nullable + @Override + public MediaType contentType() { + return JSON; + } + + @Override + public void writeTo(final BufferedSink sink) throws IOException { + final JsonWriter writer = JsonWriter.of(sink); + writer.beginObject(); // request + writer.name("data"); + writer.beginObject(); // data + writer.name("attributes"); + writer.beginObject(); // attributes + writer.name("messages"); + moshi.adapter(Object.class).toJson(writer, messages); + writer.name("meta"); + writer.jsonValue(meta); + writer.endObject(); // attributes + writer.endObject(); // data + writer.endObject(); // request + } + } +} diff --git a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardSystem.java b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardSystem.java new file mode 100644 index 00000000000..43dc227bb07 --- /dev/null +++ b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardSystem.java @@ -0,0 +1,14 @@ +package com.datadog.aiguard; + +public abstract class AIGuardSystem { + + private AIGuardSystem() {} + + public static void start() { + initializeSDK(); + } + + private static void initializeSDK() { + AIGuardInternal.install(); + } +} diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy new file mode 100644 index 00000000000..913224f70ac --- /dev/null +++ b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy @@ -0,0 +1,507 @@ +package com.datadog.aiguard + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.PropertyNamingStrategies +import com.squareup.moshi.Moshi +import datadog.trace.api.Config +import datadog.trace.api.aiguard.AIGuard +import datadog.trace.bootstrap.instrumentation.api.AgentSpan +import datadog.trace.bootstrap.instrumentation.api.AgentTracer +import datadog.trace.test.util.DDSpecification +import okhttp3.Call +import okhttp3.HttpUrl +import okhttp3.MediaType +import okhttp3.OkHttpClient +import okhttp3.Protocol +import okhttp3.Request +import okhttp3.RequestBody +import okhttp3.Response +import okhttp3.ResponseBody +import okio.Okio +import spock.lang.Shared + +import org.skyscreamer.jsonassert.JSONAssert +import org.skyscreamer.jsonassert.JSONCompareMode + +import static datadog.trace.api.aiguard.AIGuard.Action.ABORT +import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW +import static datadog.trace.api.aiguard.AIGuard.Action.DENY +import static org.codehaus.groovy.runtime.DefaultGroovyMethods.combinations + +class AIGuardInternalTests extends DDSpecification { + + @Shared + protected static final URL = HttpUrl.parse('https://app.datadoghq.com/api/v2/ai-guard/evaluate') + + @Shared + protected static final HEADERS = ['DD-API-KEY': 'api', 'DD-APPLICATION-KEY': 'app'] + + @Shared + protected static final ORIGINAL_TRACER = AgentTracer.get() + + @Shared + protected static final MOSHI = new Moshi.Builder().build() + + @Shared + protected static final MAPPER = new ObjectMapper() + .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE) + .setDefaultPropertyInclusion( + JsonInclude.Value.construct(JsonInclude.Include.NON_NULL, JsonInclude.Include.NON_NULL) + ) + + @Shared + protected static final TOOL_CALL = [ + AIGuard.Message.message('system', 'You are a beautiful AI assistant'), + AIGuard.Message.message('user', 'What is 2 + 2'), + AIGuard.Message.assistant( + AIGuard.ToolCall.toolCall('call_1', 'calc', '{ "operator": "+", "args": [2, 2] }') + ) + ] + + @Shared + protected static final TOOL_OUTPUT = TOOL_CALL + [AIGuard.Message.tool('call_1', '5')] + + @Shared + protected static final PROMPT = TOOL_OUTPUT + [AIGuard.Message.message('assistant', '2 + 2 is 5'), AIGuard.Message.message('user', '')] + + protected AgentSpan span + + void setup() { + injectEnvConfig('SERVICE', 'ai_guard_test') + injectEnvConfig('ENV', 'test') + + span = Mock(AgentSpan) + final builder = Mock(AgentTracer.SpanBuilder) { + start() >> span + } + final tracer = Stub(AgentTracer.TracerAPI) { + buildSpan(_ as String, _ as String) >> builder + } + AgentTracer.forceRegister(tracer) + } + + void cleanup() { + AgentTracer.forceRegister(ORIGINAL_TRACER) + AIGuardInternal.uninstall() + } + + void 'test missing api/app keys'() { + given: + if (apiKey) { + injectEnvConfig('API_KEY', apiKey) + } + if (appKey) { + injectEnvConfig('APP_KEY', appKey) + } + + when: + AIGuardInternal.install() + + then: + thrown(AIGuardInternal.BadConfigurationException) + + where: + apiKey | appKey + 'apiKey' | null + 'apiKey' | '' + null | 'appKey' + '' | 'appKey' + null | null + '' | '' + } + + void 'test endpoint discovery'() { + given: + injectEnvConfig('API_KEY', 'api') + injectEnvConfig('APP_KEY', 'app') + if (endpoint != null) { + injectEnvConfig("AI_GUARD_ENDPOINT", endpoint) + } else { + removeEnvConfig("AI_GUARD_ENDPOINT") + } + if (site != null) { + injectEnvConfig('SITE', site) + } else { + removeEnvConfig('SITE') + } + + when: + AIGuardInternal.install() + + then: + final internal = (AIGuardInternal) AIGuard.EVALUATOR + internal.url.toString() == expected + + where: + endpoint | site | expected + 'https://test' | null | 'https://test/evaluate' + null | null | 'https://app.datadoghq.com/api/v2/ai-guard/evaluate' + null | 'datadoghq.com' | 'https://app.datadoghq.com/api/v2/ai-guard/evaluate' + null | 'datad0g.com' | 'https://app.datad0g.com/api/v2/ai-guard/evaluate' + } + + void 'test evaluate'() { + given: + Request request = null + Throwable error = null + AIGuard.Evaluation eval = null + final throwAbortError = suite.blocking && suite.action != ALLOW + final call = Mock(Call) { + execute() >> { + return mockResponse( + request, + 200, + [data: [attributes: [action: suite.action, reason: suite.reason, is_blocking_enabled: suite.blocking]]] + ) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + + when: + try { + eval = aiguard.evaluate(suite.messages, new AIGuard.Options().block(suite.blocking)) + } catch (Throwable e) { + error = e + } + + then: + 1 * span.setTag(AIGuardInternal.TARGET_TAG, suite.target) + if (suite.target == 'tool') { + 1 * span.setTag(AIGuardInternal.TOOL_TAG, 'calc') + } + 1 * span.setTag(AIGuardInternal.ACTION_TAG, suite.action) + 1 * span.setTag(AIGuardInternal.REASON_TAG, suite.reason) + 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, [messages: suite.messages]) + if (throwAbortError) { + 1 * span.addThrowable(_ as AIGuard.AIGuardAbortError) + } + + assertRequest(request, suite.messages) + if (throwAbortError) { + error instanceof AIGuard.AIGuardAbortError + error.action == suite.action + error.reason == suite.reason + } else { + error == null + eval.action == suite.action + eval.reason == suite.reason + } + + where: + suite << TestSuite.build() + } + + void 'test evaluate with API errors'() { + given: + final errors = [[status: 400, title: 'Bad request']] + Request request = null + final call = Mock(Call) { + execute() >> { + return mockResponse(request, 404, [errors: errors]) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + + when: + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) + + then: + final exception = thrown(AIGuard.AIGuardClientError) + exception.errors == errors + 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + } + + void 'test evaluate with invalid JSON'() { + given: + Request request = null + final call = Mock(Call) { + execute() >> { + return mockResponse(request, 200, [bad: 'This is an invalid response']) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + + when: + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) + + then: + thrown(AIGuard.AIGuardClientError) + 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + } + + void 'test evaluate with missing action'() { + given: + Request request = null + final call = Mock(Call) { + execute() >> { + return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]]) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + + when: + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) + + then: + thrown(AIGuard.AIGuardClientError) + 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + } + + void 'test evaluate with non JSON response'() { + given: + Request request = null + final call = Mock(Call) { + execute() >> { + return mockResponse(request, 200, [data: [attributes: [reason: 'I miss something']]]) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + + when: + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) + + then: + thrown(AIGuard.AIGuardClientError) + 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + } + + void 'test evaluate with empty response'() { + given: + Request request = null + final call = Mock(Call) { + execute() >> { + return mockResponse(request, 200, null) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + + when: + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) + + then: + thrown(AIGuard.AIGuardClientError) + 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) + } + + void 'test message length truncation'() { + given: + final maxMessages = Config.get().getAiGuardMaxMessagesLength() + Request request = null + final call = Mock(Call) { + execute() >> { + return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]]) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + final messages = (0..maxMessages) + .collect { AIGuard.Message.message('user', "This is a prompt: ${it}") } + .toList() + + when: + aiguard.evaluate(messages, AIGuard.Options.DEFAULT) + + then: + 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { + final received = (List) it[1].messages + assert received.size() == maxMessages + assert received.size() < messages.size() + } + } + + void 'test message content truncation'() { + given: + final maxContent = Config.get().getAiGuardMaxContentSize() + Request request = null + final call = Mock(Call) { + execute() >> { + return mockResponse(request, 200, [data: [attributes: [action: ALLOW, reason: 'It is fine']]]) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + final message = AIGuard.Message.message("user", (0..maxContent).collect { 'A' }.join()) + + when: + aiguard.evaluate([message], AIGuard.Options.DEFAULT) + + then: + 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { + final received = (List) it[1].messages + received.last().with { + assert it.content.length() == maxContent + assert it.content.length() < message.content.length() + } + } + } + + void 'test no messages'() { + given: + final aiguard = new AIGuardInternal(URL, HEADERS, Stub(OkHttpClient)) + + when: + aiguard.evaluate(messages, AIGuard.Options.DEFAULT) + + then: + thrown(IllegalArgumentException) + + + where: + messages << [[], null] + } + + void 'test missing tool name'() { + given: + def request + final call = Mock(Call) { + execute() >> { + return mockResponse( + request, + 200, + [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]] + ) + } + } + final client = Mock(OkHttpClient) { + newCall(_ as Request) >> { + request = (Request) it[0] + return call + } + } + final aiguard = new AIGuardInternal(URL, HEADERS, client) + + when: + aiguard.evaluate([AIGuard.Message.tool('call_1', 'Content')], AIGuard.Options.DEFAULT) + + then: + 1 * span.setTag(AIGuardInternal.TARGET_TAG, 'tool') + 0 * span.setTag(AIGuardInternal.TOOL_TAG, _) + } + + private static assertRequest(final Request request, final List messages) { + assert request.url() == URL + assert request.method() == 'POST' + HEADERS.each { entry -> + assert request.header(entry.key) == entry.value + } + assert request.body().contentType().toString().contains('application/json') + final receivedBody = readRequestBody(request.body()) + final expectedBody = snakeCaseJson([data: [attributes: [messages: messages, meta: [service: 'ai_guard_test', env: 'test']]]]) + JSONAssert.assertEquals(expectedBody, receivedBody, JSONCompareMode.NON_EXTENSIBLE) + return true + } + + private static String snakeCaseJson(final Object value) { + MAPPER.writeValueAsString(value) + } + + private static String readRequestBody(final RequestBody body) { + final output = new ByteArrayOutputStream() + final buffer = Okio.buffer(Okio.sink(output)) + body.writeTo(buffer) + buffer.flush() + return new String(output.toByteArray()) + } + + private static Response mockResponse(final Request request, final int status, final Object body) { + return new Response.Builder() + .protocol(Protocol.HTTP_1_1) + .message('ok') + .request(request) + .code(status) + .body(body == null ? null : ResponseBody.create(MediaType.parse('application/json'), MOSHI.adapter(Object).toJson(body))) + .build() + } + + private static class TestSuite { + private final AIGuard.Action action + private final String reason + private final boolean blocking + private final String description + private final String target + private final List messages + + TestSuite(AIGuard.Action action, String reason, boolean blocking, String description, String target, List messages) { + this.action = action + this.reason = reason + this.blocking = blocking + this.description = description + this.target = target + this.messages = messages + } + + static List build() { + def actionValues = [[ALLOW, 'Go ahead'], [DENY, 'Nope'], [ABORT, 'Kill it with fire']] + def blockingValues = [true, false] + def suiteValues = [ + ['tool call', 'tool', TOOL_CALL], + ['tool output', 'tool', TOOL_OUTPUT], + ['prompt', 'prompt', PROMPT] + ] + return combinations([actionValues, blockingValues, suiteValues] as Iterable) + .collect { action, blocking, suite -> + new TestSuite(action[0], action[1], blocking, suite[0], suite[1], suite[2]) + } + } + + + @Override + String toString() { + return "TestSuite{" + + "description='" + description + '\'' + + ", action=" + action + + ", reason='" + reason + '\'' + + ", blocking=" + blocking + + ", target='" + target + '\'' + + ", messages=" + messages + + '}' + } + } +} diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardSystemTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardSystemTests.groovy new file mode 100644 index 00000000000..0929422df52 --- /dev/null +++ b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardSystemTests.groovy @@ -0,0 +1,22 @@ +package com.datadog.aiguard + +import datadog.trace.api.aiguard.AIGuard +import datadog.trace.test.util.DDSpecification + +class AIGuardSystemTests extends DDSpecification { + + void cleanup() { + AIGuardInternal.uninstall() + } + + void 'test SDK initialization'() { + injectEnvConfig('API_KEY', 'api') + injectEnvConfig('APP_KEY', 'app') + + when: + AIGuardSystem.start() + + then: + AIGuard.EVALUATOR instanceof AIGuardInternal + } +} diff --git a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/Agent.java b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/Agent.java index 028965e4296..48557510aa7 100644 --- a/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/Agent.java +++ b/dd-java-agent/agent-bootstrap/src/main/java/datadog/trace/bootstrap/Agent.java @@ -643,6 +643,7 @@ public void execute() { // start debugger before remote config to subscribe to it before starting to poll maybeStartDebugger(instrumentation, scoClass, sco); maybeStartRemoteConfig(scoClass, sco); + maybeStartAiGuard(); if (telemetryEnabled) { startTelemetry(instrumentation, scoClass, sco); @@ -935,6 +936,20 @@ private static StatsDClientManager statsDClientManager() throws Exception { return (StatsDClientManager) statsDClientManagerMethod.invoke(null); } + private static void maybeStartAiGuard() { + if (!Config.get().isAiGuardEnabled()) { + return; + } + try { + final Class aiGuardSystemClass = + AGENT_CLASSLOADER.loadClass("com.datadog.aiguard.AIGuardSystem"); + final Method aiGuardInstallerMethod = aiGuardSystemClass.getMethod("start"); + aiGuardInstallerMethod.invoke(null); + } catch (final Exception e) { + log.debug("Error initializing AI Guard", e); + } + } + private static void maybeStartAppSec(Class scoClass, Object o) { try { diff --git a/dd-java-agent/build.gradle b/dd-java-agent/build.gradle index a20908bcfae..fe12afb598a 100644 --- a/dd-java-agent/build.gradle +++ b/dd-java-agent/build.gradle @@ -167,6 +167,7 @@ includeSubprojShadowJar(project(':dd-java-agent:instrumentation'), 'inst') includeSubprojShadowJar(project(':dd-java-agent:agent-jmxfetch'), 'metrics') includeSubprojShadowJar(project(':dd-java-agent:agent-profiling'), 'profiling') includeSubprojShadowJar(project(':dd-java-agent:appsec'), 'appsec') +includeSubprojShadowJar(project(':dd-java-agent:agent-aiguard'), 'aiguard') includeSubprojShadowJar(project(':dd-java-agent:agent-iast'), 'iast') includeSubprojShadowJar(project(':dd-java-agent:agent-debugger'), 'debugger') includeSubprojShadowJar(project(':dd-java-agent:agent-ci-visibility'), 'ci-visibility') diff --git a/dd-smoke-tests/appsec/springboot/build.gradle b/dd-smoke-tests/appsec/springboot/build.gradle index 101bd197f47..37ca43b62a9 100644 --- a/dd-smoke-tests/appsec/springboot/build.gradle +++ b/dd-smoke-tests/appsec/springboot/build.gradle @@ -14,6 +14,7 @@ tasks.named("jar", Jar) { } dependencies { + implementation project(':dd-trace-api') implementation group: 'org.springframework.boot', name: 'spring-boot-starter-web', version: '2.6.0' implementation(group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.6.0') implementation group: 'com.h2database', name: 'h2', version: '2.1.212' diff --git a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java new file mode 100644 index 00000000000..fc3cb3942be --- /dev/null +++ b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java @@ -0,0 +1,98 @@ +package datadog.smoketest.appsec.springboot.controller; + +import static java.util.Arrays.asList; +import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; + +import datadog.trace.api.aiguard.AIGuard; +import datadog.trace.api.aiguard.AIGuard.AIGuardAbortError; +import datadog.trace.api.aiguard.AIGuard.Evaluation; +import datadog.trace.api.aiguard.AIGuard.Message; +import datadog.trace.api.aiguard.AIGuard.Options; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestHeader; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping(value = "/aiguard") +public class AIGuardController { + + @GetMapping(value = "/allow") + public ResponseEntity allow() { + final Evaluation result = + AIGuard.evaluate( + asList( + Message.message("system", "You are a beautiful AI"), + Message.message("user", "I am harmless"))); + return ResponseEntity.ok(result); + } + + @GetMapping(value = "/deny") + public ResponseEntity deny(final @RequestHeader("X-Blocking-Enabled") boolean block) { + try { + final Evaluation result = + AIGuard.evaluate( + asList( + Message.message("system", "You are a beautiful AI"), + Message.message("user", "You should not trust me" + (block ? " [block]" : ""))), + new Options().block(block)); + return ResponseEntity.ok(result); + } catch (AIGuardAbortError e) { + return ResponseEntity.status(HttpStatus.FORBIDDEN).body(e.getReason()); + } + } + + @GetMapping(value = "/abort") + public ResponseEntity abort(final @RequestHeader("X-Blocking-Enabled") boolean block) { + try { + final Evaluation result = + AIGuard.evaluate( + asList( + Message.message("system", "You are a beautiful AI"), + Message.message("user", "Nuke yourself" + (block ? " [block]" : ""))), + new Options().block(block)); + return ResponseEntity.ok(result); + } catch (AIGuardAbortError e) { + return ResponseEntity.status(HttpStatus.FORBIDDEN).body(e.getReason()); + } + } + + /** Mocking endpoint for the AI Guard REST API */ + @SuppressWarnings("unchecked") + @PostMapping( + value = "/evaluate", + consumes = APPLICATION_JSON_VALUE, + produces = APPLICATION_JSON_VALUE) + public ResponseEntity> evaluate( + @RequestBody final Map request) { + final Map data = (Map) request.get("data"); + final Map attributes = (Map) data.get("attributes"); + final List> messages = + (List>) attributes.get("messages"); + final Map last = messages.get(messages.size() - 1); + String action = "ALLOW"; + String reason = "The prompt looks harmless"; + String content = (String) last.get("content"); + if (content.startsWith("You should not trust me")) { + action = "DENY"; + reason = "I am feeling suspicious today"; + } else if (content.startsWith("Nuke yourself")) { + action = "ABORT"; + reason = "The user is trying to destroy me"; + } + final Map evaluation = new HashMap<>(3); + evaluation.put("action", action); + evaluation.put("reason", reason); + evaluation.put("is_blocking_enabled", content.endsWith("[block]")); + return ResponseEntity.ok() + .body(Collections.singletonMap("data", Collections.singletonMap("attributes", evaluation))); + } +} diff --git a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy new file mode 100644 index 00000000000..c7fed2158b6 --- /dev/null +++ b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy @@ -0,0 +1,103 @@ +package datadog.smoketest.appsec + +import datadog.trace.test.agent.decoder.DecodedSpan +import groovy.json.JsonSlurper +import okhttp3.Request +import spock.lang.Shared + +class AIGuardSmokeTest extends AbstractAppSecServerSmokeTest { + + @Shared + protected String[] defaultAIGuardProperties = [ + '-Ddd.ai_guard.enabled=true', + "-Ddd.ai_guard.endpoint=http://localhost:${httpPort}/aiguard".toString(), + ] + + @Override + def logLevel() { + 'DEBUG' + } + + @Override + Closure decodedTracesCallback() { + // just return the traces + return {} + } + + @Override + ProcessBuilder createProcessBuilder() { + final springBootShadowJar = System.getProperty("datadog.smoketest.appsec.springboot.shadowJar.path") + final command = [javaPath()] + command.addAll(defaultJavaProperties) + command.addAll(defaultAppSecProperties) + command.addAll(defaultAIGuardProperties) + command.addAll(['-jar', springBootShadowJar, "--server.port=${httpPort}".toString()]) + final builder = new ProcessBuilder(command).directory(new File(buildDirectory)) + builder.environment().put('DD_APPLICATION_KEY', 'test') + return builder + } + + void 'test message evaluation'() { + given: + final blocking = test.blocking as boolean + final action = test.action as String + final reason = test.reason as String + def request = new Request.Builder() + .url("http://localhost:${httpPort}/aiguard${test.endpoint}") + .header('X-Blocking-Enabled', "${blocking}") + .get() + .build() + + when: + final response = client.newCall(request).execute() + + then: + if (blocking && action != 'ALLOW') { + assert response.code() == 403 + assert response.body().string().contains(reason) + } else { + assert response.code() == 200 + final body = new JsonSlurper().parse(response.body().bytes()) + assert body.reason == reason + assert body.action == action + } + + and: + waitForTraceCount(2) // default call + internal API mock + final span = traces*.spans + ?.flatten() + ?.find { it.resource == 'ai_guard' } as DecodedSpan + assert span.meta.get('ai_guard.action') == action + assert span.meta.get('ai_guard.reason') == reason + assert span.meta.get('ai_guard.target') == 'prompt' + final messages = span.metaStruct.get('ai_guard').messages as List> + assert messages.size() == 2 + messages[0].with { + assert role == 'system' + assert content == 'You are a beautiful AI' + } + messages[1].with { + assert role == 'user' + assert content != null + } + + where: + test << testSuite() + } + + private static List testSuite() { + return combinations([ + [endpoint: '/allow', action: 'ALLOW', reason: 'The prompt looks harmless'], + [endpoint: '/deny', action: 'DENY', reason: 'I am feeling suspicious today'], + [endpoint: '/abort', action: 'ABORT', reason: 'The user is trying to destroy me'] + ], [[blocking: true], [blocking: false],]) + } + + private static List combinations(list1, list2) { + list1.collectMany { a -> + list2.collect { b -> + a + b + } + } + } +} diff --git a/dd-trace-api/build.gradle.kts b/dd-trace-api/build.gradle.kts index 640267ecec6..f191c71be0e 100644 --- a/dd-trace-api/build.gradle.kts +++ b/dd-trace-api/build.gradle.kts @@ -24,12 +24,17 @@ val excludedClassesCoverage by extra( "datadog.trace.api.SpanCorrelation*", "datadog.trace.api.internal.TraceSegment", "datadog.trace.api.internal.TraceSegment.NoOp", + "datadog.trace.api.aiguard.AIGuard", + "datadog.trace.api.aiguard.AIGuard.AIGuardAbortError", + "datadog.trace.api.aiguard.AIGuard.AIGuardClientError", + "datadog.trace.api.aiguard.AIGuard.Options", "datadog.trace.api.civisibility.CIVisibility", "datadog.trace.api.civisibility.DDTestModule", "datadog.trace.api.civisibility.noop.NoOpDDTest", "datadog.trace.api.civisibility.noop.NoOpDDTestModule", "datadog.trace.api.civisibility.noop.NoOpDDTestSession", "datadog.trace.api.civisibility.noop.NoOpDDTestSuite", + "datadog.trace.api.config.AIGuardConfig", "datadog.trace.api.config.ProfilingConfig", "datadog.trace.api.interceptor.MutableSpan", "datadog.trace.api.profiling.Profiling", diff --git a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java new file mode 100644 index 00000000000..c24d8751fbf --- /dev/null +++ b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java @@ -0,0 +1,431 @@ +package datadog.trace.api.aiguard; + +import datadog.trace.api.aiguard.noop.NoOpEvaluator; +import java.util.Arrays; +import java.util.List; + +/** + * SDK for calling the AIGuard REST API to evaluate AI prompts and tool calls for security threats. + * + *

Example usage: + * + *

{@code
+ * var messages = List.of(
+ *     AIGuard.Message.message("user", "Delete all my files"),
+ *     AIGuard.Message.message("assistant", "I'll help you delete your files")
+ * );
+ *
+ * var result = AIGuard.evaluate(messages);
+ * if (result.getAction() != AIGuard.Action.ALLOW) {
+ *     System.out.println("Unsafe: " + result.getReason());
+ * }
+ * }
+ */ +public abstract class AIGuard { + + protected static Evaluator EVALUATOR = new NoOpEvaluator(); + + protected AIGuard() {} + + /** + * Evaluates a collection of messages using default options to determine if they are safe to + * execute. + * + * @see #evaluate(List, Options) + */ + public static Evaluation evaluate(final List messages) { + return evaluate(messages, Options.DEFAULT); + } + + /** + * Evaluates a collection of messages with custom options to determine if they are safe to + * execute. + * + * @param messages the collection of messages to evaluate (prompts, responses, tool calls, etc.) + * @param options configuration options for the evaluation process + * @return an {@link Evaluation} containing the security decision and reasoning + * @throws AIGuardAbortError if the evaluation action is not ALLOW (DENY or ABORT) and blocking is + * enabled + * @throws AIGuardClientError if there are client-side errors communicating with the AIGuard REST + * API + */ + public static Evaluation evaluate(final List messages, final Options options) { + return EVALUATOR.evaluate(messages, options); + } + + /** + * Exception thrown when AIGuard evaluation results in blocking the execution due to security + * concerns. + * + *

Important: This exception is thrown when the evaluation action is not + * {@code ALLOW} (i.e., {@code DENY} or {@code ABORT}) and blocking mode is enabled. + */ + public static class AIGuardAbortError extends RuntimeException { + private final Action action; + private final String reason; + + public AIGuardAbortError(final Action action, final String reason) { + super(reason); + this.action = action; + this.reason = reason; + } + + public Action getAction() { + return action; + } + + public String getReason() { + return reason; + } + } + + /** + * Exception thrown when there are client-side errors communicating with the AIGuard REST API. + * + *

This exception indicates problems with the AIGuard client implementation such as: + * + *

    + *
  • Network connectivity issues when calling the AIGuard REST API + *
  • Authentication failures with the AIGuard service + *
  • Invalid configuration or missing API credentials + *
  • Request timeout or service unavailability + *
  • Malformed requests or unsupported API versions + *
+ */ + public static class AIGuardClientError extends RuntimeException { + + private final Object errors; + + public AIGuardClientError(final String message, final Throwable cause) { + super(message, cause); + errors = null; + } + + public AIGuardClientError(final String message, final Object errors) { + super(message, null); + this.errors = errors; + } + + public Object getErrors() { + return errors; + } + } + + /** Actions that can be recommended by an AIGuard evaluation. */ + public enum Action { + /** Content is safe to proceed with execution */ + ALLOW, + /** Current action should be blocked from execution */ + DENY, + /** Workflow should be immediately terminated due to severe risk */ + ABORT + } + + /** + * Represents the result of an AIGuard security evaluation, containing both the recommended action + * and the reasoning behind the decision. + * + *

The evaluation provides three possible actions: + * + *

    + *
  • {@link Action#ALLOW} - Content is safe to proceed + *
  • {@link Action#DENY} - Content should be blocked + *
  • {@link Action#ABORT} - Execution should be immediately terminated + *
+ */ + public static class Evaluation { + + final Action action; + final String reason; + + /** + * Creates a new evaluation result. + * + * @param action the recommended action for the evaluated content + * @param reason human-readable explanation for the decision + */ + public Evaluation(final Action action, final String reason) { + this.action = action; + this.reason = reason; + } + + /** + * Returns the recommended action for the evaluated content. + * + * @return the action (ALLOW, DENY, or ABORT) + */ + public Action getAction() { + return action; + } + + /** + * Returns the human-readable reasoning for the evaluation decision. + * + * @return explanation of why this action was recommended + */ + public String getReason() { + return reason; + } + } + + /** + * Represents a message in an AI conversation. Messages can represent user prompts, assistant + * responses, system messages, or tool outputs. + * + *

Example usage: + * + *

{@code
+   * // User prompt
+   * var userPrompt = AIGuard.Message.message("user", "What's the weather like?");
+   *
+   * // Assistant response with tool calls
+   * var assistantWithTools = AIGuard.Message.assistant(
+   *     AIGuard.ToolCall.toolCall("call_123", "get_weather", "{\"location\": \"New York\"}")
+   * );
+   *
+   * // Tool response
+   * var toolResponse = AIGuard.Message.tool("call_123", "Sunny, 75°F");
+   * }
+ */ + public static class Message { + + private final String role; + private final String content; + private final List toolCalls; + private final String toolCallId; + + /** + * Creates a new message with the specified parameters. + * + * @param role the role of the message sender (e.g., "user", "assistant", "system", "tool") + * @param content the text content of the message, or null for assistant messages with only tool + * calls + * @param toolCalls list of tool calls associated with this message, or null if no tool calls + * @param toolCallId the tool call ID this message is responding to, or null if not a tool + * response + */ + public Message( + final String role, + final String content, + final List toolCalls, + final String toolCallId) { + this.role = role; + this.content = content; + this.toolCalls = toolCalls; + this.toolCallId = toolCallId; + } + + /** + * Returns the role of the message sender. + * + * @return the role (e.g., "user", "assistant", "system", "tool") + */ + public String getRole() { + return role; + } + + /** + * Returns the text content of the message. + * + * @return the message content, or null for assistant messages with only tool calls + */ + public String getContent() { + return content; + } + + /** + * Returns the list of tool calls associated with this message. + * + * @return list of tool calls, or null if this message has no tool calls + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Returns the tool call ID that this message is responding to. + * + * @return the tool call ID, or null if this is not a tool response message + */ + public String getToolCallId() { + return toolCallId; + } + + /** + * Creates a message with specified role and text content. + * + * @param role the role of the message sender (e.g., "user", "system") + * @param content the text content of the message + * @return a new Message instance + */ + public static Message message(final String role, final String content) { + return new Message(role, content, null, null); + } + + /** + * Creates a tool response message. + * + * @param toolCallId the ID of the tool call this message is responding to + * @param content the result or output from the tool execution + * @return a new Message instance with role "tool" + */ + public static Message tool(final String toolCallId, final String content) { + return new Message("tool", content, null, toolCallId); + } + + /** + * Creates an assistant message with tool calls but no text content. + * + * @param toolCalls the tool calls the assistant wants to make + * @return a new Message instance with role "assistant" and no text content + */ + public static Message assistant(final ToolCall... toolCalls) { + return new Message("assistant", null, Arrays.asList(toolCalls), null); + } + } + + /** + * Configuration options for AIGuard evaluation behavior. + * + *

Options control how the evaluation process behaves, including whether to block execution + * when unsafe content is detected. + * + *

Example usage: + * + *

{@code
+   * // Use default options (non-blocking)
+   * var result = AIGuard.evaluate(messages);
+   *
+   * // Enable blocking mode
+   * var options = new AIGuard.Options()
+   *     .block(true);
+   * var result = AIGuard.evaluate(messages, options);
+   * }
+ */ + public static final class Options { + + /** Default options with blocking disabled. */ + public static final Options DEFAULT = new Options().block(false); + + private boolean block; + + /** + * Returns whether blocking mode is enabled. + * + * @return true if execution should be blocked on DENY/ABORT actions + */ + public boolean block() { + return block; + } + + /** + * Enable/disable blocking mode + * + * @param block true if execution should be blocked on DENY/ABORT actions + */ + public Options block(final boolean block) { + this.block = block; + return this; + } + } + + /** + * Represents a function call made by an AI assistant. Tool calls contain an identifier and + * function details (name and arguments). + * + *

Example usage: + * + *

{@code
+   * // Create a tool call
+   * var toolCall = AIGuard.ToolCall.toolCall("call_123", "get_weather", "{\"location\": \"NYC\"}");
+   *
+   * // Use in an assistant message
+   * var assistantMessage = AIGuard.Message.assistant(toolCall);
+   * }
+ */ + public static class ToolCall { + + private final String id; + private final Function function; + + /** + * Creates a new tool call with the specified ID and function. + * + * @param id unique identifier for this tool call + * @param function the function details (name and arguments) + */ + public ToolCall(final String id, final Function function) { + this.id = id; + this.function = function; + } + + /** + * Returns the unique identifier for this tool call. + * + * @return the tool call ID + */ + public String getId() { + return id; + } + + /** + * Returns the function details for this tool call. + * + * @return the Function object containing name and arguments + */ + public Function getFunction() { + return function; + } + + /** + * Represents the function details within a tool call, including the function name and its + * arguments. + */ + public static class Function { + + private final String name; + private final String arguments; + + /** + * Creates a new function with the specified name and arguments. + * + * @param name the name of the function to call + * @param arguments the function arguments as a JSON string + */ + public Function(String name, String arguments) { + this.name = name; + this.arguments = arguments; + } + + /** + * Returns the name of the function to call. + * + * @return the function name + */ + public String getName() { + return name; + } + + /** + * Returns the function arguments as a JSON string. + * + * @return the arguments in JSON format + */ + public String getArguments() { + return arguments; + } + } + + /** + * Factory method to create a new tool call with the specified parameters. + * + * @param id unique identifier for the tool call + * @param name the name of the function to call + * @param arguments the function arguments as a JSON string + * @return a new ToolCall instance + */ + public static ToolCall toolCall(final String id, final String name, final String arguments) { + return new ToolCall(id, new ToolCall.Function(name, arguments)); + } + } +} diff --git a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/Evaluator.java b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/Evaluator.java new file mode 100644 index 00000000000..fe69c813909 --- /dev/null +++ b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/Evaluator.java @@ -0,0 +1,7 @@ +package datadog.trace.api.aiguard; + +import java.util.List; + +public interface Evaluator { + AIGuard.Evaluation evaluate(List messages, AIGuard.Options options); +} diff --git a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java new file mode 100644 index 00000000000..bdb5a1869c4 --- /dev/null +++ b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java @@ -0,0 +1,17 @@ +package datadog.trace.api.aiguard.noop; + +import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW; + +import datadog.trace.api.aiguard.AIGuard.Evaluation; +import datadog.trace.api.aiguard.AIGuard.Message; +import datadog.trace.api.aiguard.AIGuard.Options; +import datadog.trace.api.aiguard.Evaluator; +import java.util.List; + +public final class NoOpEvaluator implements Evaluator { + + @Override + public Evaluation evaluate(final List messages, final Options options) { + return new Evaluation(ALLOW, "AI Guard is not enabled"); + } +} diff --git a/dd-trace-api/src/main/java/datadog/trace/api/config/AIGuardConfig.java b/dd-trace-api/src/main/java/datadog/trace/api/config/AIGuardConfig.java new file mode 100644 index 00000000000..2d685c4a098 --- /dev/null +++ b/dd-trace-api/src/main/java/datadog/trace/api/config/AIGuardConfig.java @@ -0,0 +1,15 @@ +package datadog.trace.api.config; + +public final class AIGuardConfig { + + public static final String AI_GUARD_ENABLED = "ai_guard.enabled"; + public static final String AI_GUARD_ENDPOINT = "ai_guard.endpoint"; + public static final String AI_GUARD_TIMEOUT = "ai_guard.timeout"; + public static final String AI_GUARD_MAX_CONTENT_SIZE = "ai_guard.max-content-size"; + public static final String AI_GUARD_MAX_MESSAGES_LENGTH = "ai_guard.max-messages-length"; + + public static final boolean DEFAULT_AI_GUARD_ENABLED = false; + public static final int DEFAULT_AI_GUARD_TIMEOUT = 10_000; + public static final int DEFAULT_AI_GUARD_MAX_CONTENT_SIZE = 512 * 1024; + public static final int DEFAULT_AI_GUARD_MAX_MESSAGES_LENGTH = 16; +} diff --git a/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java b/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java index 0a68ea849f9..c3f3d63ca8b 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/config/GeneralConfig.java @@ -16,6 +16,7 @@ public final class GeneralConfig { public static final String CONFIGURATION_FILE = "trace.config"; public static final String API_KEY = "api-key"; public static final String APPLICATION_KEY = "application-key"; + public static final String APP_KEY = "app-key"; // alias for application key public static final String API_KEY_FILE = "api-key-file"; public static final String APPLICATION_KEY_FILE = "application-key-file"; public static final String SITE = "site"; diff --git a/dd-trace-api/src/test/groovy/datadog/trace/api/aiguard/AIGuardTest.groovy b/dd-trace-api/src/test/groovy/datadog/trace/api/aiguard/AIGuardTest.groovy new file mode 100644 index 00000000000..f987cb307b0 --- /dev/null +++ b/dd-trace-api/src/test/groovy/datadog/trace/api/aiguard/AIGuardTest.groovy @@ -0,0 +1,72 @@ +package datadog.trace.api.aiguard + +import spock.lang.Specification + +import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW + + +class AIGuardTest extends Specification { + + void 'test text message'() { + when: + final message = AIGuard.Message.message('user', 'What day is today?') + + then: + message.role == 'user' + message.content == 'What day is today?' + message.toolCallId == null + message.toolCalls == null + } + + void 'test assistant tool call'() { + when: + final message = AIGuard.Message.assistant( + AIGuard.ToolCall.toolCall('1', 'execute_http_request', '{ "url": "http://localhost" }'), + AIGuard.ToolCall.toolCall('2', 'random_number', '{ "min": 0, "max": 10 }') + ) + + then: + message.role == 'assistant' + message.content == null + message.toolCallId == null + message.toolCalls.size() == 2 + + final http = message.toolCalls[0] + http.id == '1' + http.function.name == 'execute_http_request' + http.function.arguments == '{ "url": "http://localhost" }' + + final random = message.toolCalls[1] + random.id == '2' + random.function.name == 'random_number' + random.function.arguments == '{ "min": 0, "max": 10 }' + } + + void 'test tool'() { + when: + final message = AIGuard.Message.tool('2', '5') + + then: + message.role == 'tool' + message.content == '5' + message.toolCallId == '2' + message.toolCalls == null + } + + void 'test noop implementation'() { + when: + final eval = AIGuard.evaluate([ + AIGuard.Message.message('system', 'You are a beautiful AI assistant'), + AIGuard.Message.message('user', 'What day is today?'), + AIGuard.Message.message('assistant', 'Today is monday'), + AIGuard.Message.message('user', 'Give me a random number'), + AIGuard.Message.assistant(AIGuard.ToolCall.toolCall('1', 'generate_random_number', '{ "min": 0, "max": 10 }')), + AIGuard.Message.tool('1', '5'), + AIGuard.Message.message('assistant', 'Your number is 5') + ]) + + then: + eval.action == ALLOW + eval.reason == 'AI Guard is not enabled' + } +} diff --git a/internal-api/src/main/java/datadog/trace/api/Config.java b/internal-api/src/main/java/datadog/trace/api/Config.java index 92a7cb00e6c..56ee2a22fd0 100644 --- a/internal-api/src/main/java/datadog/trace/api/Config.java +++ b/internal-api/src/main/java/datadog/trace/api/Config.java @@ -185,6 +185,15 @@ import static datadog.trace.api.DDTags.SCHEMA_VERSION_TAG_KEY; import static datadog.trace.api.DDTags.SERVICE; import static datadog.trace.api.DDTags.SERVICE_TAG; +import static datadog.trace.api.config.AIGuardConfig.AI_GUARD_ENABLED; +import static datadog.trace.api.config.AIGuardConfig.AI_GUARD_ENDPOINT; +import static datadog.trace.api.config.AIGuardConfig.AI_GUARD_MAX_CONTENT_SIZE; +import static datadog.trace.api.config.AIGuardConfig.AI_GUARD_MAX_MESSAGES_LENGTH; +import static datadog.trace.api.config.AIGuardConfig.AI_GUARD_TIMEOUT; +import static datadog.trace.api.config.AIGuardConfig.DEFAULT_AI_GUARD_ENABLED; +import static datadog.trace.api.config.AIGuardConfig.DEFAULT_AI_GUARD_MAX_CONTENT_SIZE; +import static datadog.trace.api.config.AIGuardConfig.DEFAULT_AI_GUARD_MAX_MESSAGES_LENGTH; +import static datadog.trace.api.config.AIGuardConfig.DEFAULT_AI_GUARD_TIMEOUT; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_DOWNSTREAM_REQUEST_ANALYSIS_SAMPLE_RATE; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_ENABLED; import static datadog.trace.api.config.AppSecConfig.API_SECURITY_ENABLED_EXPERIMENTAL; @@ -330,6 +339,7 @@ import static datadog.trace.api.config.GeneralConfig.API_KEY_FILE; import static datadog.trace.api.config.GeneralConfig.APPLICATION_KEY; import static datadog.trace.api.config.GeneralConfig.APPLICATION_KEY_FILE; +import static datadog.trace.api.config.GeneralConfig.APP_KEY; import static datadog.trace.api.config.GeneralConfig.AZURE_APP_SERVICES; import static datadog.trace.api.config.GeneralConfig.DATA_JOBS_ENABLED; import static datadog.trace.api.config.GeneralConfig.DATA_JOBS_OPENLINEAGE_ENABLED; @@ -1255,6 +1265,12 @@ public static String getHostName() { private final RumInjectorConfig rumInjectorConfig; + private final boolean aiGuardEnabled; + private final String aiGuardEndpoint; + private final int aiGuardTimeout; + private final int aiGuardMaxMessagesLength; + private final int aiGuardMaxContentSize; + static { // Bind telemetry collector to config module before initializing ConfigProvider OtelEnvMetricCollectorProvider.register(OtelEnvMetricCollectorImpl.getInstance()); @@ -1298,7 +1314,7 @@ private Config(final ConfigProvider configProvider, final InstrumenterConfig ins String tmpApplicationKey = configProvider.getStringExcludingSource( - APPLICATION_KEY, null, SystemPropertiesConfigSource.class); + APPLICATION_KEY, null, SystemPropertiesConfigSource.class, APP_KEY); String applicationKeyFile = configProvider.getString(APPLICATION_KEY_FILE); if (applicationKeyFile != null) { try { @@ -2797,6 +2813,15 @@ PROFILING_DATADOG_PROFILER_ENABLED, isDatadogProfilerSafeInCurrentEnvironment()) this.rumInjectorConfig = parseRumConfig(configProvider); + this.aiGuardEnabled = configProvider.getBoolean(AI_GUARD_ENABLED, DEFAULT_AI_GUARD_ENABLED); + this.aiGuardEndpoint = configProvider.getString(AI_GUARD_ENDPOINT); + this.aiGuardTimeout = configProvider.getInteger(AI_GUARD_TIMEOUT, DEFAULT_AI_GUARD_TIMEOUT); + this.aiGuardMaxContentSize = + configProvider.getInteger(AI_GUARD_MAX_CONTENT_SIZE, DEFAULT_AI_GUARD_MAX_CONTENT_SIZE); + this.aiGuardMaxMessagesLength = + configProvider.getInteger( + AI_GUARD_MAX_MESSAGES_LENGTH, DEFAULT_AI_GUARD_MAX_MESSAGES_LENGTH); + log.debug("New instance: {}", this); } @@ -5158,6 +5183,26 @@ public RumInjectorConfig getRumInjectorConfig() { return this.rumInjectorConfig; } + public boolean isAiGuardEnabled() { + return aiGuardEnabled; + } + + public String getAiGuardEndpoint() { + return aiGuardEndpoint; + } + + public int getAiGuardMaxContentSize() { + return aiGuardMaxContentSize; + } + + public int getAiGuardMaxMessagesLength() { + return aiGuardMaxMessagesLength; + } + + public int getAiGuardTimeout() { + return aiGuardTimeout; + } + private Set getSettingsSetFromEnvironment( String name, Function mapper, boolean splitOnWS) { final String value = configProvider.getString(name, ""); @@ -5854,6 +5899,10 @@ public String toString() { + experimentalPropagateProcessTagsEnabled + ", rumInjectorConfig=" + (rumInjectorConfig == null ? "null" : rumInjectorConfig.jsonPayload()) + + ", aiGuardEnabled=" + + aiGuardEnabled + + ", aiGuardEndpoint=" + + aiGuardEndpoint + '}'; } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 40ea950f3c4..b95f02266f9 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -137,6 +137,9 @@ include( ":dd-java-agent:cws-tls", ) +// AI Guard +include(":dd-java-agent:agent-aiguard") + // misc include( ":dd-java-agent:testing",