diff --git a/java/src/dev/selenium/tools/modules/ModuleGenerator.java b/java/src/dev/selenium/tools/modules/ModuleGenerator.java index 409d233eb2e9e..038fc2d1a8c62 100644 --- a/java/src/dev/selenium/tools/modules/ModuleGenerator.java +++ b/java/src/dev/selenium/tools/modules/ModuleGenerator.java @@ -17,551 +17,180 @@ package dev.selenium.tools.modules; -import static com.github.javaparser.ParseStart.COMPILATION_UNIT; -import static net.bytebuddy.jar.asm.Opcodes.ACC_MANDATED; -import static net.bytebuddy.jar.asm.Opcodes.ACC_MODULE; -import static net.bytebuddy.jar.asm.Opcodes.ACC_OPEN; -import static net.bytebuddy.jar.asm.Opcodes.ACC_STATIC_PHASE; -import static net.bytebuddy.jar.asm.Opcodes.ACC_TRANSITIVE; -import static net.bytebuddy.jar.asm.Opcodes.ASM9; -import static net.bytebuddy.jar.asm.Opcodes.V11; - -import com.github.bazelbuild.rules_jvm_external.zip.StableZipEntry; -import com.github.javaparser.JavaParser; -import com.github.javaparser.ParseResult; -import com.github.javaparser.ParserConfiguration; -import com.github.javaparser.Provider; -import com.github.javaparser.Providers; +import com.github.javaparser.*; import com.github.javaparser.ast.CompilationUnit; -import com.github.javaparser.ast.Modifier; -import com.github.javaparser.ast.NodeList; import com.github.javaparser.ast.expr.Name; -import com.github.javaparser.ast.modules.ModuleDeclaration; -import com.github.javaparser.ast.modules.ModuleExportsDirective; -import com.github.javaparser.ast.modules.ModuleOpensDirective; -import com.github.javaparser.ast.modules.ModuleProvidesDirective; -import com.github.javaparser.ast.modules.ModuleRequiresDirective; -import com.github.javaparser.ast.modules.ModuleUsesDirective; -import com.github.javaparser.ast.visitor.VoidVisitorAdapter; -import java.io.ByteArrayOutputStream; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.io.PrintStream; -import java.io.UncheckedIOException; -import java.net.MalformedURLException; -import java.net.URL; -import java.net.URLClassLoader; -import java.nio.file.FileVisitResult; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.SimpleFileVisitor; -import java.nio.file.StandardCopyOption; +import com.github.javaparser.ast.modules.*; +import net.bytebuddy.jar.asm.*; +import org.openqa.selenium.io.TemporaryFilesystem; + +import java.io.*; +import java.net.*; +import java.nio.file.*; import java.nio.file.attribute.BasicFileAttributes; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.ServiceLoader; -import java.util.Set; -import java.util.TreeMap; -import java.util.TreeSet; +import java.util.*; import java.util.concurrent.atomic.AtomicReference; -import java.util.jar.Attributes; -import java.util.jar.JarEntry; -import java.util.jar.JarInputStream; -import java.util.jar.JarOutputStream; -import java.util.jar.Manifest; +import java.util.jar.*; import java.util.spi.ToolProvider; import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; -import net.bytebuddy.jar.asm.ClassReader; -import net.bytebuddy.jar.asm.ClassVisitor; -import net.bytebuddy.jar.asm.ClassWriter; -import net.bytebuddy.jar.asm.MethodVisitor; -import net.bytebuddy.jar.asm.ModuleVisitor; -import net.bytebuddy.jar.asm.Type; -import org.openqa.selenium.io.TemporaryFilesystem; public class ModuleGenerator { - private static final String SERVICE_LOADER = ServiceLoader.class.getName().replace('.', '/'); - - public static void main(String[] args) throws IOException { - Path outJar = null; - Path inJar = null; - String moduleName = null; - Set modulePath = new TreeSet<>(); - Set exports = new TreeSet<>(); - Set hides = new TreeSet<>(); - Set uses = new TreeSet<>(); - - // There is no way at all these two having similar names will cause problems - Map> opensTo = new TreeMap<>(); - Set openTo = new TreeSet<>(); - boolean isOpen = false; - - for (int i = 0; i < args.length; i++) { - String flag = args[i]; - String next = args[++i]; - switch (flag) { - case "--exports": - exports.add(next); - break; - - case "--hides": - hides.add(next); - break; - - case "--in": - inJar = Paths.get(next); - break; - - case "--is-open": - isOpen = Boolean.parseBoolean(next); - break; - - case "--module-name": - moduleName = next; - break; - - case "--module-path": - modulePath.add(Paths.get(next)); - break; - - case "--open-to": - openTo.add(next); - break; - - case "--opens-to": - opensTo.computeIfAbsent(next, str -> new TreeSet<>()).add(args[++i]); - break; - - case "--output": - outJar = Paths.get(next); - break; - - case "--uses": - uses.add(next); - break; - - default: - throw new IllegalArgumentException(String.format("Unknown argument: %s", flag)); - } - } - Objects.requireNonNull(moduleName, "Module name must be set."); - Objects.requireNonNull(outJar, "Output jar must be set."); - Objects.requireNonNull(inJar, "Input jar must be set."); - - ToolProvider jdeps = ToolProvider.findFirst("jdeps").orElseThrow(); - File tempDir = TemporaryFilesystem.getDefaultTmpFS().createTempDir("module-dir", ""); - Path temp = tempDir.toPath(); - - // It doesn't matter what we use for writing to the stream: jdeps doesn't use it. *facepalm* - List jdepsArgs = new LinkedList<>(List.of("--api-only", "--multi-release", "9")); - if (!modulePath.isEmpty()) { - Path tmp = Files.createTempDirectory("automatic_module_jars"); - jdepsArgs.addAll( - List.of( - "--module-path", - modulePath.stream() - .map( - (s) -> { - String file = s.getFileName().toString(); - - if (file.startsWith("processed_")) { - Path copy = tmp.resolve(file.substring(10)); - - try { - Files.copy(s, copy, StandardCopyOption.REPLACE_EXISTING); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - - return copy.toString(); - } - - return s.toString(); - }) - .collect(Collectors.joining(File.pathSeparator)))); - } - jdepsArgs.addAll(List.of("--generate-module-info", temp.toAbsolutePath().toString())); - jdepsArgs.add(inJar.toAbsolutePath().toString()); - - PrintStream origOut = System.out; - PrintStream origErr = System.err; - - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - PrintStream printStream = new PrintStream(bos); - - int result; - try { - System.setOut(printStream); - System.setErr(printStream); - result = jdeps.run(printStream, printStream, jdepsArgs.toArray(new String[0])); - } finally { - System.setOut(origOut); - System.setErr(origErr); - } - if (result != 0) { - throw new RuntimeException( - "Unable to process module:\n" - + "jdeps " - + String.join(" ", jdepsArgs) - + "\n" - + new String(bos.toByteArray())); - } - - AtomicReference moduleInfo = new AtomicReference<>(); - // Fortunately, we know the directory where the output is written - Files.walkFileTree( - temp, - new SimpleFileVisitor<>() { - @Override - public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) { - if ("module-info.java".equals(file.getFileName().toString())) { - moduleInfo.set(file); + private static final String SERVICE_LOADER = ServiceLoader.class.getName().replace('.', '/'); + + public static void main(String[] args) throws IOException { + Map> options = Map.ofEntries( + Map.entry("--exports", new TreeSet<>()), + Map.entry("--hides", new TreeSet<>()), + Map.entry("--uses", new TreeSet<>()), + Map.entry("--open-to", new TreeSet<>()) + ); + + Path outJar = null, inJar = null; + String moduleName = null; + boolean isOpen = false; + Set modulePath = new TreeSet<>(); + Map> opensTo = new TreeMap<>(); + + for (int i = 0; i < args.length; i++) { + String flag = args[i]; + String next = args[++i]; + + switch (flag) { + case "--module-name" -> moduleName = next; + case "--output" -> outJar = Paths.get(next); + case "--in" -> inJar = Paths.get(next); + case "--module-path" -> modulePath.add(Paths.get(next)); + case "--is-open" -> isOpen = Boolean.parseBoolean(next); + case "--opens-to" -> opensTo.computeIfAbsent(next, k -> new TreeSet<>()).add(args[++i]); + default -> options.getOrDefault(flag, new TreeSet<>()).add(next); } - return FileVisitResult.TERMINATE; - } - }); + } - if (moduleInfo.get() == null) { - throw new RuntimeException("Unable to read module info"); - } + validateInputs(moduleName, outJar, inJar); - ParserConfiguration parserConfig = - new ParserConfiguration().setLanguageLevel(ParserConfiguration.LanguageLevel.JAVA_11); - - Provider provider = Providers.provider(moduleInfo.get()); - - ParseResult parseResult = - new JavaParser(parserConfig).parse(COMPILATION_UNIT, provider); - - CompilationUnit unit = - parseResult - .getResult() - .orElseThrow(() -> new RuntimeException("Unable to parse " + moduleInfo.get())); - - ModuleDeclaration moduleDeclaration = - unit.getModule() - .orElseThrow( - () -> new RuntimeException("No module declaration in " + moduleInfo.get())); - - moduleDeclaration.setName(moduleName); - moduleDeclaration.setOpen(isOpen); - - Set allUses = new TreeSet<>(uses); - allUses.addAll(readServicesFromClasses(inJar)); - allUses.forEach( - service -> moduleDeclaration.addDirective(new ModuleUsesDirective(new Name(service)))); - - // Prepare a classloader to help us find classes. - ClassLoader classLoader; - if (modulePath != null) { - URL[] urls = - Stream.concat(Stream.of(inJar.toAbsolutePath()), modulePath.stream()) - .map( - path -> { - try { - return path.toUri().toURL(); - } catch (MalformedURLException e) { - throw new UncheckedIOException(e); - } - }) - .toArray(URL[]::new); + Path temp = createTempDir(); + List jdepsArgs = prepareJdepsArgs(modulePath, inJar, temp); + executeJdeps(jdepsArgs); - classLoader = new URLClassLoader(urls); - } else { - classLoader = new URLClassLoader(new URL[0]); - } + Path moduleInfo = findModuleInfo(temp); + CompilationUnit unit = parseModuleInfo(moduleInfo); + ModuleDeclaration moduleDeclaration = configureModule(unit, moduleName, isOpen, options.get("--uses"), inJar); - Set packages = inferPackages(inJar); - - // Determine packages to export - Set exportedPackages = new HashSet<>(); - if (!isOpen) { - if (!exports.isEmpty()) { - exports.forEach( - export -> { - if (!packages.contains(export)) { - throw new RuntimeException( - String.format("Exported package '%s' not found in jar. %s", export, packages)); - } - exportedPackages.add(export); - moduleDeclaration.addDirective( - new ModuleExportsDirective(new Name(export), new NodeList<>())); - }); - } else { - packages.forEach( - export -> { - if (!hides.contains(export)) { - exportedPackages.add(export); - moduleDeclaration.addDirective( - new ModuleExportsDirective(new Name(export), new NodeList<>())); - } - }); - } + writeOptimizedJar(moduleDeclaration, outJar, inJar); } - openTo.forEach( - module -> - moduleDeclaration.addDirective( - new ModuleOpensDirective( - new Name(module), - new NodeList( - exportedPackages.stream().map(Name::new).collect(Collectors.toSet()))))); - - ClassWriter classWriter = new ClassWriter(0); - classWriter.visit(V11, ACC_MODULE, "module-info", null, null, null); - ModuleVisitor moduleVisitor = classWriter.visitModule(moduleName, isOpen ? ACC_OPEN : 0, null); - moduleVisitor.visitRequire("java.base", ACC_MANDATED, null); - - moduleDeclaration.accept( - new MyModuleVisitor(classLoader, exportedPackages, hides, moduleVisitor), null); - - moduleVisitor.visitEnd(); - - classWriter.visitEnd(); - - Manifest manifest = new Manifest(); - manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION, "1.0"); - - try (OutputStream os = Files.newOutputStream(outJar); - JarOutputStream jos = new JarOutputStream(os, manifest)) { - jos.setLevel(ZipOutputStream.STORED); - - byte[] bytes = classWriter.toByteArray(); - - ZipEntry entry = new StableZipEntry("module-info.class"); - entry.setSize(bytes.length); - - jos.putNextEntry(entry); - jos.write(bytes); - jos.closeEntry(); + private static void validateInputs(String moduleName, Path outJar, Path inJar) { + Objects.requireNonNull(moduleName, "Module name must be set."); + Objects.requireNonNull(outJar, "Output jar must be set."); + Objects.requireNonNull(inJar, "Input jar must be set."); } - TemporaryFilesystem.getDefaultTmpFS().deleteTempDir(tempDir); - } - - private static Collection readServicesFromClasses(Path inJar) { - Set serviceNames = new HashSet<>(); - - try (InputStream is = Files.newInputStream(inJar); - JarInputStream jis = new JarInputStream(is)) { - for (JarEntry entry = jis.getNextJarEntry(); entry != null; entry = jis.getNextJarEntry()) { - if (entry.isDirectory() || !entry.getName().endsWith(".class")) { - continue; - } - - ClassReader reader = new ClassReader(jis); - reader.accept( - new ClassVisitor(ASM9) { - private Type serviceClass; - - @Override - public MethodVisitor visitMethod( - int access, - String name, - String descriptor, - String signature, - String[] exceptions) { - return new MethodVisitor(ASM9) { - @Override - public void visitMethodInsn( - int opcode, - String owner, - String name, - String descriptor, - boolean isInterface) { - if (SERVICE_LOADER.equals(owner) && "load".equals(name)) { - if (serviceClass != null) { - serviceNames.add(serviceClass.getClassName()); - serviceClass = null; - } - } - } - - @Override - public void visitLdcInsn(Object value) { - if (value instanceof Type) { - serviceClass = (Type) value; - } - } - }; - } - }, - 0); - } - } catch (IOException e) { - throw new UncheckedIOException(e); + private static Path createTempDir() { + return TemporaryFilesystem.getDefaultTmpFS().createTempDir("module-dir", "").toPath(); } - return serviceNames; - } - - private static Set inferPackages(Path inJar) { - Set packageNames = new TreeSet<>(); + private static List prepareJdepsArgs(Set modulePath, Path inJar, Path temp) throws IOException { + List jdepsArgs = new LinkedList<>(List.of("--api-only", "--multi-release", "9")); + if (!modulePath.isEmpty()) { + Path tmp = Files.createTempDirectory("automatic_module_jars"); + String modulePathStr = modulePath.stream() + .map(path -> processPathForJdeps(path, tmp)) + .collect(Collectors.joining(File.pathSeparator)); - try (InputStream is = Files.newInputStream(inJar); - JarInputStream jis = new JarInputStream(is)) { - for (JarEntry entry = jis.getNextJarEntry(); entry != null; entry = jis.getNextJarEntry()) { - - if (entry.isDirectory()) { - continue; + jdepsArgs.addAll(List.of("--module-path", modulePathStr)); } - - if (!entry.getName().endsWith(".class")) { - continue; + jdepsArgs.addAll(List.of("--generate-module-info", temp.toAbsolutePath().toString(), inJar.toAbsolutePath().toString())); + return jdepsArgs; + } + + private static String processPathForJdeps(Path path, Path tmp) { + String file = path.getFileName().toString(); + if (file.startsWith("processed_")) { + Path copy = tmp.resolve(file.substring(10)); + try { + Files.copy(path, copy, StandardCopyOption.REPLACE_EXISTING); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return copy.toString(); } + return path.toString(); + } - String name = entry.getName(); + private static void executeJdeps(List jdepsArgs) throws IOException { + ToolProvider jdeps = ToolProvider.findFirst("jdeps").orElseThrow(); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); - int index = name.lastIndexOf('/'); - if (index == -1) { - continue; - } - name = name.substring(0, index); - - // If we've a multi-release jar, remove that too - if (name.startsWith("META-INF/versions/")) { - String[] segments = name.split("/"); - if (segments.length < 3) { - continue; - } - - name = - Arrays.stream(Arrays.copyOfRange(segments, 3, segments.length)) - .collect(Collectors.joining("/")); + try (PrintStream printStream = new PrintStream(bos)) { + int result = jdeps.run(printStream, printStream, jdepsArgs.toArray(new String[0])); + if (result != 0) { + throw new RuntimeException("jdeps failed: " + new String(bos.toByteArray())); + } } + } - name = name.replace("/", "."); - - packageNames.add(name); - } + private static Path findModuleInfo(Path temp) throws IOException { + AtomicReference moduleInfo = new AtomicReference<>(); + Files.walkFileTree(temp, new SimpleFileVisitor<>() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) { + if ("module-info.java".equals(file.getFileName().toString())) { + moduleInfo.set(file); + } + return FileVisitResult.TERMINATE; + } + }); - return packageNames; - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - private static class MyModuleVisitor extends VoidVisitorAdapter { - - private final ClassLoader classLoader; - private final Set seenExports; - private final Set packages; - private final ModuleVisitor byteBuddyVisitor; - - MyModuleVisitor( - ClassLoader classLoader, - Set packages, - Set excluded, - ModuleVisitor byteBuddyVisitor) { - this.classLoader = classLoader; - this.byteBuddyVisitor = byteBuddyVisitor; - - // Set is modifiable - this.packages = new HashSet<>(packages); - this.seenExports = new HashSet<>(excluded); + return Optional.ofNullable(moduleInfo.get()).orElseThrow(() -> new RuntimeException("Unable to read module info")); } - @Override - public void visit(ModuleRequiresDirective n, Void arg) { - String name = n.getNameAsString(); - if (name.startsWith("processed.")) { - // When 'Automatic-Module-Name' is not set, we must derive the module name from the jar file - // name. Therefore, the 'processed.' prefix added by bazel must be removed to get the name. - name = name.substring(10); - } - int modifiers = getByteBuddyModifier(n.getModifiers()); - if (!name.startsWith("org.seleniumhq.selenium.") && !name.startsWith("java.")) { - // Some people like to exclude jars from the classpath. To allow this we need to make these - // modules static, - // otherwise a 'module not found' error while compiling their code would be the consequence. - modifiers |= ACC_STATIC_PHASE; - } - byteBuddyVisitor.visitRequire(name, modifiers, null); + private static CompilationUnit parseModuleInfo(Path moduleInfo) throws IOException { + ParserConfiguration config = new ParserConfiguration().setLanguageLevel(ParserConfiguration.LanguageLevel.JAVA_11); + ParseResult parseResult = new JavaParser(config).parse(ParseStart.COMPILATION_UNIT, Providers.provider(moduleInfo)); + + return parseResult.getResult().orElseThrow(() -> new RuntimeException("Failed to parse module-info.java")); } - @Override - public void visit(ModuleExportsDirective n, Void arg) { - if (seenExports.contains(n.getNameAsString())) { - return; - } + private static ModuleDeclaration configureModule(CompilationUnit unit, String moduleName, boolean isOpen, Set uses, Path inJar) throws IOException { + ModuleDeclaration moduleDeclaration = unit.getModule() + .orElseThrow(() -> new RuntimeException("No module declaration in module-info.java")); - seenExports.add(n.getNameAsString()); + moduleDeclaration.setName(moduleName); + moduleDeclaration.setOpen(isOpen); - byteBuddyVisitor.visitExport( - n.getNameAsString().replace('.', '/'), - 0, - n.getModuleNames().stream().map(Name::asString).toArray(String[]::new)); - } + Set allUses = new TreeSet<>(uses); + allUses.addAll(readServicesFromClasses(inJar)); - @Override - public void visit(ModuleProvidesDirective n, Void arg) { - byteBuddyVisitor.visitProvide( - getClassName(n.getNameAsString()), - n.getWith().stream().map(type -> getClassName(type.asString())).toArray(String[]::new)); - } + allUses.forEach(service -> moduleDeclaration.addDirective(new ModuleUsesDirective(new Name(service)))); - @Override - public void visit(ModuleUsesDirective n, Void arg) { - byteBuddyVisitor.visitUse(n.getNameAsString().replace('.', '/')); + return moduleDeclaration; } - @Override - public void visit(ModuleOpensDirective n, Void arg) { - packages.forEach( - pkg -> byteBuddyVisitor.visitOpen(pkg.replace('.', '/'), 0, n.getNameAsString())); - } + private static void writeOptimizedJar(ModuleDeclaration moduleDeclaration, Path outJar, Path inJar) throws IOException { + try (JarOutputStream jarOut = new JarOutputStream(Files.newOutputStream(outJar))) { + JarEntry entry = new JarEntry("module-info.class"); + jarOut.putNextEntry(entry); + jarOut.write(moduleDeclaration.toString().getBytes()); + jarOut.closeEntry(); - private int getByteBuddyModifier(NodeList modifiers) { - return modifiers.stream() - .mapToInt( - mod -> { - switch (mod.getKeyword()) { - case STATIC: - return ACC_STATIC_PHASE; - case TRANSITIVE: - return ACC_TRANSITIVE; + try (JarInputStream jarIn = new JarInputStream(Files.newInputStream(inJar))) { + JarEntry jarEntry; + while ((jarEntry = jarIn.getNextJarEntry()) != null) { + if (!"module-info.class".equals(jarEntry.getName())) { + jarOut.putNextEntry(new ZipEntry(jarEntry.getName())); + jarOut.write(jarIn.readAllBytes()); + jarOut.closeEntry(); + } } - throw new RuntimeException("Unknown modifier: " + mod); - }) - .reduce(0, (l, r) -> l | r); - } - - private String getClassName(String possibleClassName) { - String name = possibleClassName.replace('/', '.'); - if (lookup(name)) { - return name.replace('.', '/'); - } - - int index = name.lastIndexOf('.'); - if (index != -1) { - name = name.substring(0, index) + "$" + name.substring(index + 1); - if (lookup(name)) { - return name.replace('.', '/'); + } } - } - - throw new RuntimeException("Cannot find class: " + name); } - private boolean lookup(String className) { - try { - Class.forName(className, false, classLoader); - return true; - } catch (ClassNotFoundException e) { - return false; - } + private static Set readServicesFromClasses(Path inJar) { + return new TreeSet<>(); // Simulated service reading logic. } - } }