Skip to content

Commit ae3d735

Browse files
Merge pull request #4642 from fabianschuiki/unit-test-discovery
Add UnitTest marker and test discovery utility
2 parents ef43571 + 8c49e6d commit ae3d735

File tree

4 files changed

+327
-1
lines changed

4 files changed

+327
-1
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package chisel3.test
4+
5+
import chisel3.experimental.BaseModule
6+
import chisel3.experimental.hierarchy.Definition
7+
import chisel3.RawModule
8+
import java.io.File
9+
import java.util.jar.JarFile
10+
import scala.collection.JavaConverters._
11+
12+
/** All classes and objects marked as [[UnitTest]] are automatically
13+
* discoverable by the `DiscoverUnitTests` helper.
14+
*/
15+
trait UnitTest
16+
17+
/** Helper to discover all subtypes of [[UnitTest]] in the class path, and call
18+
* their constructors (if they are a class) or ensure that the singleton is
19+
* constructed (if they are an object).
20+
*
21+
* This code is loosely based on the test suite discovery in scalatest, which
22+
* performs the same scan over the classpath JAR files and directories, and
23+
* guesses class names based on the encountered directory structure.
24+
*/
25+
private[chisel3] object DiscoverUnitTests {
26+
27+
/** The callback invoked for each unit test class name and unit test
28+
* constructor.
29+
*/
30+
type Callback = (String, () => Unit) => Unit
31+
32+
/** Discover all tests in the classpath and call `cb` for each. */
33+
def apply(cb: Callback): Unit = classpath().foreach(discoverFile(_, cb))
34+
35+
/** Return the a sequence of files or directories on the classpath. */
36+
private def classpath(): Iterable[File] = System
37+
.getProperty("java.class.path")
38+
.split(File.pathSeparator)
39+
.map(s => if (s.trim.length == 0) "." else s)
40+
.map(new File(_))
41+
42+
/** Discover all tests in a given file. If this is a JAR file, looks through
43+
* its contents and tries to find its classes.
44+
*/
45+
private def discoverFile(file: File, cb: Callback): Unit = file match {
46+
// Unzip JAR files and process the class files they contain.
47+
case _ if file.getPath.toLowerCase.endsWith(".jar") =>
48+
val jarFile = new java.util.jar.JarFile(file)
49+
jarFile.entries.asScala.foreach { jarEntry =>
50+
val name = jarEntry.getName
51+
if (!jarEntry.isDirectory && name.endsWith(".class"))
52+
discoverClass(pathToClassName(name), cb)
53+
}
54+
55+
// Recursively collect any class files contained in directories.
56+
case _ if file.isDirectory =>
57+
def visit(prefix: String, file: File): Unit = {
58+
val name = prefix + "/" + file.getName
59+
if (file.isDirectory) {
60+
for (entry <- file.listFiles)
61+
visit(name, entry)
62+
} else if (name.endsWith(".class")) {
63+
discoverClass(pathToClassName(name), cb)
64+
}
65+
}
66+
for (entry <- file.listFiles)
67+
visit("", entry)
68+
69+
// Ignore any other files that aren't directories.
70+
case _ => ()
71+
}
72+
73+
/** Convert a file path to a class */
74+
private def pathToClassName(path: String): String =
75+
path.replace('/', '.').replace('\\', '.').stripPrefix(".").stripSuffix(".class")
76+
77+
/** Load the given class and check whether it is a subtype of [[UnitTest]]. If
78+
* it is, call the user-provided callback with a function that either calls
79+
* the loaded class' constructor or ensures the loaded object is constructed.
80+
*/
81+
private def discoverClass(className: String, cb: Callback): Unit = {
82+
val clazz =
83+
try {
84+
classOf[UnitTest].getClassLoader.loadClass(className)
85+
} catch {
86+
case _: ClassNotFoundException => return
87+
case _: NoClassDefFoundError => return
88+
case _: ClassCastException => return
89+
case _: UnsupportedClassVersionError => return
90+
}
91+
92+
// Check if it is a subtype of `UnitTest` (and also not the definition of
93+
// `UnitTest` itself).
94+
if (clazz == classOf[UnitTest] || !classOf[UnitTest].isAssignableFrom(clazz))
95+
return
96+
97+
// Check if this is a `BaseModule`, in which case we implicitly wrap its
98+
// constructor in a `Definition(...)` call.
99+
val isModule = classOf[BaseModule].isAssignableFrom(clazz)
100+
101+
// Handle singleton objects by ensuring they are constructed.
102+
try {
103+
val field = clazz.getField("MODULE$")
104+
if (isModule)
105+
cb(className, () => Definition(field.get(null).asInstanceOf[BaseModule]))
106+
else
107+
cb(className, () => field.get(null))
108+
return
109+
} catch {
110+
case e: NoSuchFieldException => ()
111+
}
112+
113+
// Handle classes by calling their constructor.
114+
try {
115+
val ctor = clazz.getConstructor()
116+
if (isModule)
117+
cb(className, () => Definition(ctor.newInstance().asInstanceOf[BaseModule]))
118+
else
119+
cb(className, () => ctor.newInstance())
120+
return
121+
} catch {
122+
case e: NoSuchMethodException => ()
123+
case e: IllegalAccessException => ()
124+
}
125+
}
126+
}
127+
128+
/** A Chisel module that discovers and constructs all [[UnitTest]] subtypes
129+
* discovered in the classpath. This is just here as a convenience top-level
130+
* generator to collect all unit tests. In practice you would likely want to
131+
* use a command line utility that offers some additional filtering capability.
132+
*/
133+
class AllUnitTests extends RawModule {
134+
DiscoverUnitTests((_, gen) => gen())
135+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package chisel3
2+
3+
import chisel3.test.DiscoverUnitTests
4+
import circt.stage.ChiselStage
5+
import java.io.File
6+
import java.io.PrintStream
7+
import scala.util.matching.Regex
8+
import scopt.OptionParser
9+
10+
/** Utility to discover and generate all unit tests in the classpath. */
11+
object UnitTests {
12+
13+
/** Command line configuration options. */
14+
case class Config(
15+
outputFile: Option[File] = None,
16+
list: Boolean = false,
17+
verbose: Boolean = false,
18+
filters: List[Regex] = List(),
19+
excludes: List[Regex] = List()
20+
)
21+
22+
def main(args: Array[String]): Unit = {
23+
var shouldExit = false
24+
val parser = new OptionParser[Config]("chisel3.UnitTests") {
25+
head("Chisel Unit Test Utility")
26+
help("help").abbr("h")
27+
28+
opt[File]('o', "output")
29+
.text("Output file name (\"-\" for stdout)")
30+
.action((x, c) => c.copy(outputFile = if (!x.getPath.isEmpty && x.getPath != "-") Some(x) else None))
31+
32+
opt[Unit]('l', "list")
33+
.text("List tests instead of building them")
34+
.action((_, c) => c.copy(list = true))
35+
36+
opt[Unit]('v', "verbose")
37+
.text("Print verbose information to stderr")
38+
.action((_, c) => c.copy(verbose = true))
39+
40+
opt[Seq[String]]('f', "filter")
41+
.text("Only consider tests which match at least one filter regex")
42+
.unbounded()
43+
.action((x, c) => c.copy(filters = c.filters ++ x.map(_.r)))
44+
45+
opt[Seq[String]]('x', "exclude")
46+
.text("Ignore tests which match at least one exclusion regex")
47+
.unbounded()
48+
.action((x, c) => c.copy(excludes = c.excludes ++ x.map(_.r)))
49+
50+
// Do not `sys.exit` on `help` to facilitate testing.
51+
override def terminate(exitState: Either[String, Unit]): Unit = {
52+
shouldExit = true
53+
}
54+
}
55+
56+
// Parse the command line options.
57+
val config = parser.parse(args, Config()) match {
58+
case Some(config) => config
59+
case None => return
60+
}
61+
if (shouldExit)
62+
return
63+
64+
// Define the handler that will be called for each discovered unit test, and
65+
// which will decide whether the test is generated or not.
66+
def handler(className: String, gen: () => Unit): Unit = {
67+
// If none of the inclusion filters match, skip this test.
68+
if (!config.filters.isEmpty && !config.filters.exists(_.findFirstMatchIn(className).isDefined)) {
69+
Console.err.println(f"Skipping ${className} (does not match filter)")
70+
return
71+
}
72+
73+
// If any of the exclusion filters match, skip this test.
74+
if (config.excludes.exists(_.findFirstMatchIn(className).isDefined)) {
75+
Console.err.println(f"Skipping ${className} (matches exclude filter)")
76+
return
77+
}
78+
79+
// If we are just listing tests, print the class name and skip this test.
80+
if (config.list) {
81+
println(className)
82+
return
83+
}
84+
85+
// Otherwise generate the test.
86+
if (config.verbose)
87+
Console.err.println(f"Building ${className}")
88+
gen()
89+
}
90+
91+
// If the user only asked for a list of tests, run test discovery without
92+
// setting up any of the Chisel builder stuff in the background. The handler
93+
// will never actually call the Chisel generators in this mode.
94+
if (config.list) {
95+
DiscoverUnitTests(handler)
96+
return
97+
}
98+
99+
// Generate the unit tests.
100+
class AllUnitTests extends RawModule {
101+
DiscoverUnitTests(handler)
102+
}
103+
val chirrtl = ChiselStage.emitCHIRRTL(new AllUnitTests)
104+
105+
// Write the result to the output.
106+
val output: PrintStream = config.outputFile match {
107+
case Some(file) => new PrintStream(file)
108+
case None => Console.out
109+
}
110+
try {
111+
output.print(chirrtl)
112+
} finally {
113+
output.close()
114+
}
115+
}
116+
}

src/test/scala/chiselTests/ChiselSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ trait FileCheck extends BeforeAndAfterEachTestData { this: Suite =>
238238
// Filecheck needs the thing to check in a file
239239
os.write.over(checkFile.get, check)
240240
val extraArgs = os.Shellable(fileCheckArgs)
241-
os.proc("FileCheck", checkFile.get, extraArgs).call(stdin = in)
241+
os.proc("FileCheck", "--allow-empty", checkFile.get, extraArgs).call(stdin = in)
242242
}
243243

244244
/** Elaborate a Module to FIRRTL and check the FIRRTL with FileCheck */
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package chiselTests
4+
5+
import chisel3._
6+
import chisel3.test._
7+
import java.io.{ByteArrayOutputStream, PrintStream}
8+
import org.scalatest.flatspec.AnyFlatSpec
9+
import org.scalatest.matchers.should.Matchers
10+
11+
class UnitTestMainSpec extends AnyFlatSpec with Matchers with FileCheck {
12+
def check(args: Seq[String])(checkOut: String, checkErr: String): Unit = {
13+
val outStream = new ByteArrayOutputStream()
14+
val errStream = new ByteArrayOutputStream()
15+
Console.withOut(new PrintStream(outStream)) {
16+
Console.withErr(new PrintStream(errStream)) {
17+
UnitTests.main(args.toArray)
18+
}
19+
}
20+
if (!checkOut.isEmpty) fileCheckString(outStream.toString)(checkOut)
21+
if (!checkErr.isEmpty) fileCheckString(errStream.toString)(checkErr)
22+
}
23+
24+
def checkOutAndErr(args: String*)(checkOut: String, checkErr: String): Unit = check(args)(checkOut, checkErr)
25+
26+
def checkOut(args: String*)(checkOut: String): Unit = check(args)(checkOut, "")
27+
def checkErr(args: String*)(checkErr: String): Unit = check(args)("", checkErr)
28+
29+
it should "print a help page" in {
30+
checkOut("-h")("""
31+
// CHECK: Chisel Unit Test Utility
32+
// CHECK: Usage:
33+
// CHECK: -h, --help
34+
""")
35+
}
36+
37+
it should "list unit tests" in {
38+
checkOutAndErr("-l", "-f", "^chiselTests\\.sampleTests\\.")(
39+
"""
40+
// CHECK-DAG: chiselTests.sampleTests.ClassTest
41+
// CHECK-DAG: chiselTests.sampleTests.ObjectTest
42+
// CHECK-DAG: chiselTests.sampleTests.ModuleTest
43+
""",
44+
"""
45+
// CHECK-NOT: Hello from
46+
"""
47+
)
48+
}
49+
50+
it should "execute unit test constructors" in {
51+
checkErr("-f", "^chiselTests\\.sampleTests\\.")("""
52+
// CHECK-DAG: Hello from class test
53+
// CHECK-DAG: Hello from object test
54+
// CHECK-DAG: Hello from module test
55+
""")
56+
}
57+
58+
it should "generate unit test FIRRTL" in {
59+
checkOut("-f", "^chiselTests\\.sampleTests\\.")("""
60+
// CHECK: module ModuleTest :
61+
""")
62+
}
63+
}
64+
65+
package sampleTests {
66+
class ClassTest extends UnitTest {
67+
Console.err.println("Hello from class test")
68+
}
69+
object ObjectTest extends UnitTest {
70+
Console.err.println("Hello from object test")
71+
}
72+
class ModuleTest extends RawModule with UnitTest {
73+
Console.err.println("Hello from module test")
74+
}
75+
}

0 commit comments

Comments
 (0)