Skip to content

Commit 9799d00

Browse files
authored
[Entitlements] Add support for instrumenting constructors (#117332)
1 parent d7737e7 commit 9799d00

File tree

13 files changed

+281
-23
lines changed

13 files changed

+281
-23
lines changed

libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,18 @@ static MethodKey parseCheckerMethodSignature(String checkerMethodName, Type[] ch
9191
String.format(
9292
Locale.ROOT,
9393
"Checker method %s has incorrect name format. "
94-
+ "It should be either check$$methodName (instance) or check$package_ClassName$methodName (static)",
94+
+ "It should be either check$$methodName (instance), check$package_ClassName$methodName (static) or "
95+
+ "check$package_ClassName$ (ctor)",
9596
checkerMethodName
9697
)
9798
);
9899
}
99100

100-
// No "className" (check$$methodName) -> method is static, and we'll get the class from the actual typed argument
101+
// No "className" (check$$methodName) -> method is instance, and we'll get the class from the actual typed argument
101102
final boolean targetMethodIsStatic = classNameStartIndex + 1 != classNameEndIndex;
102-
final String targetMethodName = checkerMethodName.substring(classNameEndIndex + 1);
103+
// No "methodName" (check$package_ClassName$) -> method is ctor
104+
final boolean targetMethodIsCtor = classNameEndIndex + 1 == checkerMethodName.length();
105+
final String targetMethodName = targetMethodIsCtor ? "<init>" : checkerMethodName.substring(classNameEndIndex + 1);
103106

104107
final String targetClassName;
105108
final List<String> targetParameterTypes;

libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,12 @@ public MethodVisitor visitMethod(int access, String name, String descriptor, Str
154154
var mv = super.visitMethod(access, name, descriptor, signature, exceptions);
155155
if (isAnnotationPresent == false) {
156156
boolean isStatic = (access & ACC_STATIC) != 0;
157+
boolean isCtor = "<init>".equals(name);
157158
var key = new MethodKey(className, name, Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList());
158159
var instrumentationMethod = instrumentationMethods.get(key);
159160
if (instrumentationMethod != null) {
160161
// LOGGER.debug("Will instrument method {}", key);
161-
return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, descriptor, instrumentationMethod);
162+
return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, isCtor, descriptor, instrumentationMethod);
162163
} else {
163164
// LOGGER.trace("Will not instrument method {}", key);
164165
}
@@ -187,6 +188,7 @@ private void addClassAnnotationIfNeeded() {
187188

188189
class EntitlementMethodVisitor extends MethodVisitor {
189190
private final boolean instrumentedMethodIsStatic;
191+
private final boolean instrumentedMethodIsCtor;
190192
private final String instrumentedMethodDescriptor;
191193
private final CheckerMethod instrumentationMethod;
192194
private boolean hasCallerSensitiveAnnotation = false;
@@ -195,11 +197,13 @@ class EntitlementMethodVisitor extends MethodVisitor {
195197
int api,
196198
MethodVisitor methodVisitor,
197199
boolean instrumentedMethodIsStatic,
200+
boolean instrumentedMethodIsCtor,
198201
String instrumentedMethodDescriptor,
199202
CheckerMethod instrumentationMethod
200203
) {
201204
super(api, methodVisitor);
202205
this.instrumentedMethodIsStatic = instrumentedMethodIsStatic;
206+
this.instrumentedMethodIsCtor = instrumentedMethodIsCtor;
203207
this.instrumentedMethodDescriptor = instrumentedMethodDescriptor;
204208
this.instrumentationMethod = instrumentationMethod;
205209
}
@@ -260,14 +264,15 @@ private void pushCallerClass() {
260264

261265
private void forwardIncomingArguments() {
262266
int localVarIndex = 0;
263-
if (instrumentedMethodIsStatic == false) {
267+
if (instrumentedMethodIsCtor) {
268+
localVarIndex++;
269+
} else if (instrumentedMethodIsStatic == false) {
264270
mv.visitVarInsn(Opcodes.ALOAD, localVarIndex++);
265271
}
266272
for (Type type : Type.getArgumentTypes(instrumentedMethodDescriptor)) {
267273
mv.visitVarInsn(type.getOpcode(Opcodes.ILOAD), localVarIndex);
268274
localVarIndex += type.getSize();
269275
}
270-
271276
}
272277

273278
private void invokeInstrumentationMethod() {

libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ interface TestCheckerOverloads {
4545
void check$org_example_TestTargetClass$staticMethodWithOverload(Class<?> clazz, int x, String y);
4646
}
4747

48+
interface TestCheckerCtors {
49+
void check$org_example_TestTargetClass$(Class<?> clazz);
50+
51+
void check$org_example_TestTargetClass$(Class<?> clazz, int x, String y);
52+
}
53+
4854
public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
4955
Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
5056

@@ -142,6 +148,38 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException, C
142148
);
143149
}
144150

151+
public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
152+
Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
153+
154+
assertThat(methodsMap, aMapWithSize(2));
155+
assertThat(
156+
methodsMap,
157+
hasEntry(
158+
equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of("I", "java/lang/String"))),
159+
equalTo(
160+
new CheckerMethod(
161+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
162+
"check$org_example_TestTargetClass$",
163+
List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;")
164+
)
165+
)
166+
)
167+
);
168+
assertThat(
169+
methodsMap,
170+
hasEntry(
171+
equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of())),
172+
equalTo(
173+
new CheckerMethod(
174+
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
175+
"check$org_example_TestTargetClass$",
176+
List.of("Ljava/lang/Class;")
177+
)
178+
)
179+
)
180+
);
181+
}
182+
145183
public void testParseCheckerMethodSignatureStaticMethod() {
146184
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
147185
"check$org_example_TestClass$staticMethod",
@@ -169,6 +207,24 @@ public void testParseCheckerMethodSignatureStaticMethodInnerClass() {
169207
assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass$InnerClass", "staticMethod", List.of())));
170208
}
171209

210+
public void testParseCheckerMethodSignatureCtor() {
211+
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
212+
"check$org_example_TestClass$",
213+
new Type[] { Type.getType(Class.class) }
214+
);
215+
216+
assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass", "<init>", List.of())));
217+
}
218+
219+
public void testParseCheckerMethodSignatureCtorWithArgs() {
220+
var methodKey = InstrumentationServiceImpl.parseCheckerMethodSignature(
221+
"check$org_example_TestClass$",
222+
new Type[] { Type.getType(Class.class), Type.getType("I"), Type.getType(String.class) }
223+
);
224+
225+
assertThat(methodKey, equalTo(new MethodKey("org/example/TestClass", "<init>", List.of("I", "java/lang/String"))));
226+
}
227+
172228
public void testParseCheckerMethodSignatureIncorrectName() {
173229
var exception = assertThrows(
174230
IllegalArgumentException.class,

libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@
2323

2424
import java.lang.reflect.InvocationTargetException;
2525
import java.lang.reflect.Method;
26+
import java.net.URL;
27+
import java.net.URLStreamHandlerFactory;
2628
import java.util.Arrays;
29+
import java.util.List;
2730
import java.util.Map;
2831

2932
import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
3033
import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
31-
import static org.hamcrest.Matchers.is;
34+
import static org.hamcrest.Matchers.instanceOf;
3235
import static org.hamcrest.Matchers.startsWith;
3336
import static org.objectweb.asm.Opcodes.INVOKESTATIC;
3437

@@ -72,6 +75,11 @@ public interface Testable {
7275
* They must not throw {@link TestException}.
7376
*/
7477
public static class ClassToInstrument implements Testable {
78+
79+
public ClassToInstrument() {}
80+
81+
public ClassToInstrument(int arg) {}
82+
7583
public static void systemExit(int status) {
7684
assertEquals(123, status);
7785
}
@@ -91,12 +99,20 @@ public static void someStaticMethod(int arg, String anotherArg) {}
9199

92100
static final class TestException extends RuntimeException {}
93101

102+
/**
103+
* Interface to test specific, "synthetic" cases (e.g. overloaded methods, overloaded constructors, etc.) that
104+
* may be not present/may be difficult to find or not clear in the production EntitlementChecker interface
105+
*/
94106
public interface MockEntitlementChecker extends EntitlementChecker {
95107
void checkSomeStaticMethod(Class<?> clazz, int arg);
96108

97109
void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
98110

99111
void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
112+
113+
void checkCtor(Class<?> clazz);
114+
115+
void checkCtor(Class<?> clazz, int arg);
100116
}
101117

102118
/**
@@ -118,6 +134,9 @@ public static class TestEntitlementChecker implements MockEntitlementChecker {
118134
int checkSomeStaticMethodIntStringCallCount = 0;
119135
int checkSomeInstanceMethodCallCount = 0;
120136

137+
int checkCtorCallCount = 0;
138+
int checkCtorIntCallCount = 0;
139+
121140
@Override
122141
public void check$java_lang_System$exit(Class<?> callerClass, int status) {
123142
checkSystemExitCallCount++;
@@ -126,6 +145,27 @@ public static class TestEntitlementChecker implements MockEntitlementChecker {
126145
throwIfActive();
127146
}
128147

148+
@Override
149+
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls) {}
150+
151+
@Override
152+
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent) {}
153+
154+
@Override
155+
public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory) {}
156+
157+
@Override
158+
public void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent) {}
159+
160+
@Override
161+
public void check$java_net_URLClassLoader$(
162+
Class<?> callerClass,
163+
String name,
164+
URL[] urls,
165+
ClassLoader parent,
166+
URLStreamHandlerFactory factory
167+
) {}
168+
129169
private void throwIfActive() {
130170
if (isActive) {
131171
throw new TestException();
@@ -161,6 +201,21 @@ public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg
161201
assertEquals("def", anotherArg);
162202
throwIfActive();
163203
}
204+
205+
@Override
206+
public void checkCtor(Class<?> callerClass) {
207+
checkCtorCallCount++;
208+
assertSame(InstrumenterTests.class, callerClass);
209+
throwIfActive();
210+
}
211+
212+
@Override
213+
public void checkCtor(Class<?> callerClass, int arg) {
214+
checkCtorIntCallCount++;
215+
assertSame(InstrumenterTests.class, callerClass);
216+
assertEquals(123, arg);
217+
throwIfActive();
218+
}
164219
}
165220

166221
public void testClassIsInstrumented() throws Exception {
@@ -225,7 +280,7 @@ public void testClassIsNotInstrumentedTwice() throws Exception {
225280
getTestEntitlementChecker().checkSystemExitCallCount = 0;
226281

227282
assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
228-
assertThat(getTestEntitlementChecker().checkSystemExitCallCount, is(1));
283+
assertEquals(1, getTestEntitlementChecker().checkSystemExitCallCount);
229284
}
230285

231286
public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
@@ -259,10 +314,10 @@ public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
259314
getTestEntitlementChecker().checkSystemExitCallCount = 0;
260315

261316
assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
262-
assertThat(getTestEntitlementChecker().checkSystemExitCallCount, is(1));
317+
assertEquals(1, getTestEntitlementChecker().checkSystemExitCallCount);
263318

264319
assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherSystemExit", 123));
265-
assertThat(getTestEntitlementChecker().checkSystemExitCallCount, is(2));
320+
assertEquals(2, getTestEntitlementChecker().checkSystemExitCallCount);
266321
}
267322

268323
public void testInstrumenterWorksWithOverloads() throws Exception {
@@ -294,8 +349,8 @@ public void testInstrumenterWorksWithOverloads() throws Exception {
294349
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
295350
assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
296351

297-
assertThat(getTestEntitlementChecker().checkSomeStaticMethodIntCallCount, is(1));
298-
assertThat(getTestEntitlementChecker().checkSomeStaticMethodIntStringCallCount, is(1));
352+
assertEquals(1, getTestEntitlementChecker().checkSomeStaticMethodIntCallCount);
353+
assertEquals(1, getTestEntitlementChecker().checkSomeStaticMethodIntStringCallCount);
299354
}
300355

301356
public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
@@ -327,7 +382,41 @@ public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Except
327382
testTargetClass.someMethod(123);
328383
assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
329384

330-
assertThat(getTestEntitlementChecker().checkSomeInstanceMethodCallCount, is(1));
385+
assertEquals(1, getTestEntitlementChecker().checkSomeInstanceMethodCallCount);
386+
}
387+
388+
public void testInstrumenterWorksWithConstructors() throws Exception {
389+
var classToInstrument = ClassToInstrument.class;
390+
391+
Map<MethodKey, CheckerMethod> methods = Map.of(
392+
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
393+
getCheckerMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
394+
new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
395+
getCheckerMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
396+
);
397+
398+
var instrumenter = createInstrumenter(methods);
399+
400+
byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
401+
402+
if (logger.isTraceEnabled()) {
403+
logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
404+
}
405+
406+
Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
407+
classToInstrument.getName() + "_NEW",
408+
newBytecode
409+
);
410+
411+
getTestEntitlementChecker().isActive = true;
412+
413+
var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
414+
assertThat(ex.getCause(), instanceOf(TestException.class));
415+
var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
416+
assertThat(ex2.getCause(), instanceOf(TestException.class));
417+
418+
assertEquals(1, getTestEntitlementChecker().checkCtorCallCount);
419+
assertEquals(1, getTestEntitlementChecker().checkCtorIntCallCount);
331420
}
332421

333422
/** This test doesn't replace classToInstrument in-place but instead loads a separate

libs/entitlement/bridge/src/main/java/org/elasticsearch/entitlement/bridge/EntitlementChecker.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@
99

1010
package org.elasticsearch.entitlement.bridge;
1111

12+
import java.net.URL;
13+
import java.net.URLStreamHandlerFactory;
14+
1215
public interface EntitlementChecker {
1316
void check$java_lang_System$exit(Class<?> callerClass, int status);
17+
18+
// URLClassLoader ctor
19+
void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls);
20+
21+
void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent);
22+
23+
void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory);
24+
25+
void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent);
26+
27+
void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory);
1428
}

libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ private static ElasticsearchEntitlementChecker initChecker() throws IOException
169169
}
170170
}
171171

172-
private static String internalName(Class<?> c) {
173-
return c.getName().replace('.', '/');
174-
}
175-
176172
private static final InstrumentationService INSTRUMENTER_FACTORY = new ProviderLocator<>(
177173
"entitlement",
178174
InstrumentationService.class,

0 commit comments

Comments
 (0)