|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +package chisel3.test |
| 4 | + |
| 5 | +import chisel3.experimental.BaseModule |
| 6 | +import chisel3.experimental.hierarchy.Definition |
| 7 | +import chisel3.RawModule |
| 8 | +import java.io.File |
| 9 | +import java.util.jar.JarFile |
| 10 | +import scala.collection.JavaConverters._ |
| 11 | + |
| 12 | +/** All classes and objects marked as [[UnitTest]] are automatically |
| 13 | + * discoverable by the `DiscoverUnitTests` helper. |
| 14 | + */ |
| 15 | +trait UnitTest |
| 16 | + |
| 17 | +/** Helper to discover all subtypes of [[UnitTest]] in the class path, and call |
| 18 | + * their constructors (if they are a class) or ensure that the singleton is |
| 19 | + * constructed (if they are an object). |
| 20 | + * |
| 21 | + * This code is loosely based on the test suite discovery in scalatest, which |
| 22 | + * performs the same scan over the classpath JAR files and directories, and |
| 23 | + * guesses class names based on the encountered directory structure. |
| 24 | + */ |
| 25 | +private[chisel3] object DiscoverUnitTests { |
| 26 | + |
| 27 | + /** The callback invoked for each unit test class name and unit test |
| 28 | + * constructor. |
| 29 | + */ |
| 30 | + type Callback = (String, () => Unit) => Unit |
| 31 | + |
| 32 | + /** Discover all tests in the classpath and call `cb` for each. */ |
| 33 | + def apply(cb: Callback): Unit = classpath().foreach(discoverFile(_, cb)) |
| 34 | + |
| 35 | + /** Return the a sequence of files or directories on the classpath. */ |
| 36 | + private def classpath(): Iterable[File] = System |
| 37 | + .getProperty("java.class.path") |
| 38 | + .split(File.pathSeparator) |
| 39 | + .map(s => if (s.trim.length == 0) "." else s) |
| 40 | + .map(new File(_)) |
| 41 | + |
| 42 | + /** Discover all tests in a given file. If this is a JAR file, looks through |
| 43 | + * its contents and tries to find its classes. |
| 44 | + */ |
| 45 | + private def discoverFile(file: File, cb: Callback): Unit = file match { |
| 46 | + // Unzip JAR files and process the class files they contain. |
| 47 | + case _ if file.getPath.toLowerCase.endsWith(".jar") => |
| 48 | + val jarFile = new java.util.jar.JarFile(file) |
| 49 | + jarFile.entries.asScala.foreach { jarEntry => |
| 50 | + val name = jarEntry.getName |
| 51 | + if (!jarEntry.isDirectory && name.endsWith(".class")) |
| 52 | + discoverClass(pathToClassName(name), cb) |
| 53 | + } |
| 54 | + |
| 55 | + // Recursively collect any class files contained in directories. |
| 56 | + case _ if file.isDirectory => |
| 57 | + def visit(prefix: String, file: File): Unit = { |
| 58 | + val name = prefix + "/" + file.getName |
| 59 | + if (file.isDirectory) { |
| 60 | + for (entry <- file.listFiles) |
| 61 | + visit(name, entry) |
| 62 | + } else if (name.endsWith(".class")) { |
| 63 | + discoverClass(pathToClassName(name), cb) |
| 64 | + } |
| 65 | + } |
| 66 | + for (entry <- file.listFiles) |
| 67 | + visit("", entry) |
| 68 | + |
| 69 | + // Ignore any other files that aren't directories. |
| 70 | + case _ => () |
| 71 | + } |
| 72 | + |
| 73 | + /** Convert a file path to a class */ |
| 74 | + private def pathToClassName(path: String): String = |
| 75 | + path.replace('/', '.').replace('\\', '.').stripPrefix(".").stripSuffix(".class") |
| 76 | + |
| 77 | + /** Load the given class and check whether it is a subtype of [[UnitTest]]. If |
| 78 | + * it is, call the user-provided callback with a function that either calls |
| 79 | + * the loaded class' constructor or ensures the loaded object is constructed. |
| 80 | + */ |
| 81 | + private def discoverClass(className: String, cb: Callback): Unit = { |
| 82 | + val clazz = |
| 83 | + try { |
| 84 | + classOf[UnitTest].getClassLoader.loadClass(className) |
| 85 | + } catch { |
| 86 | + case _: ClassNotFoundException => return |
| 87 | + case _: NoClassDefFoundError => return |
| 88 | + case _: ClassCastException => return |
| 89 | + case _: UnsupportedClassVersionError => return |
| 90 | + } |
| 91 | + |
| 92 | + // Check if it is a subtype of `UnitTest` (and also not the definition of |
| 93 | + // `UnitTest` itself). |
| 94 | + if (clazz == classOf[UnitTest] || !classOf[UnitTest].isAssignableFrom(clazz)) |
| 95 | + return |
| 96 | + |
| 97 | + // Check if this is a `BaseModule`, in which case we implicitly wrap its |
| 98 | + // constructor in a `Definition(...)` call. |
| 99 | + val isModule = classOf[BaseModule].isAssignableFrom(clazz) |
| 100 | + |
| 101 | + // Handle singleton objects by ensuring they are constructed. |
| 102 | + try { |
| 103 | + val field = clazz.getField("MODULE$") |
| 104 | + if (isModule) |
| 105 | + cb(className, () => Definition(field.get(null).asInstanceOf[BaseModule])) |
| 106 | + else |
| 107 | + cb(className, () => field.get(null)) |
| 108 | + return |
| 109 | + } catch { |
| 110 | + case e: NoSuchFieldException => () |
| 111 | + } |
| 112 | + |
| 113 | + // Handle classes by calling their constructor. |
| 114 | + try { |
| 115 | + val ctor = clazz.getConstructor() |
| 116 | + if (isModule) |
| 117 | + cb(className, () => Definition(ctor.newInstance().asInstanceOf[BaseModule])) |
| 118 | + else |
| 119 | + cb(className, () => ctor.newInstance()) |
| 120 | + return |
| 121 | + } catch { |
| 122 | + case e: NoSuchMethodException => () |
| 123 | + case e: IllegalAccessException => () |
| 124 | + } |
| 125 | + } |
| 126 | +} |
| 127 | + |
| 128 | +/** A Chisel module that discovers and constructs all [[UnitTest]] subtypes |
| 129 | + * discovered in the classpath. This is just here as a convenience top-level |
| 130 | + * generator to collect all unit tests. In practice you would likely want to |
| 131 | + * use a command line utility that offers some additional filtering capability. |
| 132 | + */ |
| 133 | +class AllUnitTests extends RawModule { |
| 134 | + DiscoverUnitTests((_, gen) => gen()) |
| 135 | +} |
0 commit comments