From 173f52a48468a663cbd7d8d99b329b895a9ee9d3 Mon Sep 17 00:00:00 2001 From: timdev <168923949+timdev-ger@users.noreply.github.com> Date: Mon, 23 Sep 2024 10:39:54 +0200 Subject: [PATCH 1/2] Update Transformer.java Ensure that the operand stack and local variables are properly handled when frames are manipulated. Here's an improvement to this part of the code, where we make sure that the correct frame type and operand stack size are calculated --- .../ea/async/instrumentation/Transformer.java | 913 +++++++----------- 1 file changed, 337 insertions(+), 576 deletions(-) diff --git a/async/src/main/java/com/ea/async/instrumentation/Transformer.java b/async/src/main/java/com/ea/async/instrumentation/Transformer.java index b4e3703..53d3341 100644 --- a/async/src/main/java/com/ea/async/instrumentation/Transformer.java +++ b/async/src/main/java/com/ea/async/instrumentation/Transformer.java @@ -84,8 +84,7 @@ * * @author Daniel Sperry */ -public class Transformer implements ClassFileTransformer -{ +public class Transformer implements ClassFileTransformer { /** * Name of the property that will be set by the * Agent to flag that the instrumentation is already running. @@ -121,41 +120,35 @@ public class Transformer implements ClassFileTransformer public static final String JOIN_METHOD_NAME = "join"; public static final String JOIN_METHOD_DESC = "()Ljava/lang/Object;"; - public static final String LAMBDA_DESC = "(Ljava/lang/invoke/MethodHandles$Lookup;" - + "Ljava/lang/String;" - + "Ljava/lang/invoke/MethodType;" - + "Ljava/lang/invoke/MethodType;" - + "Ljava/lang/invoke/MethodHandle;" - + "Ljava/lang/invoke/MethodType;" - + ")Ljava/lang/invoke/CallSite;"; + public static final String LAMBDA_DESC = "(Ljava/lang/invoke/MethodHandles$Lookup;" + + "Ljava/lang/String;" + + "Ljava/lang/invoke/MethodType;" + + "Ljava/lang/invoke/MethodType;" + + "Ljava/lang/invoke/MethodHandle;" + + "Ljava/lang/invoke/MethodType;" + + ")Ljava/lang/invoke/CallSite;"; public static final Handle METAFACTORY_HANDLE = new Handle(Opcodes.H_INVOKESTATIC, - "java/lang/invoke/LambdaMetafactory", - "metafactory", - LAMBDA_DESC); + "java/lang/invoke/LambdaMetafactory", + "metafactory", + LAMBDA_DESC); - public static WeakHashMap futureMap = new WeakHashMap<>(); - public static WeakHashMap> classLoaderMap = new WeakHashMap<>(); + public static WeakHashMap < URL, Boolean > futureMap = new WeakHashMap < > (); + public static WeakHashMap < ClassLoader, Set < URL >> classLoaderMap = new WeakHashMap < > (); - private Consumer errorListener; + private Consumer < String > errorListener; @Override - public byte[] transform(final ClassLoader loader, final String className, final Class classBeingRedefined, final ProtectionDomain protectionDomain, final byte[] classfileBuffer) throws IllegalClassFormatException - { - try - { - if (className != null && className.startsWith("java")) - { + public byte[] transform(final ClassLoader loader, final String className, final Class << ? > classBeingRedefined, final ProtectionDomain protectionDomain, final byte[] classfileBuffer) throws IllegalClassFormatException { + try { + if (className != null && className.startsWith("java")) { return null; } ClassReader cr = new ClassReader(classfileBuffer); - if (needsInstrumentation(cr)) - { + if (needsInstrumentation(cr)) { return transform(loader, cr); } return null; - } - catch (Exception | Error e) - { + } catch (Exception | Error e) { // Avoid using slf4j or any dependency here. // this is supposed to be a critical error. // it should be ok to write directly to the syserr. @@ -166,8 +159,7 @@ public byte[] transform(final ClassLoader loader, final String className, final } } - static class SwitchEntry - { + static class SwitchEntry { final Label resumeLabel; final Label futureIsDoneLabel; final int key; @@ -178,8 +170,7 @@ static class SwitchEntry public int[] argumentToLocal; public int[] localToiArgument; - public SwitchEntry(final int key, final FrameAnalyzer.ExtendedFrame frame, final int index) - { + public SwitchEntry(final int key, final FrameAnalyzer.ExtendedFrame frame, final int index) { this.key = key; this.frame = frame; this.index = index; @@ -188,8 +179,7 @@ public SwitchEntry(final int key, final FrameAnalyzer.ExtendedFrame frame, final } } - static class Argument - { + static class Argument { BasicValue value; String name; int iArgumentLocal; @@ -197,18 +187,13 @@ static class Argument int tmpLocalMapping = -1; } - public byte[] instrument(final ClassLoader classLoader, InputStream inputStream) - { - try - { + public byte[] instrument(final ClassLoader classLoader, InputStream inputStream) { + try { ClassReader cr = new ClassReader(inputStream); - if (needsInstrumentation(cr)) - { + if (needsInstrumentation(cr)) { return transform(classLoader, cr); } - } - catch (Exception e) - { + } catch (Exception e) { throw new RuntimeException(e); } return null; @@ -221,88 +206,70 @@ public byte[] instrument(final ClassLoader classLoader, InputStream inputStream) * @param cr the class reader for this class * @return null or the new class bytes */ - public byte[] transform(final ClassLoader classLoader, ClassReader cr) throws AnalyzerException - { + public byte[] transform(final ClassLoader classLoader, ClassReader cr) throws AnalyzerException { ClassNode classNode = new ClassNode(); // using EXPAND_FRAMES because F_SAME causes problems when inserting new frames cr.accept(classNode, ClassReader.EXPAND_FRAMES); int countInstrumented = 0; - Map nameUseCount = new HashMap<>(); + Map < String, Integer > nameUseCount = new HashMap < > (); - for (MethodNode original : (List) new ArrayList(classNode.methods)) - { + for (MethodNode original: (List < MethodNode > ) new ArrayList(classNode.methods)) { Integer countOriginalUses = nameUseCount.get(original.name); nameUseCount.put(original.name, countOriginalUses == null ? 1 : countOriginalUses + 1); boolean hasAwaitCall = false; boolean hasAwaitInitCall = false; - for (Iterator it = original.instructions.iterator(); it.hasNext(); ) - { + for (Iterator it = original.instructions.iterator(); it.hasNext();) { Object o = it.next(); - if (o instanceof MethodInsnNode) - { - if (!hasAwaitCall) - { + if (o instanceof MethodInsnNode) { + if (!hasAwaitCall) { hasAwaitCall = isAwaitCall((MethodInsnNode) o); } - if (!hasAwaitInitCall) - { + if (!hasAwaitInitCall) { hasAwaitInitCall = isAwaitInitCall((MethodInsnNode) o); } } } - if (!hasAwaitCall && !hasAwaitInitCall) - { + if (!hasAwaitCall && !hasAwaitInitCall) { continue; } countInstrumented++; boolean nonCompFutReturn = !original.desc.endsWith(COMPLETABLE_FUTURE_RET) && !original.desc.endsWith(COMPLETION_STAGE_RET); - if (original.desc.endsWith(COMPLETABLE_FUTURE_RET) - || original.desc.endsWith(COMPLETION_STAGE_RET) - || returnsCompletionStageSubClass(classLoader, original) - ) - { + if (original.desc.endsWith(COMPLETABLE_FUTURE_RET) || + original.desc.endsWith(COMPLETION_STAGE_RET) || + returnsCompletionStageSubClass(classLoader, original) + ) { // async method transformAsyncMethod(classNode, original, nameUseCount); - } - else - { + } else { // non async method // Removing calls to `Await.init()` // and advising on wrong uses of Await.await final MethodNode replacement = new MethodNode(original.access, - original.name, original.desc, original.signature, (String[]) original.exceptions.toArray(new String[original.exceptions.size()])); + original.name, original.desc, original.signature, (String[]) original.exceptions.toArray(new String[original.exceptions.size()])); - class MyMethodVisitor extends MethodVisitor - { - public MyMethodVisitor(final MethodNode mv) - { + class MyMethodVisitor extends MethodVisitor { + public MyMethodVisitor(final MethodNode mv) { super(Opcodes.ASM5, mv); } @Override - public void visitMethodInsn(final int opcode, final String owner, final String name, final String desc, final boolean itf) - { - if (isAwaitCall(opcode, owner, name, desc)) - { + public void visitMethodInsn(final int opcode, final String owner, final String name, final String desc, final boolean itf) { + if (isAwaitCall(opcode, owner, name, desc)) { // replaces invalid awaits with the join notifyError("Invalid use of await at %s.%s", cr.getClassName(), original.name); visitMethodInsn(INVOKEINTERFACE, COMPLETION_STAGE_NAME, "toCompletableFuture", "()Ljava/util/concurrent/CompletableFuture;", true); visitMethodInsn(INVOKEVIRTUAL, COMPLETABLE_FUTURE_NAME, JOIN_METHOD_NAME, JOIN_METHOD_DESC, false); - } - else if (isAwaitInitCall(opcode, owner, name, desc)) - { + } else if (isAwaitInitCall(opcode, owner, name, desc)) { // replaces all references to Await.init with NOP super.visitInsn(NOP); - } - else - { + } else { super.visitMethodInsn(opcode, owner, name, desc, itf); } } @@ -314,17 +281,14 @@ else if (isAwaitInitCall(opcode, owner, name, desc)) } } // no changes. - if (countInstrumented == 0) - { + if (countInstrumented == 0) { return null; } // avoiding using COMPUTE_FRAMES - final ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS) - { + final ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS) { @Override - protected String getCommonSuperClass(final String type1, final String type2) - { + protected String getCommonSuperClass(final String type1, final String type2) { // this is only called if COMPUTE_FRAMES is enabled // implementing this properly would require loading information @@ -343,79 +307,63 @@ protected String getCommonSuperClass(final String type1, final String type2) return bytes; } - private boolean returnsCompletionStageSubClass(ClassLoader classLoader, final MethodNode original) - { + private boolean returnsCompletionStageSubClass(ClassLoader classLoader, final MethodNode original) { final Type retType = Type.getReturnType(original.desc); - if (retType.getSort() != Type.OBJECT) - { + if (retType.getSort() != Type.OBJECT) { return false; } final String retTypeName = retType.getInternalName(); return isCompletionStage(classLoader, retTypeName); } - private boolean isCompletionStage(final ClassLoader classLoader, final String internalName) - { - if (COMPLETION_STAGE_NAME.equals(internalName)) - { + private boolean isCompletionStage(final ClassLoader classLoader, final String internalName) { + if (COMPLETION_STAGE_NAME.equals(internalName)) { return true; } - if (COMPLETABLE_FUTURE_NAME.equals(internalName)) - { + if (COMPLETABLE_FUTURE_NAME.equals(internalName)) { return true; } URL resource = classLoader.getResource(internalName + ".class"); Boolean aBoolean; - synchronized (futureMap) - { + synchronized(futureMap) { aBoolean = futureMap.get(resource); } - if (aBoolean != null) - { + if (aBoolean != null) { return aBoolean; } aBoolean = isCompletionStage(classLoader, classLoader.getResourceAsStream(internalName + ".class")); - synchronized (classLoaderMap) - { - Set urls = classLoaderMap.get(classLoader); - if (urls == null) - { + synchronized(classLoaderMap) { + Set < URL > urls = classLoaderMap.get(classLoader); + if (urls == null) { // class loader hold the references to the urls - classLoaderMap.put(classLoader, urls = new HashSet<>()); + classLoaderMap.put(classLoader, urls = new HashSet < > ()); } urls.add(resource); } - synchronized (futureMap) - { + synchronized(futureMap) { futureMap.put(resource, aBoolean); } return aBoolean; } - private boolean isCompletionStage(ClassLoader classLoader, InputStream resource) - { - if (resource == null) - { + private boolean isCompletionStage(ClassLoader classLoader, InputStream resource) { + if (resource == null) { return false; } - try - { + try { return isCompletionStage(classLoader, new ClassReader(resource).getSuperName()); - } - catch (IOException e) - { + } catch (IOException e) { return false; } } - private void transformAsyncMethod(final ClassNode classNode, final MethodNode original, final Map nameUseCount) throws AnalyzerException - { - final List switchEntries = new ArrayList<>(); + private void transformAsyncMethod(final ClassNode classNode, final MethodNode original, final Map < String, Integer > nameUseCount) throws AnalyzerException { + final List < SwitchEntry > switchEntries = new ArrayList < > (); SwitchEntry entryPoint; - List arguments = new ArrayList<>(); - final List