Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod>
return InstrumenterImpl.create(clazz, methods);
}

@Override
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
Map<MethodKey, CheckMethod> methodsToInstrument = new HashMap<>();
private interface CheckerMethodVisitor {
void visit(Class<?> currentClass, int access, String checkerMethodName, String checkerMethodDescriptor);
}

private void visitClassAndSupers(Class<?> checkerClass, CheckerMethodVisitor checkerMethodVisitor) throws ClassNotFoundException {
Set<Class<?>> visitedClasses = new HashSet<>();
ArrayDeque<Class<?>> classesToVisit = new ArrayDeque<>(Collections.singleton(checkerClass));
while (classesToVisit.isEmpty() == false) {
Expand All @@ -57,67 +58,76 @@ public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws I
}
visitedClasses.add(currentClass);

var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
try {
var classFileInfo = InstrumenterImpl.getClassFileInfo(currentClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {

@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
try {
if (OBJECT_INTERNAL_NAME.equals(superName) == false) {
classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName()));
@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
try {
if (OBJECT_INTERNAL_NAME.equals(superName) == false) {
classesToVisit.add(Class.forName(Type.getObjectType(superName).getClassName()));
}
for (var interfaceName : interfaces) {
classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName()));
}
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Cannot inspect checker class " + currentClass.getName(), e);
}
for (var interfaceName : interfaces) {
classesToVisit.add(Class.forName(Type.getObjectType(interfaceName).getClassName()));
}
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Cannot inspect checker class " + checkerClass.getName(), e);
}
}

@Override
public MethodVisitor visitMethod(
int access,
String checkerMethodName,
String checkerMethodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);

var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
var checkMethod = new CheckMethod(
Type.getInternalName(currentClass),
checkerMethodName,
checkerParameterDescriptors
);

methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod);
@Override
public MethodVisitor visitMethod(
int access,
String checkerMethodName,
String checkerMethodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, checkerMethodName, checkerMethodDescriptor, signature, exceptions);
checkerMethodVisitor.visit(currentClass, access, checkerMethodName, checkerMethodDescriptor);
return mv;
}
return mv;
}
};
reader.accept(visitor, 0);
};
reader.accept(visitor, 0);
} catch (IOException e) {
throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e);
}
}
}

@Override
public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws ClassNotFoundException {
Map<MethodKey, CheckMethod> methodsToInstrument = new HashMap<>();

visitClassAndSupers(checkerClass, (currentClass, access, checkerMethodName, checkerMethodDescriptor) -> {
if (checkerMethodName.startsWith(InstrumentationService.CHECK_METHOD_PREFIX)) {
var checkerMethodArgumentTypes = Type.getArgumentTypes(checkerMethodDescriptor);
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);

var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
var checkMethod = new CheckMethod(Type.getInternalName(currentClass), checkerMethodName, checkerParameterDescriptors);
methodsToInstrument.putIfAbsent(methodToInstrument, checkMethod);
}
});

return methodsToInstrument;
}

@SuppressForbidden(reason = "Need access to abstract methods (protected/package internal) in base class")
@Override
public InstrumentationInfo lookupImplementationMethod(
Class<?> targetSuperclass,
String methodName,
String targetMethodName,
Class<?> implementationClass,
Class<?> checkerClass,
String checkMethodName,
Class<?>... parameterTypes
) throws NoSuchMethodException, ClassNotFoundException {

var targetMethod = targetSuperclass.getDeclaredMethod(methodName, parameterTypes);
var targetMethod = targetSuperclass.getDeclaredMethod(targetMethodName, parameterTypes);
var implementationMethod = implementationClass.getMethod(targetMethod.getName(), targetMethod.getParameterTypes());
validateTargetMethod(implementationClass, targetMethod, implementationMethod);

Expand All @@ -128,33 +138,15 @@ public InstrumentationInfo lookupImplementationMethod(

CheckMethod[] checkMethod = new CheckMethod[1];

try {
InstrumenterImpl.ClassFileInfo classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {
@Override
public MethodVisitor visitMethod(
int access,
String methodName,
String methodDescriptor,
String signature,
String[] exceptions
) {
var mv = super.visitMethod(access, methodName, methodDescriptor, signature, exceptions);
if (methodName.equals(checkMethodName)) {
var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor);
if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) {
var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList();
checkMethod[0] = new CheckMethod(Type.getInternalName(checkerClass), methodName, checkerParameterDescriptors);
}
}
return mv;
visitClassAndSupers(checkerClass, (currentClass, access, methodName, methodDescriptor) -> {
if (methodName.equals(checkMethodName)) {
var methodArgumentTypes = Type.getArgumentTypes(methodDescriptor);
if (Arrays.equals(methodArgumentTypes, checkMethodArgumentTypes)) {
var checkerParameterDescriptors = Arrays.stream(methodArgumentTypes).map(Type::getDescriptor).toList();
checkMethod[0] = new CheckMethod(Type.getInternalName(currentClass), methodName, checkerParameterDescriptors);
}
};
reader.accept(visitor, 0);
} catch (IOException e) {
throw new ClassNotFoundException("Cannot find a definition for class [" + checkerClass.getName() + "]", e);
}
}
});

if (checkMethod[0] == null) {
throw new NoSuchMethodException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.test.ESTestCase;
import org.objectweb.asm.Type;

import java.io.IOException;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -90,7 +89,7 @@ interface TestCheckerMixed {
void checkInstanceMethodManual(Class<?> clazz, TestTargetBaseClass that, int x, String y);
}

public void testInstrumentationTargetLookup() throws IOException {
public void testInstrumentationTargetLookup() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);

assertThat(checkMethods, aMapWithSize(3));
Expand Down Expand Up @@ -143,7 +142,7 @@ public void testInstrumentationTargetLookup() throws IOException {
);
}

public void testInstrumentationTargetLookupWithOverloads() throws IOException {
public void testInstrumentationTargetLookupWithOverloads() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class);

assertThat(checkMethods, aMapWithSize(2));
Expand Down Expand Up @@ -175,7 +174,7 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException {
);
}

