99
1010package org .elasticsearch .entitlement .instrumentation .impl ;
1111
12+ import org .elasticsearch .core .SuppressForbidden ;
1213import org .elasticsearch .entitlement .instrumentation .CheckMethod ;
1314import org .elasticsearch .entitlement .instrumentation .InstrumentationService ;
1415import org .elasticsearch .entitlement .instrumentation .Instrumenter ;
2021import org .objectweb .asm .Type ;
2122
2223import java .io .IOException ;
24+ import java .lang .reflect .Method ;
25+ import java .lang .reflect .Modifier ;
2326import java .util .Arrays ;
2427import java .util .HashMap ;
2528import java .util .List ;
2629import java .util .Locale ;
2730import java .util .Map ;
31+ import java .util .stream .Collectors ;
32+ import java .util .stream .Stream ;
2833
2934public class InstrumentationServiceImpl implements InstrumentationService {
3035
@@ -48,22 +53,159 @@ public MethodVisitor visitMethod(
4853 String [] exceptions
4954 ) {
5055 var mv = super .visitMethod (access , checkerMethodName , checkerMethodDescriptor , signature , exceptions );
56+ if (checkerMethodName .startsWith (InstrumentationService .CHECK_METHOD_PREFIX )) {
57+ var checkerMethodArgumentTypes = Type .getArgumentTypes (checkerMethodDescriptor );
58+ var methodToInstrument = parseCheckerMethodSignature (checkerMethodName , checkerMethodArgumentTypes );
5159
52- var checkerMethodArgumentTypes = Type .getArgumentTypes (checkerMethodDescriptor );
53- var methodToInstrument = parseCheckerMethodSignature (checkerMethodName , checkerMethodArgumentTypes );
54-
55- var checkerParameterDescriptors = Arrays .stream (checkerMethodArgumentTypes ).map (Type ::getDescriptor ).toList ();
56- var checkMethod = new CheckMethod (Type .getInternalName (checkerClass ), checkerMethodName , checkerParameterDescriptors );
57-
58- methodsToInstrument .put (methodToInstrument , checkMethod );
60+ var checkerParameterDescriptors = Arrays .stream (checkerMethodArgumentTypes ).map (Type ::getDescriptor ).toList ();
61+ var checkMethod = new CheckMethod (Type .getInternalName (checkerClass ), checkerMethodName , checkerParameterDescriptors );
5962
63+ methodsToInstrument .put (methodToInstrument , checkMethod );
64+ }
6065 return mv ;
6166 }
6267 };
6368 reader .accept (visitor , 0 );
6469 return methodsToInstrument ;
6570 }
6671
72+ @ SuppressForbidden (reason = "Need access to abstract methods (protected/package internal) in base class" )
73+ @ Override
74+ public InstrumentationInfo lookupImplementationMethod (
75+ Class <?> targetSuperclass ,
76+ String methodName ,
77+ Class <?> implementationClass ,
78+ Class <?> checkerClass ,
79+ String checkMethodName ,
80+ Class <?>... parameterTypes
81+ ) throws NoSuchMethodException , ClassNotFoundException {
82+
83+ var targetMethod = targetSuperclass .getDeclaredMethod (methodName , parameterTypes );
84+ validateTargetMethod (implementationClass , targetMethod );
85+
86+ var checkerAdditionalArguments = Stream .of (Class .class , targetSuperclass );
87+ var checkMethodArgumentTypes = Stream .concat (checkerAdditionalArguments , Arrays .stream (parameterTypes ))
88+ .map (Type ::getType )
89+ .toArray (Type []::new );
90+
91+ CheckMethod [] checkMethod = new CheckMethod [1 ];
92+
93+ try {
94+ InstrumenterImpl .ClassFileInfo classFileInfo = InstrumenterImpl .getClassFileInfo (checkerClass );
95+ ClassReader reader = new ClassReader (classFileInfo .bytecodes ());
96+ ClassVisitor visitor = new ClassVisitor (Opcodes .ASM9 ) {
97+ @ Override
98+ public MethodVisitor visitMethod (
99+ int access ,
100+ String methodName ,
101+ String methodDescriptor ,
102+ String signature ,
103+ String [] exceptions
104+ ) {
105+ var mv = super .visitMethod (access , methodName , methodDescriptor , signature , exceptions );
106+ if (methodName .equals (checkMethodName )) {
107+ var methodArgumentTypes = Type .getArgumentTypes (methodDescriptor );
108+ if (Arrays .equals (methodArgumentTypes , checkMethodArgumentTypes )) {
109+ var checkerParameterDescriptors = Arrays .stream (methodArgumentTypes ).map (Type ::getDescriptor ).toList ();
110+ checkMethod [0 ] = new CheckMethod (Type .getInternalName (checkerClass ), methodName , checkerParameterDescriptors );
111+ }
112+ }
113+ return mv ;
114+ }
115+ };
116+ reader .accept (visitor , 0 );
117+ } catch (IOException e ) {
118+ throw new ClassNotFoundException ("Cannot find a definition for class [" + checkerClass .getName () + "]" , e );
119+ }
120+
121+ if (checkMethod [0 ] == null ) {
122+ throw new NoSuchMethodException (
123+ String .format (
124+ Locale .ROOT ,
125+ "Cannot find a method with name [%s] and arguments [%s] in class [%s]" ,
126+ checkMethodName ,
127+ Arrays .stream (checkMethodArgumentTypes ).map (Type ::toString ).collect (Collectors .joining ()),
128+ checkerClass .getName ()
129+ )
130+ );
131+ }
132+
133+ return new InstrumentationInfo (
134+ new MethodKey (
135+ Type .getInternalName (implementationClass ),
136+ targetMethod .getName (),
137+ Arrays .stream (parameterTypes ).map (c -> Type .getType (c ).getInternalName ()).toList ()
138+ ),
139+ checkMethod [0 ]
140+ );
141+ }
142+
143+ private static void validateTargetMethod (Class <?> implementationClass , Method targetMethod ) {
144+ if (targetMethod .getDeclaringClass ().isAssignableFrom (implementationClass ) == false ) {
145+ throw new IllegalArgumentException (
146+ String .format (
147+ Locale .ROOT ,
148+ "Not an implementation class for %s: %s does not implement %s" ,
149+ targetMethod .getName (),
150+ implementationClass .getName (),
151+ targetMethod .getDeclaringClass ().getName ()
152+ )
153+ );
154+ }
155+ if (Modifier .isPrivate (targetMethod .getModifiers ())) {
156+ throw new IllegalArgumentException (
157+ String .format (
158+ Locale .ROOT ,
159+ "Not a valid instrumentation method: %s is private in %s" ,
160+ targetMethod .getName (),
161+ targetMethod .getDeclaringClass ().getName ()
162+ )
163+ );
164+ }
165+ if (Modifier .isStatic (targetMethod .getModifiers ())) {
166+ throw new IllegalArgumentException (
167+ String .format (
168+ Locale .ROOT ,
169+ "Not a valid instrumentation method: %s is static in %s" ,
170+ targetMethod .getName (),
171+ targetMethod .getDeclaringClass ().getName ()
172+ )
173+ );
174+ }
175+ try {
176+ var implementationMethod = implementationClass .getMethod (targetMethod .getName (), targetMethod .getParameterTypes ());
177+ var methodModifiers = implementationMethod .getModifiers ();
178+ if (Modifier .isAbstract (methodModifiers )) {
179+ throw new IllegalArgumentException (
180+ String .format (
181+ Locale .ROOT ,
182+ "Not a valid instrumentation method: %s is abstract in %s" ,
183+ targetMethod .getName (),
184+ implementationClass .getName ()
185+ )
186+ );
187+ }
188+ if (Modifier .isPublic (methodModifiers ) == false ) {
189+ throw new IllegalArgumentException (
190+ String .format (
191+ Locale .ROOT ,
192+ "Not a valid instrumentation method: %s is not public in %s" ,
193+ targetMethod .getName (),
194+ implementationClass .getName ()
195+ )
196+ );
197+ }
198+ } catch (NoSuchMethodException e ) {
199+ assert false
200+ : String .format (
201+ Locale .ROOT ,
202+ "Not a valid instrumentation method: %s cannot be found in %s" ,
203+ targetMethod .getName (),
204+ implementationClass .getName ()
205+ );
206+ }
207+ }
208+
67209 private static final Type CLASS_TYPE = Type .getType (Class .class );
68210
69211 static ParsedCheckerMethod parseCheckerMethodName (String checkerMethodName ) {
@@ -85,8 +227,8 @@ static ParsedCheckerMethod parseCheckerMethodName(String checkerMethodName) {
85227 String .format (
86228 Locale .ROOT ,
87229 "Checker method %s has incorrect name format. "
88- + "It should be either check$$methodName (instance), check$package_ClassName$methodName (static) or "
89- + "check$package_ClassName$ (ctor)" ,
230+ + "It should be either check$package_ClassName $methodName (instance), check$package_ClassName$$ methodName (static) "
231+ + "or check$package_ClassName$ (ctor)" ,
90232 checkerMethodName
91233 )
92234 );
0 commit comments