Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -23,49 +23,86 @@
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class InstrumentationServiceImpl implements InstrumentationService {

private static final String OBJECT_INTERNAL_NAME = Type.getInternalName(Object.class);

@Override
public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods) {
return InstrumenterImpl.create(clazz, methods);
}

@Override
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
var methodsToInstrument = new HashMap<MethodKey, CheckMethod>();
var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
@Override
public MethodVisitor visitMethod(
int access,
String checkerMethodName,
String checkerMethodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);
Map<MethodKey, CheckMethod> methodsToInstrument = new HashMap<>();

var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
var checkMethod = new CheckMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors);
Set<Class<?>> visitedClasses = new HashSet<>();
ArrayDeque<Class<?>> classesToVisit = new ArrayDeque<>(Collections.singleton(checkerClass));
while (classesToVisit.isEmpty() == false) {
var currentClass = classesToVisit.remove();
if (visitedClasses.contains(currentClass)) {
continue;
}
visitedClasses.add(currentClass);

methodsToInstrument.put(methodToInstrument, checkMethod);
var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need asm here? Since we have Class objects already, can't we use reflection to find all of the methods?

Copy link
Contributor Author

@ldematte ldematte Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same, but every time we use reflection (e.g. in #120811 (comment)), we end up regretting it and go back to ASM; for example in the linked issue I tried to use reflection first, only to end up with NoClassDefFoundError errors in tests due to the fact that our interface uses types that may not (are not) always available (at least at this time).
It might be fine in this case if we use it only for finding the base classes, but since we need to use ASM to find all the methods anyway, why risk?

ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {

@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
try {
if (OBJECT_INTERNAL_NAME.equals(superName) == false) {
classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName()));
}
for (var interfaceName : interfaces) {
classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName()));
}
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Cannot inspect checker class " + checkerClass.getName(), e);
}
}
return mv;
}
};
reader.accept(visitor, 0);

@Override
public MethodVisitor visitMethod(
int access,
String checkerMethodName,
String checkerMethodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);

var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
var checkMethod = new CheckMethod(
Type.getInternalName(currentClass),
checkerMethodName,
checkerParameterDescriptors
);

methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod);
}
return mv;
}
};
reader.accept(visitor, 0);
}
return methodsToInstrument;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ interface TestChecker {
void check$org_example_TestTargetClass$instanceMethodWithArgs(Class<?> clazz, TestTargetClass that, int x, int y);
}

interface TestCheckerDerived extends TestChecker {
void check$org_example_TestTargetClass$instanceMethodNoArgs(Class<?> clazz, TestTargetClass that);

void check$org_example_TestTargetClass$differentInstanceMethod(Class<?> clazz, TestTargetClass that);
}

interface TestCheckerDerived2 extends TestCheckerDerived, TestChecker {}

interface TestCheckerOverloads {
void check$org_example_TestTargetClass$$staticMethodWithOverload(Class<?> clazz, int x, int y);

Expand Down Expand Up @@ -160,6 +168,75 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException {
);
}

public void testInstrumentationTargetLookupWithDerivedClass() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerDerived2.class);

assertThat(checkMethods, aMapWithSize(4));
assertThat(
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "staticMethod", List.of("I", "java/lang/String", "java/lang/Object"))),
equalTo(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
"check$org_example_TestTargetClass$$staticMethod",
List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;", "Ljava/lang/Object;")
)
)
)
);
assertThat(
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "instanceMethodNoArgs", List.of())),
equalTo(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerDerived",
"check$org_example_TestTargetClass$instanceMethodNoArgs",
List.of(
"Ljava/lang/Class;",
"Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass;"
)
)
)
)
);
assertThat(
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "instanceMethodWithArgs", List.of("I", "I"))),
equalTo(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
"check$org_example_TestTargetClass$instanceMethodWithArgs",
List.of(
"Ljava/lang/Class;",
"Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass;",
"I",
"I"
)
)
)
)
);
assertThat(
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "differentInstanceMethod", List.of())),
equalTo(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerDerived",
"check$org_example_TestTargetClass$differentInstanceMethod",
List.of(
"Ljava/lang/Class;",
"Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass;"
)
)
)
)
);
}

public void testInstrumentationTargetLookupWithCtors() throws IOException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ public static EntitlementChecker checker() {
public static void initialize(Instrumentation inst) throws Exception {
manager = initChecker();

Map<MethodKey, CheckMethod> checkMethods = new HashMap<>(INSTRUMENTATION_SERVICE.lookupMethods(EntitlementChecker.class));
var latestCheckerInterface = getVersionSpecificCheckerClass(EntitlementChecker.class);

Map<MethodKey, CheckMethod> checkMethods = new HashMap<>(INSTRUMENTATION_SERVICE.lookupMethods(latestCheckerInterface));
var fileSystemProviderClass = FileSystems.getDefault().provider().getClass();
Stream.of(
INSTRUMENTATION_SERVICE.lookupImplementationMethod(
Expand All @@ -83,7 +84,7 @@ public static void initialize(Instrumentation inst) throws Exception {

var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet());

Instrumenter instrumenter = INSTRUMENTATION_SERVICE.newInstrumenter(EntitlementChecker.class, checkMethods);
Instrumenter instrumenter = INSTRUMENTATION_SERVICE.newInstrumenter(latestCheckerInterface, checkMethods);
inst.addTransformer(new Transformer(instrumenter, classesToTransform), true);
inst.retransformClasses(findClassesToRetransform(inst.getAllLoadedClasses(), classesToTransform));
}
Expand Down Expand Up @@ -130,23 +131,40 @@ private static PolicyManager createPolicyManager() {
return new PolicyManager(serverPolicy, agentEntitlements, pluginPolicies, resolver, AGENTS_PACKAGE_NAME, ENTITLEMENTS_MODULE);
}

private static ElasticsearchEntitlementChecker initChecker() {
final PolicyManager policyManager = createPolicyManager();

/**
* Returns the "most recent" checker class compatible with the current runtime Java version.
* For checkers, we have (optionally) version specific classes, each with a prefix (e.g. Java23).
* The mapping cannot be automatic, as it depends on the actual presence of these classes in the final Jar (see
* the various mainXX source sets).
*/
private static Class<?> getVersionSpecificCheckerClass(Class<?> baseClass) {
String packageName = baseClass.getPackageName();
String baseClassName = baseClass.getSimpleName();
int javaVersion = Runtime.version().feature();

final String classNamePrefix;
if (javaVersion >= 23) {
// All Java version from 23 onwards will be able to use che checks in the Java23EntitlementChecker interface and implementation
classNamePrefix = "Java23";
} else {
// For any other Java version, the basic EntitlementChecker interface and implementation contains all the supported checks
classNamePrefix = "";
}
final String className = "org.elasticsearch.entitlement.runtime.api." + classNamePrefix + "ElasticsearchEntitlementChecker";
final String className = packageName + "." + classNamePrefix + baseClassName;
Class<?> clazz;
try {
clazz = Class.forName(className);
} catch (ClassNotFoundException e) {
throw new AssertionError("entitlement lib cannot find entitlement impl", e);
throw new AssertionError("entitlement lib cannot find entitlement class " + className, e);
}
return clazz;
}

private static ElasticsearchEntitlementChecker initChecker() {
final PolicyManager policyManager = createPolicyManager();

final Class<?> clazz = getVersionSpecificCheckerClass(ElasticsearchEntitlementChecker.class);

Constructor<?> constructor;
try {
constructor = clazz.getConstructor(PolicyManager.class);
Expand Down