Skip to content

Commit a5dcd14

Browse files
authored
[Entitlements] "dynamic" instrumentation method keys (elastic#120811) (elastic#121299)
1 parent c260034 commit a5dcd14

File tree

9 files changed

+323
-14
lines changed

9 files changed

+323
-14
lines changed

libs/entitlement/asm-provider/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ apply plugin: 'elasticsearch.build'
1111

1212
dependencies {
1313
compileOnly project(':libs:entitlement')
14+
compileOnly project(':libs:core')
1415
implementation 'org.ow2.asm:asm:9.7.1'
1516
testImplementation project(":test:framework")
1617
testImplementation project(":libs:entitlement:bridge")

libs/entitlement/asm-provider/src/main/java/module-info.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@
1414
requires org.objectweb.asm;
1515
requires org.elasticsearch.entitlement;
1616

17+
requires static org.elasticsearch.base; // for SuppressForbidden
18+
1719
provides InstrumentationService with InstrumentationServiceImpl;
1820
}

libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java

Lines changed: 151 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
package org.elasticsearch.entitlement.instrumentation.impl;
1111

12+
import org.elasticsearch.core.SuppressForbidden;
1213
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
1314
import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
1415
import org.elasticsearch.entitlement.instrumentation.Instrumenter;
@@ -20,11 +21,15 @@
2021
import org.objectweb.asm.Type;
2122

2223
import java.io.IOException;
24+
import java.lang.reflect.Method;
25+
import java.lang.reflect.Modifier;
2326
import java.util.Arrays;
2427
import java.util.HashMap;
2528
import java.util.List;
2629
import java.util.Locale;
2730
import java.util.Map;
31+
import java.util.stream.Collectors;
32+
import java.util.stream.Stream;
2833

2934
public class InstrumentationServiceImpl implements InstrumentationService {
3035

@@ -48,22 +53,159 @@ public MethodVisitor visitMethod(
4853
String[] exceptions
4954
) {
5055
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
56+
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
57+
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
58+
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);
5159

52-
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
53-
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);
54-
55-
var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
56-
var checkMethod = new CheckMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors);
57-
58-
methodsToInstrument.put(methodToInstrument, checkMethod);
60+
var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
61+
var checkMethod = new CheckMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors);
5962

63+
methodsToInstrument.put(methodToInstrument, checkMethod);
64+
}
6065
return mv;
6166
}
6267
};
6368
reader.accept(visitor, 0);
6469
return methodsToInstrument;
6570
}
6671

