Skip to content

Commit ee470bc

Browse files
authored
Use zinc worker for DiscoverTestMain and GetTestTasks (#6125)
This should reduce the performance cost of spawning a new JVM in scenarios where we already have a JVM worker running
1 parent e8f1012 commit ee470bc

File tree

11 files changed

+126
-86
lines changed

11 files changed

+126
-86
lines changed

ci/mill-bootstrap.patch

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
diff --git a/integration/package.mill b/integration/package.mill
2+
index 5c388fafb6e..31763326edf 100644
3+
--- a/integration/package.mill
4+
+++ b/integration/package.mill
5+
@@ -75,7 +75,8 @@ object `package` extends mill.Module {
6+
testReportXml(),
7+
javaHome().map(_.path),
8+
testParallelism(),
9+
- testLogLevel()
10+
+ testLogLevel(),
11+
+ jvmWorker = jvmWorker().worker()
12+
)
13+
testModuleUtil.runTests()
14+
}

libs/javalib/api/src/mill/javalib/api/JvmWorkerApi.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,14 @@ trait JvmWorkerApi {
4747
javaHome: Option[os.Path],
4848
args: Seq[String]
4949
)(using ctx: JvmWorkerApi.Ctx): Boolean
50+
51+
def discoverTests(
52+
value: mill.javalib.api.internal.ZincDiscoverTests,
53+
javaHome: Option[os.Path]
54+
)(using ctx: JvmWorkerApi.Ctx): Seq[String] = Nil
55+
56+
def getTestTasks(
57+
value: mill.javalib.api.internal.ZincGetTestTasks,
58+
javaHome: Option[os.Path]
59+
)(using ctx: JvmWorkerApi.Ctx): Seq[String] = Nil
5060
}

libs/javalib/api/src/mill/javalib/api/internal/zinc_operations.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,14 @@ case class ZincScaladocJar(
3636
scalacPluginClasspath: Seq[PathRef],
3737
args: Seq[String]
3838
) derives upickle.ReadWriter
39+
40+
case class ZincDiscoverTests(runCp: Seq[os.Path], testCp: Seq[os.Path], framework: String)
41+
derives upickle.ReadWriter
42+
43+
case class ZincGetTestTasks(
44+
runCp: Seq[os.Path],
45+
testCp: Seq[os.Path],
46+
framework: String,
47+
selectors: Seq[String],
48+
args: Seq[String]
49+
) derives upickle.ReadWriter

libs/javalib/src/mill/javalib/TestModule.scala

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,8 @@ import mill.javalib.bsp.BspModule
1313
import mill.util.Jvm
1414
import mill.api.JsonFormatters.given
1515
import mill.constants.EnvVars
16-
import mill.javalib.testrunner.{
17-
DiscoverTestsMain,
18-
Framework,
19-
TestArgs,
20-
TestResult,
21-
TestRunner,
22-
TestRunnerUtils
23-
}
16+
import mill.javalib.api.internal.ZincDiscoverTests
17+
import mill.javalib.testrunner.{Framework, TestArgs, TestResult, TestRunner, TestRunnerUtils}
2418

2519
import java.nio.file.Path
2620

@@ -67,27 +61,15 @@ trait TestModule
6761
* Test classes (often called test suites) discovered by the configured [[testFramework]].
6862
*/
6963
def discoveredTestClasses: T[Seq[String]] = Task {
70-
val classes = if (javaHome().isDefined) {
71-
Jvm.callProcess(
72-
mainClass = "mill.javalib.testrunner.DiscoverTestsMain",
73-
classPath = jvmWorker().scalalibClasspath().map(_.path).toVector,
74-
mainArgs =
75-
runClasspath().flatMap(p => Seq("--runCp", p.path.toString())) ++
76-
testClasspath().flatMap(p => Seq("--testCp", p.path.toString())) ++
77-
Seq("--framework", testFramework()),
78-
javaHome = javaHome().map(_.path),
79-
stdin = os.Inherit,
80-
stdout = os.Pipe,
81-
cwd = Task.dest
82-
).out.lines()
83-
} else {
84-
DiscoverTestsMain.main0(
64+
val discoveredTests = jvmWorker().worker().discoverTests(
65+
ZincDiscoverTests(
8566
runClasspath().map(_.path),
8667
testClasspath().map(_.path),
8768
testFramework()
88-
)
89-
}
90-
classes.sorted
69+
),
70+
javaHome().map(_.path)
71+
)
72+
discoveredTests.sorted
9173
}
9274

