Skip to content

Commit 17d4e5d

Browse files
authored
Synthetize main classes for @mainarg.main methods in Scala scripts for Intellij/@scala.main compatibility (#6084)
* `@scala.main` synthesizes a main class for each main method, which IntelliJ expects when you click the "run" button next to the main method `def`. * `@mainargs.main` does not, as it expects a single `def main` method to call `mainargs.Parser(this).runOrExit` to dispatch to the various annotated entrypoints * This PR bridges the gap by: * Hackily parsing out the name of the annotated main methods from the `mainargs.Parser(this).` macro expansion in the synthesized `def main` method (Ideally we could fish out the annotations with ASM, but despite `@mainargs.main` being a `ClassFileAnnotation` it does not appear in the bytecode), * Generates synthetic main classes corresponding to each one that forward to the respective annotated method * Filters them out during `allLocalMainClasses` discovery, so that normal `run` only runs the primary synthetic `main`, while `runMain` or similar (e.g. what IntelliJ does) can run the others Covered by an integration test and tested manually in IntelliJ to make sure the run button works in single-main and multi-main scenarios <img width="976" height="652" alt="Screenshot 2025-11-01 at 11 15 03 PM" src="https://github.com/user-attachments/assets/a9a4a02e-f122-4977-9be4-c3e2bd55e960" />
1 parent 3504b6b commit 17d4e5d

File tree

12 files changed

+264
-6
lines changed

12 files changed

+264
-6
lines changed

integration/feature/auxiliary-class-files/resources/build.mill

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ object app extends Module {
99
trait Common extends ScalaModule {
1010
def moduleDir = super.moduleDir / os.up
1111
def scalaVersion = "3.4.0"
12+
// force this on for testing even though there's only one file
13+
def zincIncrementalCompilation = true
1214
}
1315

1416
object jvm extends Common
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@main
2+
def main1(text: String) = println(text + "123")
3+
4+
@main
5+
def main2(text: String) = println(text + "456")
6+
7+
@main
8+
def main3(text: String) = println(text + "789")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
@scala.main
2+
def main1(text: String) = println(text + "OMG")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
@scala.main
2+
def main1(text: String) = println(text + "ABC")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
@main
2+
def main1(text: String) = println(text + "XYZ")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package mill.integration
2+
3+
import mill.testkit.UtestIntegrationTestSuite
4+
import utest._
5+
6+
// Make sure that for scripts with multiple `@main` methods (aliased to `@mainargs.main`),
7+
// we generate synthetic main classes for each one following the name of the class that we
8+
// can run using `runMain`. This mimics the behavior of `@scala.main`, and allows interop
9+
// with tools that expect that behavior such as the IntelliJ `run` button
10+
object ScriptMainForwarderClassesTests extends UtestIntegrationTestSuite {
11+
val tests: Tests = Tests {
12+
test("test") - integrationTest { tester =>
13+
import tester._
14+
// When using `run`, we ignore the synthetic main classes when picking a main method,
15+
// so in this case we run the default `_MillScriptMain` method which delegates
16+
// to the relevant method internally based on the first token
17+
val res0 = eval(("Multi.scala:run", "main1", "--text", "HELLO"))
18+
assert(res0.out == "HELLO123")
19+
20+
// Multi-main script have forwarder classes synthesized for each @mainargs.main method,
21+
// which forward to `_MillScriptMain` but pass the name as the first param to disambiguate
22+
val res1 = eval(("Multi.scala:runMain", "main1", "--text", "hello"))
23+
assert(res1.out == "hello123")
24+
25+
val res2 = eval(("Multi.scala:runMain", "main2", "--text", "world"))
26+
assert(res2.out == "world456")
27+
28+
val res3 = eval(("Multi.scala:runMain", "main3", "--text", "moooo"))
29+
assert(res3.out == "moooo789")
30+
31+
// Single-main script, forwarder class must *not* pass the method name as
32+
// the first parameter
33+
val res4 = eval(("Single.scala:runMain", "main1", "--text", "iamcow"))
34+
assert(res4.out == "iamcowXYZ")
35+
36+
// scala.main method takes priority over synthetic _MillScriptMain method
37+
val res5 = eval(("ScalaMain.scala:run", "hearmemoo"))
38+
assert(res5.out == "hearmemooABC")
39+
40+
// `def main(args: Array[String]): Unit` method takes priority over synthetic _MillScriptMain method
41+
val res6 = eval(("RawMainSignature.scala:run", "iweightwiceasmuchasyou"))
42+
assert(res6.out == "iweightwiceasmuchasyouOMG")
43+
}
44+
}
45+
}

libs/javalib/src/mill/javalib/JavaModule.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -836,9 +836,12 @@ trait JavaModule
836836
).equalsIgnoreCase("true")
837837
}
838838

839-
def zincIncrementalCompilation: T[Boolean] = Task {
840-
true
841-
}
839+
/**
840+
* Whether to turn on zinc incremental compilation or not, as it can speed things up
841+
* by skipping some source files but also adds some performance overhead. Defaults
842+
* to turning it on if there is more than one source file being compiled
843+
*/
844+
def zincIncrementalCompilation: T[Boolean] = Task { allSourceFiles().length > 1 }
842845

