Skip to content

Commit 25efa74

Browse files
committed
[Entitlements] Make lookupImplementationMethod inheritance-aware (elastic#122474)
1 parent ba32add commit 25efa74

File tree

3 files changed

+110
-80
lines changed

3 files changed

+110
-80
lines changed

libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java

Lines changed: 63 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod>
4444
return InstrumenterImpl.create(clazz, methods);
4545
}
4646

47-
@Override
48-
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
49-
Map<MethodKey, CheckMethod> methodsToInstrument = new HashMap<>();
47+
private interface CheckerMethodVisitor {
48+
void visit(Class<?> currentClass, int access, String checkerMethodName, String checkerMethodDescriptor);
49+
}
5050

51+
private void visitClassAndSupers(Class<?> checkerClass, CheckerMethodVisitor checkerMethodVisitor) throws ClassNotFoundException {
5152
Set<Class<?>> visitedClasses = new HashSet<>();
5253
ArrayDeque<Class<?>> classesToVisit = new ArrayDeque<>(Collections.singleton(checkerClass));
5354
while (classesToVisit.isEmpty() == false) {
@@ -57,67 +58,76 @@ public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws I
5758
}
5859
visitedClasses.add(currentClass);
5960

60-
var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass);
61-
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
62-
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
61+
try {
62+
var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass);
63+
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
64+
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
6365

64-
@Override
65-
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
66-
super.visit(version, access, name, signature, superName, interfaces);
67-
try {
68-
if (OBJECT_INTERNAL_NAME.equals(superName) == false) {
69-
classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName()));
66+
@Override
67+
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
68+
super.visit(version, access, name, signature, superName, interfaces);
69+
try {
70+
if (OBJECT_INTERNAL_NAME.equals(superName) == false) {
71+
classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName()));
72+
}
73+
for (var interfaceName : interfaces) {
74+
classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName()));
75+
}
76+
} catch (ClassNotFoundException e) {
77+
throw new IllegalArgumentException("Cannot inspect checker class " + currentClass.getName(), e);
7078
}
71-
for (var interfaceName : interfaces) {
72-
classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName()));
73-
}
74-
} catch (ClassNotFoundException e) {
75-
throw new IllegalArgumentException("Cannot inspect checker class " + checkerClass.getName(), e);
7679
}
77-
}
7880

79-
@Override
80-
public MethodVisitor visitMethod(
81-
int access,
82-
String checkerMethodName,
83-
String checkerMethodDescriptor,
84-
String signature,
85-
String[] exceptions
86-
) {
87-
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
88-
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
89-
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
90-
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);
91-
92-
var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
93-
var checkMethod = new CheckMethod(
94-
Type.getInternalName(currentClass),
95-
checkerMethodName,
96-
checkerParameterDescriptors
97-
);
98-
99-
methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod);
81+
@Override
82+
public MethodVisitor visitMethod(
83+
int access,
84+
String checkerMethodName,
85+
String checkerMethodDescriptor,
86+
String signature,
87+
String[] exceptions
88+
) {
89+
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
90+
checkerMethodVisitor.visit(currentClass, access, checkerMethodName, checkerMethodDescriptor);
91+
return mv;
10092
}
101-
return mv;
102-
}
103-
};
104-
reader.accept(visitor, 0);
93+
};
94+
reader.accept(visitor, 0);
95+
} catch (IOException e) {
96+
throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e);
97+
}
10598
}
99+
}
100+
101+
@Override
102+
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws ClassNotFoundException {
103+
Map<MethodKey, CheckMethod> methodsToInstrument = new HashMap<>();
104+
105+
visitClassAndSupers(checkerClass, (currentClass, access, checkerMethodName, checkerMethodDescriptor) -> {
106+
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
107+
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
108+
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);
109+
110+
var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
111+
var checkMethod = new CheckMethod(Type.getInternalName(currentClass), checkerMethodName, checkerParameterDescriptors);
112+
methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod);
113+
}
114+
});
115+
106116
return methodsToInstrument;
107117
}
108118

109119
@SuppressForbidden(reason = "Need access to abstract methods (protected/package internal) in base class")
110120
@Override
111121
public InstrumentationInfo lookupImplementationMethod(
112122
Class<?> targetSuperclass,
113-
String methodName,
123+
String targetMethodName,
114124
Class<?> implementationClass,
115125
Class<?> checkerClass,
116126
String checkMethodName,
117127
Class<?>... parameterTypes
118128
) throws NoSuchMethodException, ClassNotFoundException {
119129

120-
var targetMethod = targetSuperclass.getDeclaredMethod(methodName, parameterTypes);
130+
var targetMethod = targetSuperclass.getDeclaredMethod(targetMethodName, parameterTypes);
121131
var implementationMethod = implementationClass.getMethod(targetMethod.getName(), targetMethod.getParameterTypes());
122132
validateTargetMethod(implementationClass, targetMethod, implementationMethod);
123133

@@ -128,33 +138,15 @@ public InstrumentationInfo lookupImplementationMethod(
128138

129139
CheckMethod[] checkMethod = new CheckMethod[1];
130140

131-
try {
132-
InstrumenterImpl.ClassFileInfo classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
133-
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
134-
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
135-
@Override
136-
public MethodVisitor visitMethod(
137-
int access,
138-
String methodName,
139-
String methodDescriptor,
140-
String signature,
141-
String[] exceptions
142-
) {
143-
var mv = super.visitMethod(access, methodName, methodDescriptor, signature, exceptions);
144-
if (methodName.equals(checkMethodName)) {
145-
var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor);
146-
if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) {
147-
var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList();
148-
checkMethod[0] = new CheckMethod(Type.getInternalName(checkerClass), methodName, checkerParameterDescriptors);
149-
}
150-
}
151-
return mv;
141+
visitClassAndSupers(checkerClass, (currentClass, access, methodName, methodDescriptor) -> {
142+
if (methodName.equals(checkMethodName)) {
143+
var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor);
144+
if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) {
145+
var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList();
146+
checkMethod[0] = new CheckMethod(Type.getInternalName(currentClass), methodName, checkerParameterDescriptors);
152147
}
153-
};
154-
reader.accept(visitor, 0);
155-
} catch (IOException e) {
156-
throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e);
157-
}
148+
}
149+
});
158150