9375
/**
@@ -266,7 +248,8 @@ trait TestModule
266248
javaHome().map(_.path),
267249
testParallelism(),
268250
testLogLevel(),
269-
propagateEnv()
251+
propagateEnv(),
252+
jvmWorker().worker()
270253
)
271254
testModuleUtil.runTests()
272255
}

libs/javalib/src/mill/javalib/TestModuleUtil.scala

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ import mill.api.Logger
1717

1818
import java.util.concurrent.ConcurrentHashMap
1919
import mill.api.BuildCtx
20-
import mill.javalib.testrunner.{GetTestTasksMain, TestArgs, TestResult, TestRunnerUtils}
20+
import mill.javalib.api.internal.ZincGetTestTasks
21+
import mill.javalib.testrunner.{TestArgs, TestResult, TestRunnerUtils}
2122
import os.Path
2223

2324
import scala.annotation.unused
@@ -46,7 +47,8 @@ final class TestModuleUtil(
4647
javaHome: Option[os.Path],
4748
testParallelism: Boolean,
4849
testLogLevel: TestReporter.LogLevel,
49-
propagateEnv: Boolean = true
50+
propagateEnv: Boolean = true,
51+
jvmWorker: mill.javalib.api.JvmWorkerApi
5052
)(using ctx: mill.api.TaskCtx) {
5153

5254
private val (jvmArgs, props) = TestModuleUtil.loadArgsAndProps(useArgsFile, forkArgs)
@@ -78,32 +80,16 @@ final class TestModuleUtil(
7880
// test group requires spawning a JVM which can take 1+ seconds to realize there are no
7981
// tests to run and shut down
8082

81-
val discoveredTests = if (javaHome.isDefined) {
82-
Jvm.callProcess(
83-
mainClass = "mill.javalib.testrunner.GetTestTasksMain",
84-
classPath = scalalibClasspath.map(_.path).toVector,
85-
mainArgs =
86-
(runClasspath ++ testrunnerEntrypointClasspath).flatMap(p =>
87-
Seq("--runCp", p.path.toString)
88-
) ++
89-
testClasspath.flatMap(p => Seq("--testCp", p.path.toString)) ++
90-
Seq("--framework", testFramework) ++
91-
selectors.flatMap(s => Seq("--selectors", s)) ++
92-
args.flatMap(s => Seq("--args", s)),
93-
javaHome = javaHome,
94-
stdin = os.Inherit,
95-
stdout = os.Pipe,
96-
cwd = Task.dest
97-
).out.lines().toSet
98-
} else {
99-
GetTestTasksMain.main0(
83+
val discoveredTests = jvmWorker.getTestTasks(
84+
ZincGetTestTasks(
10085
(runClasspath ++ testrunnerEntrypointClasspath).map(_.path),
10186
testClasspath.map(_.path),
10287
testFramework,
10388
selectors,
10489
args
105-
).toSet
106-
}
90+
),
91+
javaHome
92+
).toSet
10793

10894
filteredClassLists0.map(_.filter(discoveredTests)).filter(_.nonEmpty)
10995
}

libs/javalib/testrunner/src/mill/javalib/testrunner/DiscoverTestsMain.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,8 @@ package mill.javalib.testrunner
33
import mill.api.daemon.internal.internal
44

55
@internal object DiscoverTestsMain {
6-
import mill.api.JsonFormatters.PathTokensReader
7-
8-
@mainargs.main
9-
def main(runCp: Seq[os.Path], testCp: Seq[os.Path], framework: String): Unit = {
10-
main0(runCp, testCp, framework).foreach(println)
11-
}
12-
def main0(runCp: Seq[os.Path], testCp: Seq[os.Path], framework: String): Seq[String] = {
6+
def apply(args0: mill.javalib.api.internal.ZincDiscoverTests): Seq[String] = {
7+
import args0.*
138
mill.util.Jvm.withClassLoader(
149
classPath = runCp,
1510
sharedPrefixes = Seq("sbt.testing.")
@@ -24,6 +19,4 @@ import mill.api.daemon.internal.internal
2419
}
2520
}
2621
}
27-
28-
def main(args: Array[String]): Unit = mainargs.Parser(this).runOrExit(args.toSeq)
2922
}

libs/javalib/testrunner/src/mill/javalib/testrunner/GetTestTasksMain.scala

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,9 @@ package mill.javalib.testrunner
33
import mill.api.daemon.internal.internal
44

55
@internal object GetTestTasksMain {
6-
import mill.api.JsonFormatters.PathTokensReader
7-
@mainargs.main
8-
def main(
9-
runCp: Seq[os.Path],
10-
testCp: Seq[os.Path],
11-
framework: String,
12-
selectors: Seq[String],
13-
args: Seq[String]
14-
): Unit = {
15-
main0(runCp, testCp, framework, selectors, args).foreach(println)
16-
}
176

18-
def main0(
19-
runCp: Seq[os.Path],
20-
testCp: Seq[os.Path],
21-
framework: String,
22-
selectors: Seq[String],
23-
args: Seq[String]
24-
): Seq[String] = {
7+
def apply(args0: mill.javalib.api.internal.ZincGetTestTasks): Seq[String] = {
8+
import args0.*
259
val globFilter = TestRunnerUtils.globFilter(selectors)
2610
mill.util.Jvm.withClassLoader(
2711
classPath = runCp,
@@ -38,6 +22,4 @@ import mill.api.daemon.internal.internal
3822
.toSeq
3923
}
4024
}
41-
42-
def main(args: Array[String]): Unit = mainargs.Parser(this).runOrExit(args.toSeq)
4325
}

libs/javalib/worker/src/mill/javalib/worker/JvmWorkerImpl.scala

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ class JvmWorkerImpl(args: JvmWorkerArgs) extends JvmWorkerApi with AutoCloseable
8585
result.result
8686
}
8787

88+
override def discoverTests(
89+
op: mill.javalib.api.internal.ZincDiscoverTests,
90+
javaHome: Option[os.Path]
91+
)(using ctx: JvmWorkerApi.Ctx): Seq[String] = {
92+
given RequestId = requestIds.next()
93+
94+
val zinc = zincApi(javaHome, JavaRuntimeOptions(Seq.empty))
95+
val result = Timed(zinc.discoverTests(op))
96+
fileLog(s"discoverTests took ${result.durationPretty}")
97+
result.result
98+
}
99+
100+
override def getTestTasks(
101+
op: mill.javalib.api.internal.ZincGetTestTasks,
102+
javaHome: Option[os.Path]
103+
)(using
104+
ctx: JvmWorkerApi.Ctx
105+
): Seq[String] = {
106+
given RequestId = requestIds.next()
107+
108+
val zinc = zincApi(javaHome, JavaRuntimeOptions(Seq.empty))
109+
val result = Timed(zinc.getTestTasks(op))
110+
fileLog(s"getTestTasks took ${result.durationPretty}")
111+
result.result
112+
}
113+
88114
override def close(): Unit = {
89115
zincLocalWorker.close()
90116
subprocessCache.close()
@@ -274,10 +300,7 @@ class JvmWorkerImpl(args: JvmWorkerArgs) extends JvmWorkerApi with AutoCloseable
274300
}
275301

276302
/** Gives you API for the [[zincLocalWorker]] instance. */
277-
private def localZincApi(
278-
zincCtx: ZincWorker.InvocationContext,
279-
log: Logger
280-
): ZincApi = {
303+
private def localZincApi(zincCtx: ZincWorker.InvocationContext, log: Logger): ZincApi = {
281304
val zincDeps = ZincWorker.InvocationDependencies(
282305
log = log,
283306
consoleOut = ConsoleOut.printStreamOut(log.streams.err),
@@ -413,6 +436,20 @@ class JvmWorkerImpl(args: JvmWorkerArgs) extends JvmWorkerApi with AutoCloseable
413436
rpcClient(msg)
414437
}
415438
}
439+
440+
override def discoverTests(op: mill.javalib.api.internal.ZincDiscoverTests): Seq[String] = {
441+
withRpcClient(serverRpcToClientHandler(reporter = None, log, cacheKey)) { rpcClient =>
442+
val msg = ZincWorkerRpcServer.ClientToServer.DiscoverTests(op)
443+
rpcClient(msg)
444+
}
445+
}
446+
447+
override def getTestTasks(op: mill.javalib.api.internal.ZincGetTestTasks): Seq[String] = {
448+
withRpcClient(serverRpcToClientHandler(reporter = None, log, cacheKey)) { rpcClient =>
449+
val msg = ZincWorkerRpcServer.ClientToServer.GetTestTasks(op)
450+
rpcClient(msg)
451+
}
452+
}
416453
}
417454
}
418455

