Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod>
return InstrumenterImpl.create(clazz, methods);
}

@Override
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
Map<MethodKey, CheckMethod> methodsToInstrument = new HashMap<>();
private interface CheckerMethodVisitor {
void visit(Class<?> currentClass, int access, String checkerMethodName, String checkerMethodDescriptor);
}

private void visitClassAndSupers(Class<?> checkerClass, CheckerMethodVisitor checkerMethodVisitor) throws ClassNotFoundException {
Set<Class<?>> visitedClasses = new HashSet<>();
ArrayDeque<Class<?>> classesToVisit = new ArrayDeque<>(Collections.singleton(checkerClass));
while (classesToVisit.isEmpty() == false) {
Expand All @@ -57,67 +58,76 @@ public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws I
}
visitedClasses.add(currentClass);

var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
try {
var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {

@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
try {
if (OBJECT_INTERNAL_NAME.equals(superName) == false) {
classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName()));
@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
try {
if (OBJECT_INTERNAL_NAME.equals(superName) == false) {
classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName()));
}
for (var interfaceName : interfaces) {
classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName()));
}
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Cannot inspect checker class " + currentClass.getName(), e);
}
for (var interfaceName : interfaces) {
classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName()));
}
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Cannot inspect checker class " + checkerClass.getName(), e);
}
}

@Override
public MethodVisitor visitMethod(
int access,
String checkerMethodName,
String checkerMethodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);

var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
var checkMethod = new CheckMethod(
Type.getInternalName(currentClass),
checkerMethodName,
checkerParameterDescriptors
);

methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod);
@Override
public MethodVisitor visitMethod(
int access,
String checkerMethodName,
String checkerMethodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
checkerMethodVisitor.visit(currentClass, access, checkerMethodName, checkerMethodDescriptor);
return mv;
}
return mv;
}
};
reader.accept(visitor, 0);
};
reader.accept(visitor, 0);
} catch (IOException e) {
throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e);
}
}
}

@Override
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws ClassNotFoundException {
Map<MethodKey, CheckMethod> methodsToInstrument = new HashMap<>();

visitClassAndSupers(checkerClass, (currentClass, access, checkerMethodName, checkerMethodDescriptor) -> {
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);

var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
var checkMethod = new CheckMethod(Type.getInternalName(currentClass), checkerMethodName, checkerParameterDescriptors);
methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod);
}
});

return methodsToInstrument;
}

@SuppressForbidden(reason = "Need access to abstract methods (protected/package internal) in base class")
@Override
public InstrumentationInfo lookupImplementationMethod(
Class<?> targetSuperclass,
String methodName,
String targetMethodName,
Class<?> implementationClass,
Class<?> checkerClass,
String checkMethodName,
Class<?>... parameterTypes
) throws NoSuchMethodException, ClassNotFoundException {

var targetMethod = targetSuperclass.getDeclaredMethod(methodName, parameterTypes);
var targetMethod = targetSuperclass.getDeclaredMethod(targetMethodName, parameterTypes);
var implementationMethod = implementationClass.getMethod(targetMethod.getName(), targetMethod.getParameterTypes());
validateTargetMethod(implementationClass, targetMethod, implementationMethod);

Expand All @@ -128,33 +138,15 @@ public InstrumentationInfo lookupImplementationMethod(

CheckMethod[] checkMethod = new CheckMethod[1];

try {
InstrumenterImpl.ClassFileInfo classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
@Override
public MethodVisitor visitMethod(
int access,
String methodName,
String methodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, methodName, methodDescriptor, signature, exceptions);
if (methodName.equals(checkMethodName)) {
var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor);
if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) {
var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList();
checkMethod[0] = new CheckMethod(Type.getInternalName(checkerClass), methodName, checkerParameterDescriptors);
}
}
return mv;
visitClassAndSupers(checkerClass, (currentClass, access, methodName, methodDescriptor) -> {
if (methodName.equals(checkMethodName)) {
var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor);
if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) {
var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList();
checkMethod[0] = new CheckMethod(Type.getInternalName(currentClass), methodName, checkerParameterDescriptors);
}
};
reader.accept(visitor, 0);
} catch (IOException e) {
throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e);
}
}
});