72+
@SuppressForbidden(reason = "Need access to abstract methods (protected/package internal) in base class")
73+
@Override
74+
public InstrumentationInfo lookupImplementationMethod(
75+
Class<?> targetSuperclass,
76+
String methodName,
77+
Class<?> implementationClass,
78+
Class<?> checkerClass,
79+
String checkMethodName,
80+
Class<?>... parameterTypes
81+
) throws NoSuchMethodException, ClassNotFoundException {
82+
83+
var targetMethod = targetSuperclass.getDeclaredMethod(methodName, parameterTypes);
84+
validateTargetMethod(implementationClass, targetMethod);
85+
86+
var checkerAdditionalArguments = Stream.of(Class.class, targetSuperclass);
87+
var checkMethodArgumentTypes = Stream.concat(checkerAdditionalArguments, Arrays.stream(parameterTypes))
88+
.map(Type::getType)
89+
.toArray(Type[]::new);
90+
91+
CheckMethod[] checkMethod = new CheckMethod[1];
92+
93+
try {
94+
InstrumenterImpl.ClassFileInfo classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
95+
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
96+
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
97+
@Override
98+
public MethodVisitor visitMethod(
99+
int access,
100+
String methodName,
101+
String methodDescriptor,
102+
String signature,
103+
String[] exceptions
104+
) {
105+
var mv = super.visitMethod(access, methodName, methodDescriptor, signature, exceptions);
106+
if (methodName.equals(checkMethodName)) {
107+
var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor);
108+
if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) {
109+
var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList();
110+
checkMethod[0] = new CheckMethod(Type.getInternalName(checkerClass), methodName, checkerParameterDescriptors);
111+
}
112+
}
113+
return mv;
114+
}
115+
};
116+
reader.accept(visitor, 0);
117+
} catch (IOException e) {
118+
throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e);
119+
}
120+
121+
if (checkMethod[0] == null) {
122+
throw new NoSuchMethodException(
123+
String.format(
124+
Locale.ROOT,
125+
"Cannot find a method with name [%s] and arguments [%s] in class [%s]",
126+
checkMethodName,
127+
Arrays.stream(checkMethodArgumentTypes).map(Type::toString).collect(Collectors.joining()),
128+
checkerClass.getName()
129+
)
130+
);
131+
}
132+
133+
return new InstrumentationInfo(
134+
new MethodKey(
135+
Type.getInternalName(implementationClass),
136+
targetMethod.getName(),
137+
Arrays.stream(parameterTypes).map(c -> Type.getType(c).getInternalName()).toList()
138+
),
139+
checkMethod[0]
140+
);
141+
}
142+
143+
private static void validateTargetMethod(Class<?> implementationClass, Method targetMethod) {
144+
if (targetMethod.getDeclaringClass().isAssignableFrom(implementationClass) == false) {
145+
throw new IllegalArgumentException(
146+
String.format(
147+
Locale.ROOT,
148+
"Not an implementation class for %s: %s does not implement %s",
149+
targetMethod.getName(),
150+
implementationClass.getName(),
151+
targetMethod.getDeclaringClass().getName()
152+
)
153+
);
154+
}
155+
if (Modifier.isPrivate(targetMethod.getModifiers())) {
156+
throw new IllegalArgumentException(
157+
String.format(
158+
Locale.ROOT,
159+
"Not a valid instrumentation method: %s is private in %s",
160+
targetMethod.getName(),
161+
targetMethod.getDeclaringClass().getName()
162+
)
163+
);
164+
}
165+
if (Modifier.isStatic(targetMethod.getModifiers())) {
166+
throw new IllegalArgumentException(
167+
String.format(
168+
Locale.ROOT,
169+
"Not a valid instrumentation method: %s is static in %s",
170+
targetMethod.getName(),
171+
targetMethod.getDeclaringClass().getName()
172+
)
173+
);
174+
}
175+
try {
176+
var implementationMethod = implementationClass.getMethod(targetMethod.getName(), targetMethod.getParameterTypes());
177+
var methodModifiers = implementationMethod.getModifiers();
178+
if (Modifier.isAbstract(methodModifiers)) {
179+
throw new IllegalArgumentException(
180+
String.format(
181+
Locale.ROOT,
182+
"Not a valid instrumentation method: %s is abstract in %s",
183+
targetMethod.getName(),
184+
implementationClass.getName()
185+
)
186+
);
187+
}
188+
if (Modifier.isPublic(methodModifiers) == false) {
189+
throw new IllegalArgumentException(
190+
String.format(
191+
Locale.ROOT,
192+
"Not a valid instrumentation method: %s is not public in %s",
193+
targetMethod.getName(),
194+
implementationClass.getName()
195+
)
196+
);
197+
}
198+
} catch (NoSuchMethodException e) {
199+
assert false
200+
: String.format(
201+
Locale.ROOT,
202+
"Not a valid instrumentation method: %s cannot be found in %s",
203+
targetMethod.getName(),
204+
implementationClass.getName()
205+
);
206+
}
207+
}
208+
67209
private static final Type CLASS_TYPE = Type.getType(Class.class);
68210

