diff --git a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java index 05a5af374e5d9..ffcc23e16d1f6 100644 --- a/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java +++ b/libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java @@ -44,10 +44,11 @@ public Instrumenter newInstrumenter(Class clazz, Map return InstrumenterImpl.create(clazz, methods); } - @Override - public Map lookupMethods(Class checkerClass) throws IOException { - Map methodsToInstrument = new HashMap<>(); + private interface CheckerMethodVisitor { + void visit(Class currentClass, int access, String checkerMethodName, String checkerMethodDescriptor); + } + private void visitClassAndSupers(Class checkerClass, CheckerMethodVisitor checkerMethodVisitor) throws ClassNotFoundException { Set> visitedClasses = new HashSet<>(); ArrayDeque> classesToVisit = new ArrayDeque<>(Collections.singleton(checkerClass)); while (classesToVisit.isEmpty() == false) { @@ -57,52 +58,61 @@ public Map lookupMethods(Class checkerClass) throws I } visitedClasses.add(currentClass); - var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass); - ClassReader reader = new ClassReader(classFileInfo.bytecodes()); - ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) { + try { + var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass); + ClassReader reader = new ClassReader(classFileInfo.bytecodes()); + ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) { - @Override - public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { - super.visit(version, access, name, signature, superName, interfaces); - try { - if (OBJECT_INTERNAL_NAME.equals(superName) == false) { - classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName())); + @Override + public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { + super.visit(version, access, name, signature, superName, interfaces); + try { + if (OBJECT_INTERNAL_NAME.equals(superName) == false) { + classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName())); + } + for (var interfaceName : interfaces) { + classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName())); + } + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Cannot inspect checker class " + currentClass.getName(), e); } - for (var interfaceName : interfaces) { - classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName())); - } - } catch (ClassNotFoundException e) { - throw new IllegalArgumentException("Cannot inspect checker class " + checkerClass.getName(), e); } - } - @Override - public MethodVisitor visitMethod( - int access, - String checkerMethodName, - String checkerMethodDescriptor, - String signature, - String[] exceptions - ) { - var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions); - if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) { - var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor); - var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes); - - var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList(); - var checkMethod = new CheckMethod( - Type.getInternalName(currentClass), - checkerMethodName, - checkerParameterDescriptors - ); - - methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod); + @Override + public MethodVisitor visitMethod( + int access, + String checkerMethodName, + String checkerMethodDescriptor, + String signature, + String[] exceptions + ) { + var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions); + checkerMethodVisitor.visit(currentClass, access, checkerMethodName, checkerMethodDescriptor); + return mv; } - return mv; - } - }; - reader.accept(visitor, 0); + }; + reader.accept(visitor, 0); + } catch (IOException e) { + throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e); + } } + } + + @Override + public Map lookupMethods(Class checkerClass) throws ClassNotFoundException { + Map methodsToInstrument = new HashMap<>(); + + visitClassAndSupers(checkerClass, (currentClass, access, checkerMethodName, checkerMethodDescriptor) -> { + if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) { + var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor); + var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes); + + var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList(); + var checkMethod = new CheckMethod(Type.getInternalName(currentClass), checkerMethodName, checkerParameterDescriptors); + methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod); + } + }); + return methodsToInstrument; } @@ -110,14 +120,14 @@ public MethodVisitor visitMethod( @Override public InstrumentationInfo lookupImplementationMethod( Class targetSuperclass, - String methodName, + String targetMethodName, Class implementationClass, Class checkerClass, String checkMethodName, Class... parameterTypes ) throws NoSuchMethodException, ClassNotFoundException { - var targetMethod = targetSuperclass.getDeclaredMethod(methodName, parameterTypes); + var targetMethod = targetSuperclass.getDeclaredMethod(targetMethodName, parameterTypes); var implementationMethod = implementationClass.getMethod(targetMethod.getName(), targetMethod.getParameterTypes()); validateTargetMethod(implementationClass, targetMethod, implementationMethod); @@ -128,33 +138,15 @@ public InstrumentationInfo lookupImplementationMethod( CheckMethod[] checkMethod = new CheckMethod[1]; - try { - InstrumenterImpl.ClassFileInfo classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass); - ClassReader reader = new ClassReader(classFileInfo.bytecodes()); - ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) { - @Override - public MethodVisitor visitMethod( - int access, - String methodName, - String methodDescriptor, - String signature, - String[] exceptions - ) { - var mv = super.visitMethod(access, methodName, methodDescriptor, signature, exceptions); - if (methodName.equals(checkMethodName)) { - var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor); - if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) { - var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList(); - checkMethod[0] = new CheckMethod(Type.getInternalName(checkerClass), methodName, checkerParameterDescriptors); - } - } - return mv; + visitClassAndSupers(checkerClass, (currentClass, access, methodName, methodDescriptor) -> { + if (methodName.equals(checkMethodName)) { + var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor); + if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) { + var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList(); + checkMethod[0] = new CheckMethod(Type.getInternalName(currentClass), methodName, checkerParameterDescriptors); } - }; - reader.accept(visitor, 0); - } catch (IOException e) { - throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e); - } + } + }); if (checkMethod[0] == null) { throw new NoSuchMethodException( diff --git a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java index 2b9b70d46c0ea..25689f0b8a636 100644 --- a/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java +++ b/libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.test.ESTestCase; import org.objectweb.asm.Type; -import java.io.IOException; import java.util.List; import java.util.Map; @@ -90,7 +89,9 @@ interface TestCheckerMixed { void checkInstanceMethodManual(Class clazz, TestTargetBaseClass that, int x, String y); } - public void testInstrumentationTargetLookup() throws IOException { + interface TestCheckerDerived3 extends TestCheckerMixed {} + + public void testInstrumentationTargetLookup() throws ClassNotFoundException { Map checkMethods = instrumentationService.lookupMethods(TestChecker.class); assertThat(checkMethods, aMapWithSize(3)); @@ -143,7 +144,7 @@ public void testInstrumentationTargetLookup() throws IOException { ); } - public void testInstrumentationTargetLookupWithOverloads() throws IOException { + public void testInstrumentationTargetLookupWithOverloads() throws ClassNotFoundException { Map checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class); assertThat(checkMethods, aMapWithSize(2)); @@ -175,7 +176,7 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException { ); } - public void testInstrumentationTargetLookupWithDerivedClass() throws IOException { + public void testInstrumentationTargetLookupWithDerivedClass() throws ClassNotFoundException { Map checkMethods = instrumentationService.lookupMethods(TestCheckerDerived2.class); assertThat(checkMethods, aMapWithSize(4)); @@ -244,7 +245,7 @@ public void testInstrumentationTargetLookupWithDerivedClass() throws IOException ); } - public void testInstrumentationTargetLookupWithCtors() throws IOException { + public void testInstrumentationTargetLookupWithCtors() throws ClassNotFoundException { Map checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class); assertThat(checkMethods, aMapWithSize(2)); @@ -276,7 +277,7 @@ public void testInstrumentationTargetLookupWithCtors() throws IOException { ); } - public void testInstrumentationTargetLookupWithExtraMethods() throws IOException { + public void testInstrumentationTargetLookupWithExtraMethods() throws ClassNotFoundException { Map checkMethods = instrumentationService.lookupMethods(TestCheckerMixed.class); assertThat(checkMethods, aMapWithSize(1)); @@ -371,7 +372,7 @@ public void testLookupImplementationMethodWithBaseClass() throws ClassNotFoundEx ); } - public void testLookupImplementationMethodWithInheritance() throws ClassNotFoundException, NoSuchMethodException { + public void testLookupImplementationMethodWithInheritanceOnTarget() throws ClassNotFoundException, NoSuchMethodException { var info = instrumentationService.lookupImplementationMethod( TestTargetBaseClass.class, "instanceMethod2", @@ -409,6 +410,44 @@ public void testLookupImplementationMethodWithInheritance() throws ClassNotFound ); } + public void testLookupImplementationMethodWithInheritanceOnChecker() throws ClassNotFoundException, NoSuchMethodException { + var info = instrumentationService.lookupImplementationMethod( + TestTargetBaseClass.class, + "instanceMethod2", + TestTargetImplementationClass.class, + TestCheckerDerived3.class, + "checkInstanceMethodManual", + int.class, + String.class + ); + + assertThat( + info.targetMethod(), + equalTo( + new MethodKey( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetIntermediateClass", + "instanceMethod2", + List.of("I", "java/lang/String") + ) + ) + ); + assertThat( + info.checkMethod(), + equalTo( + new CheckMethod( + "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerMixed", + "checkInstanceMethodManual", + List.of( + "Ljava/lang/Class;", + "Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetBaseClass;", + "I", + "Ljava/lang/String;" + ) + ) + ) + ); + } + public void testParseCheckerMethodSignatureStaticMethod() { var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature( "check$org_example_TestClass$$staticMethod", diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java index 79673418eb321..ece51a8414b70 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java @@ -9,7 +9,6 @@ package org.elasticsearch.entitlement.instrumentation; -import java.io.IOException; import java.util.Map; /** @@ -23,7 +22,7 @@ record InstrumentationInfo(MethodKey targetMethod, CheckMethod checkMethod) {} Instrumenter newInstrumenter(Class clazz, Map methods); - Map lookupMethods(Class clazz) throws IOException; + Map lookupMethods(Class clazz) throws ClassNotFoundException; InstrumentationInfo lookupImplementationMethod( Class targetSuperclass,