public void testInstrumentationTargetLookupWithDerivedClass() throws IOException {
public void testInstrumentationTargetLookupWithDerivedClass() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerDerived2.class);

assertThat(checkMethods, aMapWithSize(4));
Expand Down Expand Up @@ -244,7 +243,7 @@ public void testInstrumentationTargetLookupWithDerivedClass() throws IOException
);
}

public void testInstrumentationTargetLookupWithCtors() throws IOException {
public void testInstrumentationTargetLookupWithCtors() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);

assertThat(checkMethods, aMapWithSize(2));
Expand Down Expand Up @@ -276,7 +275,7 @@ public void testInstrumentationTargetLookupWithCtors() throws IOException {
);
}

public void testInstrumentationTargetLookupWithExtraMethods() throws IOException {
public void testInstrumentationTargetLookupWithExtraMethods() throws ClassNotFoundException {
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerMixed.class);

assertThat(checkMethods, aMapWithSize(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketImplFactory;
import java.net.URI;
import java.net.URL;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
Expand All @@ -43,17 +44,24 @@
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.nio.charset.Charset;
import java.nio.file.AccessMode;
import java.nio.file.CopyOption;
import java.nio.file.DirectoryStream;
import java.nio.file.FileStore;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.nio.file.attribute.UserPrincipal;
import java.nio.file.spi.FileSystemProvider;
import java.security.cert.CertStoreParameters;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.TimeZone;
import java.util.concurrent.ExecutorService;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
Expand Down Expand Up @@ -491,8 +499,75 @@ public interface EntitlementChecker {
void check$java_nio_file_Files$$setOwner(Class<?> callerClass, Path path, UserPrincipal principal);

// file system providers
void check$java_nio_file_spi_FileSystemProvider$(Class<?> callerClass);

void checkNewFileSystem(Class<?> callerClass, FileSystemProvider that, URI uri, Map<String, ?> env);

void checkNewFileSystem(Class<?> callerClass, FileSystemProvider that, Path path, Map<String, ?> env);

void checkNewInputStream(Class<?> callerClass, FileSystemProvider that, Path path, OpenOption... options);

void checkNewOutputStream(Class<?> callerClass, FileSystemProvider that, Path path, OpenOption... options);

void checkNewFileChannel(
Class<?> callerClass,
FileSystemProvider that,
Path path,
Set<? extends OpenOption> options,
FileAttribute<?>... attrs
);

void checkNewAsynchronousFileChannel(
Class<?> callerClass,
FileSystemProvider that,
Path path,
Set<? extends OpenOption> options,
ExecutorService executor,
FileAttribute<?>... attrs
);

void checkNewByteChannel(
Class<?> callerClass,
FileSystemProvider that,
Path path,
Set<? extends OpenOption> options,
FileAttribute<?>... attrs
);

void checkNewDirectoryStream(Class<?> callerClass, FileSystemProvider that, Path dir, DirectoryStream.Filter<? super Path> filter);

void checkCreateDirectory(Class<?> callerClass, FileSystemProvider that, Path dir, FileAttribute<?>... attrs);

void checkCreateSymbolicLink(Class<?> callerClass, FileSystemProvider that, Path link, Path target, FileAttribute<?>... attrs);

void checkCreateLink(Class<?> callerClass, FileSystemProvider that, Path link, Path existing);

void checkDelete(Class<?> callerClass, FileSystemProvider that, Path path);

void checkDeleteIfExists(Class<?> callerClass, FileSystemProvider that, Path path);

void checkReadSymbolicLink(Class<?> callerClass, FileSystemProvider that, Path link);

void checkCopy(Class<?> callerClass, FileSystemProvider that, Path source, Path target, CopyOption... options);

void checkMove(Class<?> callerClass, FileSystemProvider that, Path source, Path target, CopyOption... options);

void checkIsSameFile(Class<?> callerClass, FileSystemProvider that, Path path, Path path2);

void checkIsHidden(Class<?> callerClass, FileSystemProvider that, Path path);

void checkGetFileStore(Class<?> callerClass, FileSystemProvider that, Path path);

void checkCheckAccess(Class<?> callerClass, FileSystemProvider that, Path path, AccessMode... modes);

void checkGetFileAttributeView(Class<?> callerClass, FileSystemProvider that, Path path, Class<?> type, LinkOption... options);

void checkReadAttributes(Class<?> callerClass, FileSystemProvider that, Path path, Class<?> type, LinkOption... options);

void checkReadAttributes(Class<?> callerClass, FileSystemProvider that, Path path, String attributes, LinkOption... options);

void checkSetAttribute(Class<?> callerClass, FileSystemProvider that, Path path, String attribute, Object value, LinkOption... options);

// file store
void checkGetFileStoreAttributeView(Class<?> callerClass, FileStore that, Class<?> type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.spi.FileSystemProvider;

/**
* Interface with Java20 "stable" functions and types.
Expand All @@ -32,4 +35,8 @@ public interface Java20StableEntitlementChecker extends EntitlementChecker {
FunctionDescriptor function,
Linker.Option... options
);

void checkReadAttributesIfExists(Class<?> callerClass, FileSystemProvider that, Path path, Class<?> type, LinkOption... options);

void checkExists(Class<?> callerClass, FileSystemProvider that, Path path, LinkOption... options);
}
Loading