843846
/**
844847
* Compiles the current module to generate compiled classfiles/bytecode.
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package mill.script.asm
2+
3+
import org.objectweb.asm
4+
5+
object AsmWorkerImpl {
6+
7+
def generateSyntheticClasses(classesDir: java.nio.file.Path, mainMethods: Array[String]): Unit = {
8+
mainMethods.foreach { methodName =>
9+
generateSyntheticMainClass(os.Path(classesDir), methodName, mainMethods.size > 1)
10+
}
11+
}
12+
13+
def findMainArgsMethods(classesDir: java.nio.file.Path): Array[String] = {
14+
val mainMethods = collection.mutable.ArrayBuffer[String]()
15+
16+
// Look for _MillScriptMain$ class which contains the mainargs.Parser code
17+
val millScriptMainClass = os.Path(classesDir) / "_MillScriptMain$.class"
18+
19+
if (os.exists(millScriptMainClass)) {
20+
val reader = new asm.ClassReader(os.read.bytes(millScriptMainClass))
21+
22+
val visitor = new asm.ClassVisitor(asm.Opcodes.ASM9) {
23+
override def visitMethod(
24+
access: Int,
25+
name: String,
26+
descriptor: String,
27+
signature: String,
28+
exceptions: Array[String]
29+
): asm.MethodVisitor = {
30+
new asm.MethodVisitor(asm.Opcodes.ASM9) {
31+
private val stringsSinceLastCreate = collection.mutable.ArrayBuffer[String]()
32+
33+
override def visitLdcInsn(value: Any): Unit = {
34+
value match {
35+
case s: String => stringsSinceLastCreate += s
36+
case _ =>
37+
}
38+
super.visitLdcInsn(value)
39+
}
40+
41+
override def visitMethodInsn(
42+
opcode: Int,
43+
owner: String,
44+
methodName: String,
45+
descriptor: String,
46+
isInterface: Boolean
47+
): Unit = {
48+
// Look for MainData.create calls which include the method name as first parameter
49+
if (owner.contains("MainData") && methodName == "create") {
50+
// The first string constant before MainData.create is the method name
51+
if (stringsSinceLastCreate.nonEmpty) {
52+
val potentialMethodName = stringsSinceLastCreate.head
53+
mainMethods += potentialMethodName
54+
stringsSinceLastCreate.clear()
55+
}
56+
}
57+
58+
super.visitMethodInsn(opcode, owner, methodName, descriptor, isInterface)
59+
}
60+
}
61+
}
62+
}
63+
64+
reader.accept(visitor, 0)
65+
}
66+
67+
mainMethods.toArray.distinct
68+
}
69+
70+
def generateSyntheticMainClass(
71+
classesDir: os.Path,
72+
methodName: String,
73+
multiMain: Boolean
74+
): Unit = {
75+
val templateClassName = if (multiMain) "TemplateMultiMainClass" else "TemplateSingleMainClass"
76+
val templateBytes = os.read.bytes(
77+
os.resource(getClass.getClassLoader) / os.SubPath(s"mill/script/asm/$templateClassName.class")
78+
)
79+
val reader = new asm.ClassReader(templateBytes)
80+
val writer = new asm.ClassWriter(reader, 0)
81+
82+
val visitor = new asm.ClassVisitor(asm.Opcodes.ASM9, writer) {
83+
override def visit(
84+
version: Int,
85+
access: Int,
86+
name: String,
87+
signature: String,
88+
superName: String,
89+
interfaces: Array[String]
90+
): Unit = {
91+
super.visit(version, access, methodName, signature, superName, interfaces)
92+
}
93+
94+
override def visitMethod(
95+
access: Int,
96+
name: String,
97+
descriptor: String,
98+
signature: String,
99+
exceptions: Array[String]
100+
): asm.MethodVisitor = {
101+
val mv = super.visitMethod(access, name, descriptor, signature, exceptions)
102+
103+
new asm.MethodVisitor(asm.Opcodes.ASM9, mv) {
104+
override def visitLdcInsn(value: Any): Unit = {
105+
// Replace "TEMPLATE_METHOD_NAME" with actual method name
106+
if (value == "TEMPLATE_METHOD_NAME") super.visitLdcInsn(methodName)
107+
else super.visitLdcInsn(value)
108+
}
109+
110+
override def visitMethodInsn(
111+
opcode: Int,
112+
owner: String,
113+
name: String,
114+
descriptor: String,
115+
isInterface: Boolean
116+
): Unit = {
117+
// Replace TemplateMainClass.main call with _MillScriptMain$.main
118+
if (owner == s"mill/script/asm/$templateClassName" && name == "main") {
119+
super.visitMethodInsn(opcode, "_MillScriptMain", name, descriptor, isInterface)
120+
} else {
121+
super.visitMethodInsn(opcode, owner, name, descriptor, isInterface)
122+
}
123+
}
124+
}
125+
}
126+
}
127+
128+
reader.accept(visitor, 0)
129+
130+
// Write the modified class file
131+
val classBytes = writer.toByteArray
132+
os.write.over(classesDir / s"$methodName.class", classBytes)
133+
}
134+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package mill.script.asm;
2+
3+
public class TemplateMultiMainClass {
4+
public static void main(String[] args) {
5+
String[] newArgs = new String[args.length + 1];
6+
newArgs[0] = "TEMPLATE_METHOD_NAME";
7+
System.arraycopy(args, 0, newArgs, 1, args.length);
8+
TemplateMultiMainClass.main(newArgs);
9+
}
10+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package mill.script.asm;
2+
3+
public class TemplateSingleMainClass {
4+
public static void main(String[] args) {
5+
TemplateSingleMainClass.main(args);
6+
}
7+
}

0 commit comments

Comments
 (0)