| 
 | 1 | +/*  | 
 | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one  | 
 | 3 | + * or more contributor license agreements. Licensed under the "Elastic License  | 
 | 4 | + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side  | 
 | 5 | + * Public License v 1"; you may not use this file except in compliance with, at  | 
 | 6 | + * your election, the "Elastic License 2.0", the "GNU Affero General Public  | 
 | 7 | + * License v3.0 only", or the "Server Side Public License, v 1".  | 
 | 8 | + */  | 
 | 9 | + | 
 | 10 | +package org.elasticsearch.entitlement.instrumentation.impl;  | 
 | 11 | + | 
 | 12 | +import org.elasticsearch.entitlement.instrumentation.Instrumenter;  | 
 | 13 | +import org.elasticsearch.entitlement.instrumentation.MethodKey;  | 
 | 14 | +import org.objectweb.asm.AnnotationVisitor;  | 
 | 15 | +import org.objectweb.asm.ClassReader;  | 
 | 16 | +import org.objectweb.asm.ClassVisitor;  | 
 | 17 | +import org.objectweb.asm.ClassWriter;  | 
 | 18 | +import org.objectweb.asm.MethodVisitor;  | 
 | 19 | +import org.objectweb.asm.Opcodes;  | 
 | 20 | +import org.objectweb.asm.Type;  | 
 | 21 | + | 
 | 22 | +import java.io.IOException;  | 
 | 23 | +import java.io.InputStream;  | 
 | 24 | +import java.lang.reflect.Method;  | 
 | 25 | +import java.util.Map;  | 
 | 26 | +import java.util.stream.Stream;  | 
 | 27 | + | 
 | 28 | +import static org.objectweb.asm.ClassWriter.COMPUTE_FRAMES;  | 
 | 29 | +import static org.objectweb.asm.ClassWriter.COMPUTE_MAXS;  | 
 | 30 | +import static org.objectweb.asm.Opcodes.ACC_STATIC;  | 
 | 31 | +import static org.objectweb.asm.Opcodes.GETSTATIC;  | 
 | 32 | +import static org.objectweb.asm.Opcodes.INVOKEINTERFACE;  | 
 | 33 | +import static org.objectweb.asm.Opcodes.INVOKESTATIC;  | 
 | 34 | +import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL;  | 
 | 35 | + | 
 | 36 | +public class InstrumenterImpl implements Instrumenter {  | 
 | 37 | +    /**  | 
 | 38 | +     * To avoid class name collisions during testing without an agent to replace classes in-place.  | 
 | 39 | +     */  | 
 | 40 | +    private final String classNameSuffix;  | 
 | 41 | +    private final Map<MethodKey, Method> instrumentationMethods;  | 
 | 42 | + | 
 | 43 | +    public InstrumenterImpl(String classNameSuffix, Map<MethodKey, Method> instrumentationMethods) {  | 
 | 44 | +        this.classNameSuffix = classNameSuffix;  | 
 | 45 | +        this.instrumentationMethods = instrumentationMethods;  | 
 | 46 | +    }  | 
 | 47 | + | 
 | 48 | +    public ClassFileInfo instrumentClassFile(Class<?> clazz) throws IOException {  | 
 | 49 | +        ClassFileInfo initial = getClassFileInfo(clazz);  | 
 | 50 | +        return new ClassFileInfo(initial.fileName(), instrumentClass(Type.getInternalName(clazz), initial.bytecodes()));  | 
 | 51 | +    }  | 
 | 52 | + | 
 | 53 | +    public static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {  | 
 | 54 | +        String internalName = Type.getInternalName(clazz);  | 
 | 55 | +        String fileName = "/" + internalName + ".class";  | 
 | 56 | +        byte[] originalBytecodes;  | 
 | 57 | +        try (InputStream classStream = clazz.getResourceAsStream(fileName)) {  | 
 | 58 | +            if (classStream == null) {  | 
 | 59 | +                throw new IllegalStateException("Classfile not found in jar: " + fileName);  | 
 | 60 | +            }  | 
 | 61 | +            originalBytecodes = classStream.readAllBytes();  | 
 | 62 | +        }  | 
 | 63 | +        return new ClassFileInfo(fileName, originalBytecodes);  | 
 | 64 | +    }  | 
 | 65 | + | 
 | 66 | +    @Override  | 
 | 67 | +    public byte[] instrumentClass(String className, byte[] classfileBuffer) {  | 
 | 68 | +        ClassReader reader = new ClassReader(classfileBuffer);  | 
 | 69 | +        ClassWriter writer = new ClassWriter(reader, COMPUTE_FRAMES | COMPUTE_MAXS);  | 
 | 70 | +        ClassVisitor visitor = new EntitlementClassVisitor(Opcodes.ASM9, writer, className);  | 
 | 71 | +        reader.accept(visitor, 0);  | 
 | 72 | +        return writer.toByteArray();  | 
 | 73 | +    }  | 
 | 74 | + | 
 | 75 | +    class EntitlementClassVisitor extends ClassVisitor {  | 
 | 76 | +        final String className;  | 
 | 77 | + | 
 | 78 | +        EntitlementClassVisitor(int api, ClassVisitor classVisitor, String className) {  | 
 | 79 | +            super(api, classVisitor);  | 
 | 80 | +            this.className = className;  | 
 | 81 | +        }  | 
 | 82 | + | 
 | 83 | +        @Override  | 
 | 84 | +        public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {  | 
 | 85 | +            super.visit(version, access, name + classNameSuffix, signature, superName, interfaces);  | 
 | 86 | +        }  | 
 | 87 | + | 
 | 88 | +        @Override  | 
 | 89 | +        public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {  | 
 | 90 | +            var mv = super.visitMethod(access, name, descriptor, signature, exceptions);  | 
 | 91 | +            boolean isStatic = (access & ACC_STATIC) != 0;  | 
 | 92 | +            var key = new MethodKey(  | 
 | 93 | +                className,  | 
 | 94 | +                name,  | 
 | 95 | +                Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList(),  | 
 | 96 | +                isStatic  | 
 | 97 | +            );  | 
 | 98 | +            var instrumentationMethod = instrumentationMethods.get(key);  | 
 | 99 | +            if (instrumentationMethod != null) {  | 
 | 100 | +                // LOGGER.debug("Will instrument method {}", key);  | 
 | 101 | +                return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, descriptor, instrumentationMethod);  | 
 | 102 | +            } else {  | 
 | 103 | +                // LOGGER.trace("Will not instrument method {}", key);  | 
 | 104 | +            }  | 
 | 105 | +            return mv;  | 
 | 106 | +        }  | 
 | 107 | +    }  | 
 | 108 | + | 
 | 109 | +    static class EntitlementMethodVisitor extends MethodVisitor {  | 
 | 110 | +        private final boolean instrumentedMethodIsStatic;  | 
 | 111 | +        private final String instrumentedMethodDescriptor;  | 
 | 112 | +        private final Method instrumentationMethod;  | 
 | 113 | +        private boolean hasCallerSensitiveAnnotation = false;  | 
 | 114 | + | 
 | 115 | +        EntitlementMethodVisitor(  | 
 | 116 | +            int api,  | 
 | 117 | +            MethodVisitor methodVisitor,  | 
 | 118 | +            boolean instrumentedMethodIsStatic,  | 
 | 119 | +            String instrumentedMethodDescriptor,  | 
 | 120 | +            Method instrumentationMethod  | 
 | 121 | +        ) {  | 
 | 122 | +            super(api, methodVisitor);  | 
 | 123 | +            this.instrumentedMethodIsStatic = instrumentedMethodIsStatic;  | 
 | 124 | +            this.instrumentedMethodDescriptor = instrumentedMethodDescriptor;  | 
 | 125 | +            this.instrumentationMethod = instrumentationMethod;  | 
 | 126 | +        }  | 
 | 127 | + | 
 | 128 | +        @Override  | 
 | 129 | +        public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {  | 
 | 130 | +            if (visible && descriptor.endsWith("CallerSensitive;")) {  | 
 | 131 | +                hasCallerSensitiveAnnotation = true;  | 
 | 132 | +            }  | 
 | 133 | +            return super.visitAnnotation(descriptor, visible);  | 
 | 134 | +        }  | 
 | 135 | + | 
 | 136 | +        @Override  | 
 | 137 | +        public void visitCode() {  | 
 | 138 | +            pushEntitlementChecksObject();  | 
 | 139 | +            pushCallerClass();  | 
 | 140 | +            forwardIncomingArguments();  | 
 | 141 | +            invokeInstrumentationMethod();  | 
 | 142 | +            super.visitCode();  | 
 | 143 | +        }  | 
 | 144 | + | 
 | 145 | +        private void pushEntitlementChecksObject() {  | 
 | 146 | +            mv.visitMethodInsn(  | 
 | 147 | +                INVOKESTATIC,  | 
 | 148 | +                "org/elasticsearch/entitlement/api/EntitlementProvider",  | 
 | 149 | +                "checks",  | 
 | 150 | +                "()Lorg/elasticsearch/entitlement/api/EntitlementChecks;",  | 
 | 151 | +                false  | 
 | 152 | +            );  | 
 | 153 | +        }  | 
 | 154 | + | 
 | 155 | +        private void pushCallerClass() {  | 
 | 156 | +            if (hasCallerSensitiveAnnotation) {  | 
 | 157 | +                mv.visitMethodInsn(  | 
 | 158 | +                    INVOKESTATIC,  | 
 | 159 | +                    "jdk/internal/reflect/Reflection",  | 
 | 160 | +                    "getCallerClass",  | 
 | 161 | +                    Type.getMethodDescriptor(Type.getType(Class.class)),  | 
 | 162 | +                    false  | 
 | 163 | +                );  | 
 | 164 | +            } else {  | 
 | 165 | +                mv.visitFieldInsn(  | 
 | 166 | +                    GETSTATIC,  | 
 | 167 | +                    Type.getInternalName(StackWalker.Option.class),  | 
 | 168 | +                    "RETAIN_CLASS_REFERENCE",  | 
 | 169 | +                    Type.getDescriptor(StackWalker.Option.class)  | 
 | 170 | +                );  | 
 | 171 | +                mv.visitMethodInsn(  | 
 | 172 | +                    INVOKESTATIC,  | 
 | 173 | +                    Type.getInternalName(StackWalker.class),  | 
 | 174 | +                    "getInstance",  | 
 | 175 | +                    Type.getMethodDescriptor(Type.getType(StackWalker.class), Type.getType(StackWalker.Option.class)),  | 
 | 176 | +                    false  | 
 | 177 | +                );  | 
 | 178 | +                mv.visitMethodInsn(  | 
 | 179 | +                    INVOKEVIRTUAL,  | 
 | 180 | +                    Type.getInternalName(StackWalker.class),  | 
 | 181 | +                    "getCallerClass",  | 
 | 182 | +                    Type.getMethodDescriptor(Type.getType(Class.class)),  | 
 | 183 | +                    false  | 
 | 184 | +                );  | 
 | 185 | +            }  | 
 | 186 | +        }  | 
 | 187 | + | 
 | 188 | +        private void forwardIncomingArguments() {  | 
 | 189 | +            int localVarIndex = 0;  | 
 | 190 | +            if (instrumentedMethodIsStatic == false) {  | 
 | 191 | +                mv.visitVarInsn(Opcodes.ALOAD, localVarIndex++);  | 
 | 192 | +            }  | 
 | 193 | +            for (Type type : Type.getArgumentTypes(instrumentedMethodDescriptor)) {  | 
 | 194 | +                mv.visitVarInsn(type.getOpcode(Opcodes.ILOAD), localVarIndex);  | 
 | 195 | +                localVarIndex += type.getSize();  | 
 | 196 | +            }  | 
 | 197 | + | 
 | 198 | +        }  | 
 | 199 | + | 
 | 200 | +        private void invokeInstrumentationMethod() {  | 
 | 201 | +            mv.visitMethodInsn(  | 
 | 202 | +                INVOKEINTERFACE,  | 
 | 203 | +                Type.getInternalName(instrumentationMethod.getDeclaringClass()),  | 
 | 204 | +                instrumentationMethod.getName(),  | 
 | 205 | +                Type.getMethodDescriptor(instrumentationMethod),  | 
 | 206 | +                true  | 
 | 207 | +            );  | 
 | 208 | +        }  | 
 | 209 | +    }  | 
 | 210 | + | 
 | 211 | +    // private static final Logger LOGGER = LogManager.getLogger(Instrumenter.class);  | 
 | 212 | + | 
 | 213 | +    public record ClassFileInfo(String fileName, byte[] bytecodes) {}  | 
 | 214 | +}  | 
0 commit comments