159151
if (checkMethod[0] == null) {
160152
throw new NoSuchMethodException(

libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.elasticsearch.test.ESTestCase;
1616
import org.objectweb.asm.Type;
1717

18-
import java.io.IOException;
1918
import java.util.List;
2019
import java.util.Map;
2120

@@ -90,7 +89,9 @@ interface TestCheckerMixed {
9089
void checkInstanceMethodManual(Class<?> clazz, TestTargetBaseClass that, int x, String y);
9190
}
9291

93-
public void testInstrumentationTargetLookup() throws IOException {
92+
interface TestCheckerDerived3 extends TestCheckerMixed {}
93+
94+
public void testInstrumentationTargetLookup() throws ClassNotFoundException {
9495
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);
9596

9697
assertThat(checkMethods, aMapWithSize(3));
@@ -143,7 +144,7 @@ public void testInstrumentationTargetLookup() throws IOException {
143144
);
144145
}
145146

146-
public void testInstrumentationTargetLookupWithOverloads() throws IOException {
147+
public void testInstrumentationTargetLookupWithOverloads() throws ClassNotFoundException {
147148
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class);
148149

149150
assertThat(checkMethods, aMapWithSize(2));
@@ -175,7 +176,7 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException {
175176
);
176177
}
177178

178-
public void testInstrumentationTargetLookupWithDerivedClass() throws IOException {
179+
public void testInstrumentationTargetLookupWithDerivedClass() throws ClassNotFoundException {
179180
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerDerived2.class);
180181

181182
assertThat(checkMethods, aMapWithSize(4));
@@ -244,7 +245,7 @@ public void testInstrumentationTargetLookupWithDerivedClass() throws IOException
244245
);
245246
}
246247

247-
public void testInstrumentationTargetLookupWithCtors() throws IOException {
248+
public void testInstrumentationTargetLookupWithCtors() throws ClassNotFoundException {
248249
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);
249250

250251
assertThat(checkMethods, aMapWithSize(2));
@@ -276,7 +277,7 @@ public void testInstrumentationTargetLookupWithCtors() throws IOException {
276277
);
277278
}
278279

279-
public void testInstrumentationTargetLookupWithExtraMethods() throws IOException {
280+
public void testInstrumentationTargetLookupWithExtraMethods() throws ClassNotFoundException {
280281
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerMixed.class);
281282

282283
assertThat(checkMethods, aMapWithSize(1));
@@ -371,7 +372,7 @@ public void testLookupImplementationMethodWithBaseClass() throws ClassNotFoundEx
371372
);
372373
}
373374

374-
public void testLookupImplementationMethodWithInheritance() throws ClassNotFoundException, NoSuchMethodException {
375+
public void testLookupImplementationMethodWithInheritanceOnTarget() throws ClassNotFoundException, NoSuchMethodException {
375376
var info = instrumentationService.lookupImplementationMethod(
376377
TestTargetBaseClass.class,
377378
"instanceMethod2",
@@ -409,6 +410,44 @@ public void testLookupImplementationMethodWithInheritance() throws ClassNotFound
409410
);
410411
}
411412

413+
public void testLookupImplementationMethodWithInheritanceOnChecker() throws ClassNotFoundException, NoSuchMethodException {
414+
var info = instrumentationService.lookupImplementationMethod(
415+
TestTargetBaseClass.class,
416+
"instanceMethod2",
417+
TestTargetImplementationClass.class,
418+
TestCheckerDerived3.class,
419+
"checkInstanceMethodManual",
420+
int.class,
421+
String.class
422+
);
423+
424+
assertThat(
425+
info.targetMethod(),
426+
equalTo(
427+
new MethodKey(
428+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetIntermediateClass",
429+
"instanceMethod2",
430+
List.of("I", "java/lang/String")
431+
)
432+
)
433+
);
434+
assertThat(
435+
info.checkMethod(),
436+
equalTo(
437+
new CheckMethod(
438+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerMixed",
439+
"checkInstanceMethodManual",
440+
List.of(
441+
"Ljava/lang/Class;",
442+
"Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetBaseClass;",
443+
"I",
444+
"Ljava/lang/String;"
445+
)
446+
)
447+
)
448+
);
449+
}
450+
412451
public void testParseCheckerMethodSignatureStaticMethod() {
413452
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
414453
"check$org_example_TestClass$$staticMethod",

libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
package org.elasticsearch.entitlement.instrumentation;
1111

12-
import java.io.IOException;
1312
import java.util.Map;
1413

1514
/**
@@ -23,7 +22,7 @@ record InstrumentationInfo(MethodKey targetMethod, CheckMethod checkMethod) {}
2322

2423
Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods);
2524

26-
Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws IOException;
25+
Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws ClassNotFoundException;
2726

2827
InstrumentationInfo lookupImplementationMethod(
2928
Class<?> targetSuperclass,

0 commit comments

Comments
 (0)