libs/javalib/worker/src/mill/javalib/zinc/ZincApi.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ trait ZincApi {
2626
def scaladocJar(
2727
op: ZincScaladocJar
2828
): Boolean
29+
30+
def discoverTests(value: mill.javalib.api.internal.ZincDiscoverTests): Seq[String]
31+
32+
def getTestTasks(value: mill.javalib.api.internal.ZincGetTestTasks): Seq[String]
2933
}

libs/javalib/worker/src/mill/javalib/zinc/ZincWorker.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ import java.util.Optional
3333
import scala.collection.mutable
3434

3535
/** @param jobs number of parallel jobs */
36-
class ZincWorker(
37-
jobs: Int
38-
) extends AutoCloseable { self =>
36+
class ZincWorker(jobs: Int) extends AutoCloseable { self =>
3937
private val incrementalCompiler = new sbt.internal.inc.IncrementalCompilerImpl()
4038
private val compilerBridgeLocks: mutable.Map[String, MemoryLock] = mutable.Map.empty
4139

@@ -304,6 +302,12 @@ class ZincWorker(
304302

305303
override def scaladocJar(op: ZincScaladocJar): Boolean =
306304
self.scaladocJar(op, deps.compilerBridge)
305+
306+
override def discoverTests(op: mill.javalib.api.internal.ZincDiscoverTests): Seq[String] =
307+
mill.javalib.testrunner.DiscoverTestsMain(op)
308+
309+
override def getTestTasks(op: mill.javalib.api.internal.ZincGetTestTasks): Seq[String] =
310+
mill.javalib.testrunner.GetTestTasksMain(op)
307311
}
308312

309313
def close(): Unit = {

0 commit comments

Comments
 (0)