2323
2424import  java .lang .reflect .InvocationTargetException ;
2525import  java .lang .reflect .Method ;
26+ import  java .net .URL ;
27+ import  java .net .URLStreamHandlerFactory ;
2628import  java .util .Arrays ;
29+ import  java .util .List ;
2730import  java .util .Map ;
2831
2932import  static  org .elasticsearch .entitlement .instrumentation .impl .ASMUtils .bytecode2text ;
3033import  static  org .elasticsearch .entitlement .instrumentation .impl .InstrumenterImpl .getClassFileInfo ;
31- import  static  org .hamcrest .Matchers .is ;
34+ import  static  org .hamcrest .Matchers .instanceOf ;
3235import  static  org .hamcrest .Matchers .startsWith ;
3336import  static  org .objectweb .asm .Opcodes .INVOKESTATIC ;
3437
@@ -72,6 +75,11 @@ public interface Testable {
7275     * They must not throw {@link TestException}. 
7376     */ 
7477    public  static  class  ClassToInstrument  implements  Testable  {
78+ 
79+         public  ClassToInstrument () {}
80+ 
81+         public  ClassToInstrument (int  arg ) {}
82+ 
7583        public  static  void  systemExit (int  status ) {
7684            assertEquals (123 , status );
7785        }
@@ -91,12 +99,20 @@ public static void someStaticMethod(int arg, String anotherArg) {}
9199
92100    static  final  class  TestException  extends  RuntimeException  {}
93101
102+     /** 
103+      * Interface to test specific, "synthetic" cases (e.g. overloaded methods, overloaded constructors, etc.) that 
104+      * may be not present/may be difficult to find or not clear in the production EntitlementChecker interface 
105+      */ 
94106    public  interface  MockEntitlementChecker  extends  EntitlementChecker  {
95107        void  checkSomeStaticMethod (Class <?> clazz , int  arg );
96108
97109        void  checkSomeStaticMethod (Class <?> clazz , int  arg , String  anotherArg );
98110
99111        void  checkSomeInstanceMethod (Class <?> clazz , Testable  that , int  arg , String  anotherArg );
112+ 
113+         void  checkCtor (Class <?> clazz );
114+ 
115+         void  checkCtor (Class <?> clazz , int  arg );
100116    }
101117
102118    /** 
@@ -118,6 +134,9 @@ public static class TestEntitlementChecker implements MockEntitlementChecker {
118134        int  checkSomeStaticMethodIntStringCallCount  = 0 ;
119135        int  checkSomeInstanceMethodCallCount  = 0 ;
120136
137+         int  checkCtorCallCount  = 0 ;
138+         int  checkCtorIntCallCount  = 0 ;
139+ 
121140        @ Override 
122141        public  void  check$java_lang_System$exit (Class <?> callerClass , int  status ) {
123142            checkSystemExitCallCount ++;
@@ -126,6 +145,27 @@ public static class TestEntitlementChecker implements MockEntitlementChecker {
126145            throwIfActive ();
127146        }
128147
148+         @ Override 
149+         public  void  check$java_net_URLClassLoader$ (Class <?> callerClass , URL [] urls ) {}
150+ 
151+         @ Override 
152+         public  void  check$java_net_URLClassLoader$ (Class <?> callerClass , URL [] urls , ClassLoader  parent ) {}
153+ 
154+         @ Override 
155+         public  void  check$java_net_URLClassLoader$ (Class <?> callerClass , URL [] urls , ClassLoader  parent , URLStreamHandlerFactory  factory ) {}
156+ 
157+         @ Override 
158+         public  void  check$java_net_URLClassLoader$ (Class <?> callerClass , String  name , URL [] urls , ClassLoader  parent ) {}
159+ 
160+         @ Override 
161+         public  void  check$java_net_URLClassLoader$ (
162+             Class <?> callerClass ,
163+             String  name ,
164+             URL [] urls ,
165+             ClassLoader  parent ,
166+             URLStreamHandlerFactory  factory 
167+         ) {}
168+ 
129169        private  void  throwIfActive () {
130170            if  (isActive ) {
131171                throw  new  TestException ();
@@ -161,6 +201,21 @@ public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg
161201            assertEquals ("def" , anotherArg );
162202            throwIfActive ();
163203        }
204+ 
205+         @ Override 
206+         public  void  checkCtor (Class <?> callerClass ) {
207+             checkCtorCallCount ++;
208+             assertSame (InstrumenterTests .class , callerClass );
209+             throwIfActive ();
210+         }
211+ 
212+         @ Override 
213+         public  void  checkCtor (Class <?> callerClass , int  arg ) {
214+             checkCtorIntCallCount ++;
215+             assertSame (InstrumenterTests .class , callerClass );
216+             assertEquals (123 , arg );
217+             throwIfActive ();
218+         }
164219    }
165220
166221    public  void  testClassIsInstrumented () throws  Exception  {
@@ -225,7 +280,7 @@ public void testClassIsNotInstrumentedTwice() throws Exception {
225280        getTestEntitlementChecker ().checkSystemExitCallCount  = 0 ;
226281
227282        assertThrows (TestException .class , () -> callStaticMethod (newClass , "systemExit" , 123 ));
228-         assertThat ( getTestEntitlementChecker ().checkSystemExitCallCount ,  is ( 1 ) );
283+         assertEquals ( 1 ,  getTestEntitlementChecker ().checkSystemExitCallCount );
229284    }
230285
231286    public  void  testClassAllMethodsAreInstrumentedFirstPass () throws  Exception  {
@@ -259,10 +314,10 @@ public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
259314        getTestEntitlementChecker ().checkSystemExitCallCount  = 0 ;
260315
261316        assertThrows (TestException .class , () -> callStaticMethod (newClass , "systemExit" , 123 ));
262-         assertThat ( getTestEntitlementChecker ().checkSystemExitCallCount ,  is ( 1 ) );
317+         assertEquals ( 1 ,  getTestEntitlementChecker ().checkSystemExitCallCount );
263318
264319        assertThrows (TestException .class , () -> callStaticMethod (newClass , "anotherSystemExit" , 123 ));
265-         assertThat ( getTestEntitlementChecker ().checkSystemExitCallCount ,  is ( 2 ) );
320+         assertEquals ( 2 ,  getTestEntitlementChecker ().checkSystemExitCallCount );
266321    }
267322
268323    public  void  testInstrumenterWorksWithOverloads () throws  Exception  {
@@ -294,8 +349,8 @@ public void testInstrumenterWorksWithOverloads() throws Exception {
294349        assertThrows (TestException .class , () -> callStaticMethod (newClass , "someStaticMethod" , 123 ));
295350        assertThrows (TestException .class , () -> callStaticMethod (newClass , "someStaticMethod" , 123 , "abc" ));
296351
297-         assertThat ( getTestEntitlementChecker ().checkSomeStaticMethodIntCallCount ,  is ( 1 ) );
298-         assertThat ( getTestEntitlementChecker ().checkSomeStaticMethodIntStringCallCount ,  is ( 1 ) );
352+         assertEquals ( 1 ,  getTestEntitlementChecker ().checkSomeStaticMethodIntCallCount );
353+         assertEquals ( 1 ,  getTestEntitlementChecker ().checkSomeStaticMethodIntStringCallCount );
299354    }
300355
301356    public  void  testInstrumenterWorksWithInstanceMethodsAndOverloads () throws  Exception  {
@@ -327,7 +382,41 @@ public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Except
327382        testTargetClass .someMethod (123 );
328383        assertThrows (TestException .class , () -> testTargetClass .someMethod (123 , "def" ));
329384
330-         assertThat (getTestEntitlementChecker ().checkSomeInstanceMethodCallCount , is (1 ));
385+         assertEquals (1 , getTestEntitlementChecker ().checkSomeInstanceMethodCallCount );
386+     }
387+ 
388+     public  void  testInstrumenterWorksWithConstructors () throws  Exception  {
389+         var  classToInstrument  = ClassToInstrument .class ;
390+ 
391+         Map <MethodKey , CheckerMethod > methods  = Map .of (
392+             new  MethodKey (classToInstrument .getName ().replace ('.' , '/' ), "<init>" , List .of ()),
393+             getCheckerMethod (MockEntitlementChecker .class , "checkCtor" , Class .class ),
394+             new  MethodKey (classToInstrument .getName ().replace ('.' , '/' ), "<init>" , List .of ("I" )),
395+             getCheckerMethod (MockEntitlementChecker .class , "checkCtor" , Class .class , int .class )
396+         );
397+ 
398+         var  instrumenter  = createInstrumenter (methods );
399+ 
400+         byte [] newBytecode  = instrumenter .instrumentClassFile (classToInstrument ).bytecodes ();
401+ 
402+         if  (logger .isTraceEnabled ()) {
403+             logger .trace ("Bytecode after instrumentation:\n {}" , bytecode2text (newBytecode ));
404+         }
405+ 
406+         Class <?> newClass  = new  TestLoader (Testable .class .getClassLoader ()).defineClassFromBytes (
407+             classToInstrument .getName () + "_NEW" ,
408+             newBytecode 
409+         );
410+ 
411+         getTestEntitlementChecker ().isActive  = true ;
412+ 
413+         var  ex  = assertThrows (InvocationTargetException .class , () -> newClass .getConstructor ().newInstance ());
414+         assertThat (ex .getCause (), instanceOf (TestException .class ));
415+         var  ex2  = assertThrows (InvocationTargetException .class , () -> newClass .getConstructor (int .class ).newInstance (123 ));
416+         assertThat (ex2 .getCause (), instanceOf (TestException .class ));
417+ 
418+         assertEquals (1 , getTestEntitlementChecker ().checkCtorCallCount );
419+         assertEquals (1 , getTestEntitlementChecker ().checkCtorIntCallCount );
331420    }
332421
333422    /** This test doesn't replace classToInstrument in-place but instead loads a separate 
0 commit comments