Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package datadog.trace.agent.tooling.bytebuddy.iast;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import javax.annotation.Nullable;
import net.bytebuddy.asm.AsmVisitorWrapper;
import net.bytebuddy.description.field.FieldDescription;
import net.bytebuddy.description.field.FieldList;
import net.bytebuddy.description.method.MethodList;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.implementation.Implementation;
import net.bytebuddy.jar.asm.ClassVisitor;
import net.bytebuddy.jar.asm.FieldVisitor;
import net.bytebuddy.jar.asm.Opcodes;
import net.bytebuddy.jar.asm.Type;
import net.bytebuddy.pool.TypePool;
import net.bytebuddy.utility.OpenedClassReader;

public class TaintableDbVisitor implements AsmVisitorWrapper {

public static volatile boolean DEBUG = false;
static volatile boolean ENABLED = true;

private static final String INTERFACE_NAME = "datadog/trace/api/iast/TaintableDb";
private static final String FIELD_NAME = "$$DD$recordsRead";
private static final String GETTER_NAME = "$$DD$getRecordsRead";
private static final String SETTER_NAME = "$$DD$setRecordsRead";

private final Set<String> types;

public TaintableDbVisitor(final String... classNames) {
types = new HashSet<>(Arrays.asList(classNames));
}

@Override
public int mergeWriter(final int flags) {
return flags;
}

@Override
public int mergeReader(int flags) {
return flags;
}

@Override
public ClassVisitor wrap(
final TypeDescription instrumentedType,
final ClassVisitor classVisitor,
final Implementation.Context implementationContext,
final TypePool typePool,
final FieldList<FieldDescription.InDefinedShape> fields,
final MethodList<?> methods,
final int writerFlags,
final int readerFlags) {
if (ENABLED) {
return types.contains(instrumentedType.getName())
? new AddTaintableDbInterfaceVisitor(classVisitor)
: classVisitor;
} else {
return NoOp.INSTANCE.wrap(
instrumentedType,
classVisitor,
implementationContext,
typePool,
fields,
methods,
writerFlags,
readerFlags);
}
}

private static class AddTaintableDbInterfaceVisitor extends ClassVisitor {

private String owner;

private boolean addTaintable = true;

protected AddTaintableDbInterfaceVisitor(final ClassVisitor classVisitor) {
super(OpenedClassReader.ASM_API, classVisitor);
}

@Override
public void visit(
final int version,
final int access,
final String name,
final String signature,
final String superName,
final String[] interfaces) {
owner = name;
if (interfaces != null) {
for (final String iface : interfaces) {
if (INTERFACE_NAME.equals(iface)) {
addTaintable = false;
break;
}
}
}
super.visit(
version,
access,
name,
signature,
superName,
addTaintable ? addInterface(interfaces) : interfaces);
}

@Override
public void visitEnd() {
if (addTaintable) {
addField();
// addGetter();
// if (!DEBUG) {
// addSetter();
// } else {
// addSetterDebug();
// }
}
}

private String[] addInterface(@Nullable final String[] interfaces) {
if (interfaces == null || interfaces.length == 0) {
return new String[] {INTERFACE_NAME};
} else {
final String[] newInterfaces = new String[interfaces.length + 1];
System.arraycopy(interfaces, 0, newInterfaces, 0, interfaces.length);
newInterfaces[newInterfaces.length - 1] = INTERFACE_NAME;
return newInterfaces;
}
}

private void addField() {
final FieldVisitor fv =
cv.visitField(
Opcodes.ACC_PRIVATE | Opcodes.ACC_TRANSIENT | Opcodes.ACC_VOLATILE,
FIELD_NAME,
Type.INT_TYPE.getDescriptor(),
null,
null);
fv.visitEnd();
}

// private void addGetter() {
// final MethodVisitor mv =
// cv.visitMethod(Opcodes.ACC_PUBLIC, GETTER_NAME, Type.INT_TYPE.getDescriptor(), null,
// null);
// mv.visitCode();
// mv.visitVarInsn(Opcodes.ALOAD, 0);
// mv.visitFieldInsn(Opcodes.GETSTATIC , owner, FIELD_NAME, Type.INT_TYPE.getDescriptor());
// mv.visitInsn(Opcodes.ARETURN);
// mv.visitMaxs(1, 1);
// mv.visitEnd();
// }
//
// private void addSetter() {
// final MethodVisitor mv =
// cv.visitMethod(
// Opcodes.ACC_PUBLIC, SETTER_NAME, Type.INT_TYPE.getDescriptor(), null, null);
// mv.visitCode();
// mv.visitVarInsn(Opcodes.ALOAD, 0);
// mv.visitVarInsn(Opcodes.ILOAD, 1);
// mv.visitFieldInsn(Opcodes.PUTFIELD, owner, FIELD_NAME, Type.INT_TYPE.getDescriptor());
// mv.visitInsn(Opcodes.RETURN);
// mv.visitMaxs(2, 2);
// mv.visitEnd();
// }
//
// private void addSetterDebug() {
// final MethodVisitor mv =
// cv.visitMethod(
// Opcodes.ACC_PUBLIC, SETTER_NAME, Type.INT_TYPE.getDescriptor(), null, null);
// mv.visitCode();
// mv.visitVarInsn(Opcodes.ALOAD, 0);
// mv.visitVarInsn(Opcodes.ILOAD, 1);
// mv.visitFieldInsn(Opcodes.PUTFIELD, owner, FIELD_NAME, Type.INT_TYPE.getDescriptor());
//
// mv.visitVarInsn(Opcodes.ALOAD, 0);
// mv.visitMethodInsn(
// Opcodes.INVOKESTATIC,
// Type.getInternalName(TaintableDb.DebugLogger.class),
// "logTaint",
// "(Ldatadog/trace/api/iast/TaintableDb;)V",
// false);
// mv.visitInsn(Opcodes.RETURN);
// mv.visitMaxs(2, 2);
// mv.visitEnd();
// }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
1 io.opentelemetry.javaagent.*
1 java.*
0 java.lang.ClassLoader
0 java.sql.ResultSet
# allow exception profiling instrumentation
0 java.lang.Exception
0 java.lang.Error
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package datadog.trace.instrumentation.jdbc;

import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface;
import static net.bytebuddy.matcher.ElementMatchers.isMethod;
import static net.bytebuddy.matcher.ElementMatchers.named;
import static net.bytebuddy.matcher.ElementMatchers.takesArguments;

import com.google.auto.service.AutoService;
import datadog.trace.advice.ActiveRequestContext;
import datadog.trace.advice.RequiresRequestContext;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.agent.tooling.bytebuddy.iast.TaintableDbVisitor;
import datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.Source;
import datadog.trace.api.iast.SourceTypes;
import datadog.trace.api.iast.TaintableDb;
import datadog.trace.api.iast.propagation.PropagationModule;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.matcher.ElementMatcher;

@AutoService(InstrumenterModule.class)
public class IastResultSetInstrumentation extends InstrumenterModule.Iast
implements Instrumenter.ForTypeHierarchy, Instrumenter.HasTypeAdvice {

public IastResultSetInstrumentation() {
super("jdbc", "resultset");
}

// @Override
// public String instrumentedType() {
// return "java.sql.ResultSet";
// }

@Override
public String hierarchyMarkerType() {
return "java.sql.ResultSet";
}

@Override
public ElementMatcher<TypeDescription> hierarchyMatcher() {
return implementsInterface(NameMatchers.named("java.sql.ResultSet"));
}

@Override
public void typeAdvice(TypeTransformer transformer) {
transformer.applyAdvice(new TaintableDbVisitor(hierarchyMarkerType()));
}

@Override
public void methodAdvice(MethodTransformer transformer) {
transformer.applyAdvice(
isMethod().and(named("getInt")).and(takesArguments(int.class)),
IastResultSetInstrumentation.class.getName() + "$GetParameterAdvice");
}

@RequiresRequestContext(RequestContextSlot.IAST)
public static class GetParameterAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Source(SourceTypes.SQL_TABLE)
public static void onExit(
@Advice.Argument(0) final int columnIndex,
@Advice.Return final Object value,
@Advice.This final TaintableDb resultSet,
@ActiveRequestContext RequestContext reqCtx) {
// int recordsRead = resultSet.$$DD$RecordsRead;
// int recordsRead = resultSet.$$DD$getRecordsRead();
// resultSet.$$DD$setRecordsRead(recordsRead + 1);
// if (recordsRead > 1) {
// return;
// }
if (value == null) {
return;
}
final PropagationModule module = InstrumentationBridge.PROPAGATION;
if (module == null) {
return;
}
IastContext ctx = reqCtx.getData(RequestContextSlot.IAST);
module.taintString(ctx, String.valueOf(value), SourceTypes.SQL_TABLE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ private SourceTypes() {}
public static final byte GRPC_BODY = 13;
public static final byte KAFKA_MESSAGE_KEY = 14;
public static final byte KAFKA_MESSAGE_VALUE = 15;
public static final byte SQL_TABLE = 16;

/** Array indexed with all source types, the index should match the source types values */
public static final String[] STRINGS = {
Expand All @@ -42,7 +43,8 @@ private SourceTypes() {}
"http.request.path",
"grpc.request.body",
"kafka.message.key",
"kafka.message.value"
"kafka.message.value",
"sql.row.value"
};

public static String toString(final byte source) {
Expand Down
49 changes: 49 additions & 0 deletions internal-api/src/main/java/datadog/trace/api/iast/TaintableDb.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package datadog.trace.api.iast;

import de.thetaphi.forbiddenapis.SuppressForbidden;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public interface TaintableDb {

int $$DD$getRecordsRead();

void $$DD$setRecordsRead(final int recordsRead);

@SuppressForbidden
class DebugLogger {
private static final Logger LOGGER;

static {
try {
LOGGER = LoggerFactory.getLogger("TaintableDb tainted objects");
Class<?> levelCls = Class.forName("ch.qos.logback.classic.Level");
Method setLevel = LOGGER.getClass().getMethod("setLevel", levelCls);
Object debugLevel = levelCls.getField("DEBUG").get(null);
setLevel.invoke(LOGGER, debugLevel);
} catch (IllegalAccessException
| NoSuchFieldException
| ClassNotFoundException
| NoSuchMethodException
| InvocationTargetException e) {
throw new RuntimeException(e);
}
}

public static void logTaint(TaintableDb t) {
String content;
if (t.getClass().getName().startsWith("java.")) {
content = t.toString();
} else {
content = "(value not shown)"; // toString() may trigger tainting
}
LOGGER.debug(
"taint: {}[{}] {}",
t.getClass().getSimpleName(),
Integer.toHexString(System.identityHashCode(t)),
content);
}
}
}
Loading