2323
2424import java .lang .reflect .InvocationTargetException ;
2525import java .lang .reflect .Method ;
26+ import java .net .URL ;
27+ import java .net .URLStreamHandlerFactory ;
2628import java .util .Arrays ;
29+ import java .util .List ;
2730import java .util .Map ;
2831
2932import static org .elasticsearch .entitlement .instrumentation .impl .ASMUtils .bytecode2text ;
3033import static org .elasticsearch .entitlement .instrumentation .impl .InstrumenterImpl .getClassFileInfo ;
31- import static org .hamcrest .Matchers .is ;
34+ import static org .hamcrest .Matchers .instanceOf ;
3235import static org .hamcrest .Matchers .startsWith ;
3336import static org .objectweb .asm .Opcodes .INVOKESTATIC ;
3437
@@ -72,6 +75,11 @@ public interface Testable {
7275 * They must not throw {@link TestException}.
7376 */
7477 public static class ClassToInstrument implements Testable {
78+
79+ public ClassToInstrument () {}
80+
81+ public ClassToInstrument (int arg ) {}
82+
7583 public static void systemExit (int status ) {
7684 assertEquals (123 , status );
7785 }
@@ -91,12 +99,20 @@ public static void someStaticMethod(int arg, String anotherArg) {}
9199
92100 static final class TestException extends RuntimeException {}
93101
102+ /**
103+ * Interface to test specific, "synthetic" cases (e.g. overloaded methods, overloaded constructors, etc.) that
104+ * may be not present/may be difficult to find or not clear in the production EntitlementChecker interface
105+ */
94106 public interface MockEntitlementChecker extends EntitlementChecker {
95107 void checkSomeStaticMethod (Class <?> clazz , int arg );
96108
97109 void checkSomeStaticMethod (Class <?> clazz , int arg , String anotherArg );
98110
99111 void checkSomeInstanceMethod (Class <?> clazz , Testable that , int arg , String anotherArg );
112+
113+ void checkCtor (Class <?> clazz );
114+
115+ void checkCtor (Class <?> clazz , int arg );
100116 }
101117
102118 /**
@@ -118,6 +134,9 @@ public static class TestEntitlementChecker implements MockEntitlementChecker {
118134 int checkSomeStaticMethodIntStringCallCount = 0 ;
119135 int checkSomeInstanceMethodCallCount = 0 ;
120136
137+ int checkCtorCallCount = 0 ;
138+ int checkCtorIntCallCount = 0 ;
139+
121140 @ Override
122141 public void check$java_lang_System$exit (Class <?> callerClass , int status ) {
123142 checkSystemExitCallCount ++;
@@ -126,6 +145,27 @@ public static class TestEntitlementChecker implements MockEntitlementChecker {
126145 throwIfActive ();
127146 }
128147
148+ @ Override
149+ public void check$java_net_URLClassLoader$ (Class <?> callerClass , URL [] urls ) {}
150+
151+ @ Override
152+ public void check$java_net_URLClassLoader$ (Class <?> callerClass , URL [] urls , ClassLoader parent ) {}
153+
154+ @ Override
155+ public void check$java_net_URLClassLoader$ (Class <?> callerClass , URL [] urls , ClassLoader parent , URLStreamHandlerFactory factory ) {}
156+
157+ @ Override
158+ public void check$java_net_URLClassLoader$ (Class <?> callerClass , String name , URL [] urls , ClassLoader parent ) {}
159+
160+ @ Override
161+ public void check$java_net_URLClassLoader$ (
162+ Class <?> callerClass ,
163+ String name ,
164+ URL [] urls ,
165+ ClassLoader parent ,
166+ URLStreamHandlerFactory factory
167+ ) {}
168+
129169 private void throwIfActive () {
130170 if (isActive ) {
131171 throw new TestException ();
@@ -161,6 +201,21 @@ public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg
161201 assertEquals ("def" , anotherArg );
162202 throwIfActive ();
163203 }
204+
205+ @ Override
206+ public void checkCtor (Class <?> callerClass ) {
207+ checkCtorCallCount ++;
208+ assertSame (InstrumenterTests .class , callerClass );
209+ throwIfActive ();
210+ }
211+
212+ @ Override
213+ public void checkCtor (Class <?> callerClass , int arg ) {
214+ checkCtorIntCallCount ++;
215+ assertSame (InstrumenterTests .class , callerClass );
216+ assertEquals (123 , arg );
217+ throwIfActive ();
218+ }
164219 }
165220
166221 public void testClassIsInstrumented () throws Exception {
@@ -225,7 +280,7 @@ public void testClassIsNotInstrumentedTwice() throws Exception {
225280 getTestEntitlementChecker ().checkSystemExitCallCount = 0 ;
226281
227282 assertThrows (TestException .class , () -> callStaticMethod (newClass , "systemExit" , 123 ));
228- assertThat ( getTestEntitlementChecker ().checkSystemExitCallCount , is ( 1 ) );
283+ assertEquals ( 1 , getTestEntitlementChecker ().checkSystemExitCallCount );
229284 }
230285
231286 public void testClassAllMethodsAreInstrumentedFirstPass () throws Exception {
@@ -259,10 +314,10 @@ public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
259314 getTestEntitlementChecker ().checkSystemExitCallCount = 0 ;
260315
261316 assertThrows (TestException .class , () -> callStaticMethod (newClass , "systemExit" , 123 ));
262- assertThat ( getTestEntitlementChecker ().checkSystemExitCallCount , is ( 1 ) );
317+ assertEquals ( 1 , getTestEntitlementChecker ().checkSystemExitCallCount );
263318
264319 assertThrows (TestException .class , () -> callStaticMethod (newClass , "anotherSystemExit" , 123 ));
265- assertThat ( getTestEntitlementChecker ().checkSystemExitCallCount , is ( 2 ) );
320+ assertEquals ( 2 , getTestEntitlementChecker ().checkSystemExitCallCount );
266321 }
267322
268323 public void testInstrumenterWorksWithOverloads () throws Exception {
@@ -294,8 +349,8 @@ public void testInstrumenterWorksWithOverloads() throws Exception {
294349 assertThrows (TestException .class , () -> callStaticMethod (newClass , "someStaticMethod" , 123 ));
295350 assertThrows (TestException .class , () -> callStaticMethod (newClass , "someStaticMethod" , 123 , "abc" ));
296351
297- assertThat ( getTestEntitlementChecker ().checkSomeStaticMethodIntCallCount , is ( 1 ) );
298- assertThat ( getTestEntitlementChecker ().checkSomeStaticMethodIntStringCallCount , is ( 1 ) );
352+ assertEquals ( 1 , getTestEntitlementChecker ().checkSomeStaticMethodIntCallCount );
353+ assertEquals ( 1 , getTestEntitlementChecker ().checkSomeStaticMethodIntStringCallCount );
299354 }
300355
301356 public void testInstrumenterWorksWithInstanceMethodsAndOverloads () throws Exception {
@@ -327,7 +382,41 @@ public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Except
327382 testTargetClass .someMethod (123 );
328383 assertThrows (TestException .class , () -> testTargetClass .someMethod (123 , "def" ));
329384
330- assertThat (getTestEntitlementChecker ().checkSomeInstanceMethodCallCount , is (1 ));
385+ assertEquals (1 , getTestEntitlementChecker ().checkSomeInstanceMethodCallCount );
386+ }
387+
388+ public void testInstrumenterWorksWithConstructors () throws Exception {
389+ var classToInstrument = ClassToInstrument .class ;
390+
391+ Map <MethodKey , CheckerMethod > methods = Map .of (
392+ new MethodKey (classToInstrument .getName ().replace ('.' , '/' ), "<init>" , List .of ()),
393+ getCheckerMethod (MockEntitlementChecker .class , "checkCtor" , Class .class ),
394+ new MethodKey (classToInstrument .getName ().replace ('.' , '/' ), "<init>" , List .of ("I" )),
395+ getCheckerMethod (MockEntitlementChecker .class , "checkCtor" , Class .class , int .class )
396+ );
397+
398+ var instrumenter = createInstrumenter (methods );
399+
400+ byte [] newBytecode = instrumenter .instrumentClassFile (classToInstrument ).bytecodes ();
401+
402+ if (logger .isTraceEnabled ()) {
403+ logger .trace ("Bytecode after instrumentation:\n {}" , bytecode2text (newBytecode ));
404+ }
405+
406+ Class <?> newClass = new TestLoader (Testable .class .getClassLoader ()).defineClassFromBytes (
407+ classToInstrument .getName () + "_NEW" ,
408+ newBytecode
409+ );
410+
411+ getTestEntitlementChecker ().isActive = true ;
412+
413+ var ex = assertThrows (InvocationTargetException .class , () -> newClass .getConstructor ().newInstance ());
414+ assertThat (ex .getCause (), instanceOf (TestException .class ));
415+ var ex2 = assertThrows (InvocationTargetException .class , () -> newClass .getConstructor (int .class ).newInstance (123 ));
416+ assertThat (ex2 .getCause (), instanceOf (TestException .class ));
417+
418+ assertEquals (1 , getTestEntitlementChecker ().checkCtorCallCount );
419+ assertEquals (1 , getTestEntitlementChecker ().checkCtorIntCallCount );
331420 }
332421
333422 /** This test doesn't replace classToInstrument in-place but instead loads a separate
0 commit comments