69211
static ParsedCheckerMethod parseCheckerMethodName(String checkerMethodName) {
@@ -85,8 +227,8 @@ static ParsedCheckerMethod parseCheckerMethodName(String checkerMethodName) {
85227
String.format(
86228
Locale.ROOT,
87229
"Checker method %s has incorrect name format. "
88-
+ "It should be either check$$methodName (instance), check$package_ClassName$methodName (static) or "
89-
+ "check$package_ClassName$ (ctor)",
230+
+ "It should be either check$package_ClassName$methodName (instance), check$package_ClassName$$methodName (static) "
231+
+ "or check$package_ClassName$ (ctor)",
90232
checkerMethodName
91233
)
92234
);

libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,23 @@ public class InstrumentationServiceImplTests extends ESTestCase {
2929

3030
final InstrumentationService instrumentationService = new InstrumentationServiceImpl();
3131

32-
static class TestTargetClass {}
32+
interface TestTargetInterface {
33+
void instanceMethod(int x, String y);
34+
}
35+
36+
static class TestTargetClass implements TestTargetInterface {
37+
@Override
38+
public void instanceMethod(int x, String y) {}
39+
}
40+
41+
abstract static class TestTargetBaseClass {
42+
abstract void instanceMethod(int x, String y);
43+
}
44+
45+
static class TestTargetImplementationClass extends TestTargetBaseClass {
46+
@Override
47+
public void instanceMethod(int x, String y) {}
48+
}
3349

3450
interface TestChecker {
3551
void check$org_example_TestTargetClass$$staticMethod(Class<?> clazz, int arg0, String arg1, Object arg2);
@@ -51,6 +67,14 @@ interface TestCheckerCtors {
5167
void check$org_example_TestTargetClass$(Class<?> clazz, int x, String y);
5268
}
5369

70+
interface TestCheckerMixed {
71+
void check$org_example_TestTargetClass$$staticMethod(Class<?> clazz, int arg0, String arg1, Object arg2);
72+
73+
void checkInstanceMethodManual(Class<?> clazz, TestTargetInterface that, int x, String y);
74+
75+
void checkInstanceMethodManual(Class<?> clazz, TestTargetBaseClass that, int x, String y);
76+
}
77+
5478
public void testInstrumentationTargetLookup() throws IOException {
5579
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);
5680

@@ -168,6 +192,101 @@ public void testInstrumentationTargetLookupWithCtors() throws IOException {
168192
);
169193
}
170194

195+
public void testInstrumentationTargetLookupWithExtraMethods() throws IOException {
196+
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerMixed.class);
197+
198+
assertThat(checkMethods, aMapWithSize(1));
199+
assertThat(
200+
checkMethods,
201+
hasEntry(
202+
equalTo(new MethodKey("org/example/TestTargetClass", "staticMethod", List.of("I", "java/lang/String", "java/lang/Object"))),
203+
equalTo(
204+
new CheckMethod(
205+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerMixed",
206+
"check$org_example_TestTargetClass$$staticMethod",
207+
List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;", "Ljava/lang/Object;")
208+
)
209+
)
210+
)
211+
);
212+
}
213+
214+
public void testLookupImplementationMethodWithInterface() throws ClassNotFoundException, NoSuchMethodException {
215+
var info = instrumentationService.lookupImplementationMethod(
216+
TestTargetInterface.class,
217+
"instanceMethod",
218+
TestTargetClass.class,
219+
TestCheckerMixed.class,
220+
"checkInstanceMethodManual",
221+
int.class,
222+
String.class
223+
);
224+
225+
assertThat(
226+
info.targetMethod(),
227+
equalTo(
228+
new MethodKey(
229+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetClass",
230+
"instanceMethod",
231+
List.of("I", "java/lang/String")
232+
)
233+
)
234+
);
235+
assertThat(
236+
info.checkMethod(),
237+
equalTo(
238+
new CheckMethod(
239+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerMixed",
240+
"checkInstanceMethodManual",
241+
List.of(
242+
"Ljava/lang/Class;",
243+
"Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetInterface;",
244+
"I",
245+
"Ljava/lang/String;"
246+
)
247+
)
248+
)
249+
);
250+
}
251+
252+
public void testLookupImplementationMethodWithBaseClass() throws ClassNotFoundException, NoSuchMethodException {
253+
var info = instrumentationService.lookupImplementationMethod(
254+
TestTargetBaseClass.class,
255+
"instanceMethod",
256+
TestTargetImplementationClass.class,
257+
TestCheckerMixed.class,
258+
"checkInstanceMethodManual",
259+
int.class,
260+
String.class
261+
);
262+
263+
assertThat(
264+
info.targetMethod(),
265+
equalTo(
266+
new MethodKey(
267+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetImplementationClass",
268+
"instanceMethod",
269+
List.of("I", "java/lang/String")
270+
)
271+
)
272+
);
273+
assertThat(
274+
info.checkMethod(),
275+
equalTo(
276+
new CheckMethod(
277+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerMixed",
278+
"checkInstanceMethodManual",
279+
List.of(
280+
"Ljava/lang/Class;",
281+
"Lorg/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestTargetBaseClass;",
282+
"I",
283+
"Ljava/lang/String;"
284+
)
285+
)
286+
)
287+
);
288+
}
289+
171290
public void testParseCheckerMethodSignatureStaticMethod() {
172291
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
173292
"check$org_example_TestClass$$staticMethod",

libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@
4242
import java.nio.channels.ServerSocketChannel;
4343
import java.nio.channels.SocketChannel;
4444
import java.nio.charset.Charset;
45+
import java.nio.file.OpenOption;
4546
import java.nio.file.Path;
4647
import java.nio.file.attribute.UserPrincipal;
48+
import java.nio.file.spi.FileSystemProvider;
4749
import java.security.cert.CertStoreParameters;
4850
import java.util.List;
4951
import java.util.Locale;
@@ -448,4 +450,7 @@ public interface EntitlementChecker {
448450
void check$java_nio_file_Files$$probeContentType(Class<?> callerClass, Path path);
449451

450452
void check$java_nio_file_Files$$setOwner(Class<?> callerClass, Path path, UserPrincipal principal);
453+
454+
// hand-wired methods
455+
void checkNewInputStream(Class<?> callerClass, FileSystemProvider that, Path path, OpenOption... options);
451456
}

0 commit comments

Comments
 (0)