if (checkMethod[0] == null) {
throw new NoSuchMethodException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.test.ESTestCase;
import org.objectweb.asm.Type;

import java.io.IOException;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -90,7 +89,9 @@ interface TestCheckerMixed {
void checkInstanceMethodManual(Class<?> clazz, TestTargetBaseClass that, int x, String y);
}

public void testInstrumentationTargetLookup() throws IOException {
interface TestCheckerDerived3 extends TestCheckerMixed {}

public void testInstrumentationTargetLookup() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);

assertThat(checkMethods, aMapWithSize(3));
Expand Down Expand Up @@ -143,7 +144,7 @@ public void testInstrumentationTargetLookup() throws IOException {
);
}

public void testInstrumentationTargetLookupWithOverloads() throws IOException {
public void testInstrumentationTargetLookupWithOverloads() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class);

assertThat(checkMethods, aMapWithSize(2));
Expand Down Expand Up @@ -175,7 +176,7 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException {
);
}

public void testInstrumentationTargetLookupWithDerivedClass() throws IOException {
public void testInstrumentationTargetLookupWithDerivedClass() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerDerived2.class);

assertThat(checkMethods, aMapWithSize(4));
Expand Down Expand Up @@ -244,7 +245,7 @@ public void testInstrumentationTargetLookupWithDerivedClass() throws IOException
);
}

public void testInstrumentationTargetLookupWithCtors() throws IOException {
public void testInstrumentationTargetLookupWithCtors() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);

assertThat(checkMethods, aMapWithSize(2));
Expand Down Expand Up @@ -276,7 +277,7 @@ public void testInstrumentationTargetLookupWithCtors() throws IOException {
);
}

public void testInstrumentationTargetLookupWithExtraMethods() throws IOException {
public void testInstrumentationTargetLookupWithExtraMethods() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerMixed.class);

assertThat(checkMethods, aMapWithSize(1));
Expand Down Expand Up @@ -371,7 +372,7 @@ public void testLookupImplementationMethodWithBaseClass() throws ClassNotFoundEx
);
}

public void testLookupImplementationMethodWithInheritance() throws ClassNotFoundException, NoSuchMethodException {
public void testLookupImplementationMethodWithInheritanceOnTarget() throws ClassNotFoundException, NoSuchMethodException {
var info = instrumentationService.lookupImplementationMethod(
TestTargetBaseClass.class,
"instanceMethod2",
Expand Down Expand Up @@ -409,6 +410,44 @@ public void testLookupImplementationMethodWithInheritance() throws ClassNotFound
);
}

public void testLookupImplementationMethodWithInheritanceOnChecker() throws ClassNotFoundException, NoSuchMethodException {
var info = instrumentationService.lookupImplementationMethod(
TestTargetBaseClass.class,
"instanceMethod2",
TestTargetImplementationClass.class,
TestCheckerDerived3.class,
"checkInstanceMethodManual",
int.class,
String.class
);

assertThat(
info.targetMethod(),
equalTo(
new MethodKey(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetIntermediateClass",
"instanceMethod2",
List.of("I", "java/lang/String")
)
)
);
assertThat(
info.checkMethod(),
equalTo(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerMixed",
"checkInstanceMethodManual",
List.of(
"Ljava/lang/Class;",
"Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetBaseClass;",
"I",
"Ljava/lang/String;"
)
)
)
);
}

public void testParseCheckerMethodSignatureStaticMethod() {
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
"check$org_example_TestClass$$staticMethod",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

package org.elasticsearch.entitlement.instrumentation;

import java.io.IOException;
import java.util.Map;

/**
Expand All @@ -23,7 +22,7 @@ record InstrumentationInfo(MethodKey targetMethod, CheckMethod checkMethod) {}

Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods);

Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws IOException;
Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws ClassNotFoundException;

InstrumentationInfo lookupImplementationMethod(
Class<?> targetSuperclass,
Expand Down