99
1010package org .elasticsearch .entitlement .instrumentation .impl ;
1111
12+ import org .elasticsearch .entitlement .instrumentation .CheckerMethod ;
1213import org .elasticsearch .entitlement .instrumentation .InstrumentationService ;
1314import org .elasticsearch .entitlement .instrumentation .Instrumenter ;
1415import org .elasticsearch .entitlement .instrumentation .MethodKey ;
16+ import org .objectweb .asm .ClassReader ;
17+ import org .objectweb .asm .ClassVisitor ;
18+ import org .objectweb .asm .MethodVisitor ;
19+ import org .objectweb .asm .Opcodes ;
1520import org .objectweb .asm .Type ;
1621
22+ import java .io .IOException ;
1723import java .lang .reflect .Method ;
18- import java .lang .reflect .Modifier ;
24+ import java .util .Arrays ;
25+ import java .util .HashMap ;
26+ import java .util .List ;
27+ import java .util .Locale ;
1928import java .util .Map ;
2029import java .util .stream .Stream ;
2130
2231public class InstrumentationServiceImpl implements InstrumentationService {
32+
2333 @ Override
24- public Instrumenter newInstrumenter (String classNameSuffix , Map <MethodKey , Method > instrumentationMethods ) {
34+ public Instrumenter newInstrumenter (String classNameSuffix , Map <MethodKey , CheckerMethod > instrumentationMethods ) {
2535 return new InstrumenterImpl (classNameSuffix , instrumentationMethods );
2636 }
2737
@@ -33,9 +43,97 @@ public MethodKey methodKeyForTarget(Method targetMethod) {
3343 return new MethodKey (
3444 Type .getInternalName (targetMethod .getDeclaringClass ()),
3545 targetMethod .getName (),
36- Stream .of (actualType .getArgumentTypes ()).map (Type ::getInternalName ).toList (),
37- Modifier .isStatic (targetMethod .getModifiers ())
46+ Stream .of (actualType .getArgumentTypes ()).map (Type ::getInternalName ).toList ()
3847 );
3948 }
4049
50+ @ Override
51+ public Map <MethodKey , CheckerMethod > lookupMethodsToInstrument (String entitlementCheckerClassName ) throws ClassNotFoundException ,
52+ IOException {
53+ var methodsToInstrument = new HashMap <MethodKey , CheckerMethod >();
54+ var checkerClass = Class .forName (entitlementCheckerClassName );
55+ var classFileInfo = InstrumenterImpl .getClassFileInfo (checkerClass );
56+ ClassReader reader = new ClassReader (classFileInfo .bytecodes ());
57+ ClassVisitor visitor = new ClassVisitor (Opcodes .ASM9 ) {
58+ @ Override
59+ public MethodVisitor visitMethod (
60+ int access ,
61+ String checkerMethodName ,
62+ String checkerMethodDescriptor ,
63+ String signature ,
64+ String [] exceptions
65+ ) {
66+ var mv = super .visitMethod (access , checkerMethodName , checkerMethodDescriptor , signature , exceptions );
67+
68+ var checkerMethodArgumentTypes = Type .getArgumentTypes (checkerMethodDescriptor );
69+ var methodToInstrument = parseCheckerMethodSignature (checkerMethodName , checkerMethodArgumentTypes );
70+
71+ var checkerParameterDescriptors = Arrays .stream (checkerMethodArgumentTypes ).map (Type ::getDescriptor ).toList ();
72+ var checkerMethod = new CheckerMethod (Type .getInternalName (checkerClass ), checkerMethodName , checkerParameterDescriptors );
73+
74+ methodsToInstrument .put (methodToInstrument , checkerMethod );
75+
76+ return mv ;
77+ }
78+ };
79+ reader .accept (visitor , 0 );
80+ return methodsToInstrument ;
81+ }
82+
83+ private static final Type CLASS_TYPE = Type .getType (Class .class );
84+
85+ static MethodKey parseCheckerMethodSignature (String checkerMethodName , Type [] checkerMethodArgumentTypes ) {
86+ var classNameStartIndex = checkerMethodName .indexOf ('$' );
87+ var classNameEndIndex = checkerMethodName .lastIndexOf ('$' );
88+
89+ if (classNameStartIndex == -1 || classNameStartIndex >= classNameEndIndex ) {
90+ throw new IllegalArgumentException (
91+ String .format (
92+ Locale .ROOT ,
93+ "Checker method %s has incorrect name format. "
94+ + "It should be either check$$methodName (instance) or check$package_ClassName$methodName (static)" ,
95+ checkerMethodName
96+ )
97+ );
98+ }
99+
100+ // No "className" (check$$methodName) -> method is static, and we'll get the class from the actual typed argument
101+ final boolean targetMethodIsStatic = classNameStartIndex + 1 != classNameEndIndex ;
102+ final String targetMethodName = checkerMethodName .substring (classNameEndIndex + 1 );
103+
104+ final String targetClassName ;
105+ final List <String > targetParameterTypes ;
106+ if (targetMethodIsStatic ) {
107+ if (checkerMethodArgumentTypes .length < 1 || CLASS_TYPE .equals (checkerMethodArgumentTypes [0 ]) == false ) {
108+ throw new IllegalArgumentException (
109+ String .format (
110+ Locale .ROOT ,
111+ "Checker method %s has incorrect argument types. " + "It must have a first argument of Class<?> type." ,
112+ checkerMethodName
113+ )
114+ );
115+ }
116+
117+ targetClassName = checkerMethodName .substring (classNameStartIndex + 1 , classNameEndIndex ).replace ('_' , '/' );
118+ targetParameterTypes = Arrays .stream (checkerMethodArgumentTypes ).skip (1 ).map (Type ::getInternalName ).toList ();
119+ } else {
120+ if (checkerMethodArgumentTypes .length < 2
121+ || CLASS_TYPE .equals (checkerMethodArgumentTypes [0 ]) == false
122+ || checkerMethodArgumentTypes [1 ].getSort () != Type .OBJECT ) {
123+ throw new IllegalArgumentException (
124+ String .format (
125+ Locale .ROOT ,
126+ "Checker method %s has incorrect argument types. "
127+ + "It must have a first argument of Class<?> type, and a second argument of the class containing the method to "
128+ + "instrument" ,
129+ checkerMethodName
130+ )
131+ );
132+ }
133+ var targetClassType = checkerMethodArgumentTypes [1 ];
134+ targetClassName = targetClassType .getInternalName ();
135+ targetParameterTypes = Arrays .stream (checkerMethodArgumentTypes ).skip (2 ).map (Type ::getInternalName ).toList ();
136+ }
137+ return new MethodKey (targetClassName , targetMethodName , targetParameterTypes );
138+ }
41139}
0 commit comments