Skip to content

Commit 31a8b3c

Browse files
authored
Reduce size of generated code in Discover macro (#4463)
Fixes #3893 - Recurse over all module parent classes, in addition to member return types - Remove `allNames`, since it can now be derived from the `declaredNames` of every module class - Make `entryPoints` only store the entrypoints specifically declared for each class, and not those inherited Reduces the size of `out/mill-build/compile.dest/classes//build_/package_$.class` in the Mill repo from 820K to 73K. Haven't measured but I would expect a decrease in compile times as well
1 parent 351fc0f commit 31a8b3c

File tree

5 files changed

+81
-89
lines changed

5 files changed

+81
-89
lines changed

integration/feature/inspect/src/InspectTests.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ object InspectTests extends UtestIntegrationTestSuite {
6363
doc
6464
)
6565

66-
assert(eval(("inspect", "core.run")).isSuccess)
66+
val res2 = eval(("inspect", "core.run"))
67+
assert(res2.isSuccess)
6768
val run = out("inspect").json.str
6869

6970
assertGlobMatches(

integration/feature/scala-3-syntax/resources/sub/package.mill

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import mill.*
88
assert(1 + 1 == 2)
99

1010
// modifiers also allowed at top-level
11-
private def subCommand(): Command[Unit] = Task.Command:
11+
def subCommand(): Command[Unit] = Task.Command:
1212
println("Hello, sub-world!")
1313

1414
// top-level object with no extends clause

main/define/src/mill/define/Discover.scala

Lines changed: 71 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,21 @@ import scala.collection.mutable
1414
* the `Task.Command` methods we find. This mapping from `Class[_]` to `MainData`
1515
* can then be used later to look up the `MainData` for any module.
1616
*/
17-
class Discover(val classInfo: Map[Class[_], Discover.Node], val allNames: Seq[String])
17+
class Discover(val classInfo: Map[Class[_], Discover.ClassInfo]) {
18+
def resolveEntrypoint(cls: Class[_], name: String) = {
19+
val res = for {
20+
(cls2, node) <- classInfo
21+
if cls2.isAssignableFrom(cls)
22+
ep <- node.entryPoints
23+
if ep.name == name
24+
} yield ep
25+
26+
res.headOption
27+
}
28+
}
1829

1930
object Discover {
20-
class Node(
31+
class ClassInfo(
2132
val entryPoints: Seq[mainargs.MainData[_, _]],
2233
val declaredNames: Seq[String]
2334
)
@@ -37,14 +48,19 @@ object Discover {
3748
def rec(tpe: TypeRepr): Unit = {
3849
if (seen.add(tpe)) {
3950
val typeSym = tpe.typeSymbol
40-
for {
51+
val memberTypes: Seq[TypeRepr] = for {
4152
m <- typeSym.fieldMembers ++ typeSym.methodMembers
4253
if m != Symbol.noSymbol
43-
memberTpe = m.termRef
44-
if memberTpe.baseClasses.contains(moduleSym)
54+
} yield m.termRef
55+
56+
val parentTypes: Seq[TypeRepr] = tpe.baseClasses.map(_.typeRef)
57+
58+
for {
59+
tpe <- memberTypes ++ parentTypes
60+
if tpe.baseClasses.contains(moduleSym)
4561
} {
46-
rec(memberTpe)
47-
memberTpe.asType match {
62+
rec(tpe)
63+
tpe.asType match {
4864
case '[mill.define.Cross[m]] => rec(TypeRepr.of[m])
4965
case _ => () // no cross argument to extract
5066
}
@@ -100,83 +116,63 @@ object Discover {
100116
// otherwise the compiler likes to give us stuff in random orders, which
101117
// causes the code to be generated in random order resulting in code hashes
102118
// changing unnecessarily
103-
val mapping: Seq[(Expr[(Class[_], Node)], Seq[String])] = for {
104-
discoveredModuleType <- seen.toSeq.sortBy(_.typeSymbol.fullName)
105-
curCls = discoveredModuleType
106-
methods = filterDefs(curCls.typeSymbol.methodMembers)
107-
declMethods = filterDefs(curCls.typeSymbol.declaredMethods)
108-
_ = {
119+
val mapping: Seq[(TypeRepr, (Seq[scala.quoted.Expr[mainargs.MainData[?, ?]]], Seq[String]))] =
120+
for {
121+
curCls <- seen.toSeq.sortBy(_.typeSymbol.fullName)
122+
} yield {
123+
val declMethods = filterDefs(curCls.typeSymbol.declaredMethods)
109124
assertParamListCounts(
110125
curCls,
111-
methods,
126+
declMethods,
112127
(TypeRepr.of[mill.define.Command[?]], 1, "`Task.Command`"),
113128
(TypeRepr.of[mill.define.Target[?]], 0, "Target")
114129
)
115-
}
116130

117-
names =
118-
sortedMethods(curCls, sub = TypeRepr.of[mill.define.NamedTask[?]], methods).map(_.name)
119-
entryPoints = for {
120-
m <- sortedMethods(curCls, sub = TypeRepr.of[mill.define.Command[?]], methods)
121-
} yield curCls.asType match {
122-
case '[t] =>
123-
val expr =
124-
try
125-
createMainData[Any, t](
126-
m,
127-
m.annotations.find(_.tpe =:= TypeRepr.of[mainargs.main]).getOrElse('{
128-
new mainargs.main()
129-
}.asTerm),
130-
m.paramSymss
131-
).asExprOf[mainargs.MainData[?, ?]]
132-
catch {
133-
case NonFatal(e) =>
134-
val (before, Array(after, _*)) = e.getStackTrace().span(e =>
135-
!(e.getClassName() == "mill.define.Discover$Router$" && e.getMethodName() == "applyImpl")
136-
): @unchecked
137-
val trace =
138-
(before :+ after).map(_.toString).mkString("trace:\n", "\n", "\n...")
139-
report.errorAndAbort(
140-
s"Error generating maindata for ${m.fullName}: ${e}\n$trace",
141-
m.pos.getOrElse(Position.ofMacroExpansion)
142-
)
143-
}
144-
expr
145-
}
146-
declaredNames =
147-
sortedMethods(
148-
curCls,
149-
sub = TypeRepr.of[mill.define.NamedTask[?]],
150-
declMethods
151-
).map(_.name)
152-
if names.nonEmpty || entryPoints.nonEmpty
153-
} yield {
154-
// by wrapping the `overridesRoutes` in a lambda function we kind of work around
155-
// the problem of generating a *huge* macro method body that finally exceeds the
156-
// JVM's maximum allowed method size
157-
val overridesLambda = '{
158-
def triple() =
159-
new Node(${ Expr.ofList(entryPoints) }, ${ Expr(declaredNames) })
160-
triple()
131+
val names =
132+
sortedMethods(
133+
curCls,
134+
sub = TypeRepr.of[mill.define.NamedTask[?]],
135+
declMethods
136+
).map(_.name)
137+
val entryPoints = for {
138+
m <- sortedMethods(curCls, sub = TypeRepr.of[mill.define.Command[?]], declMethods)
139+
} yield curCls.asType match {
140+
case '[t] =>
141+
val expr =
142+
try
143+
createMainData[Any, t](
144+
m,
145+
m.annotations
146+
.find(_.tpe =:= TypeRepr.of[mainargs.main])
147+
.getOrElse('{ new mainargs.main() }.asTerm),
148+
m.paramSymss
149+
).asExprOf[mainargs.MainData[?, ?]]
150+
catch {
151+
case NonFatal(e) =>
152+
report.errorAndAbort(
153+
s"Error generating maindata for ${m.fullName}: ${e}\n${e.getStackTrace().mkString("\n")}",
154+
m.pos.getOrElse(Position.ofMacroExpansion)
155+
)
156+
}
157+
expr
158+
}
159+
160+
(curCls.widen, (entryPoints, names))
161161
}
162-
val lhs =
163-
Ref(defn.Predef_classOf).appliedToType(discoveredModuleType.widen).asExprOf[Class[?]]
164-
('{ $lhs -> $overridesLambda }, names)
162+
163+
val mappingExpr = mapping.collect {
164+
case (cls, (entryPoints, names)) if entryPoints.nonEmpty || names.nonEmpty =>
165+
// by wrapping the `overridesRoutes` in a lambda function we kind of work around
166+
// the problem of generating a *huge* macro method body that finally exceeds the
167+
// JVM's maximum allowed method size
168+
'{
169+
def func() = new ClassInfo(${ Expr.ofList(entryPoints.toList) }, ${ Expr(names) })
170+
171+
(${ Ref(defn.Predef_classOf).appliedToType(cls).asExprOf[Class[?]] }, func())
172+
}
165173
}
166174

167-
val expr: Expr[Discover] =
168-
'{
169-
// TODO: we can not import this here, so we have to import at the use site now, or redesign?
170-
// import mill.main.TokenReaders.*
171-
// import mill.api.JsonFormatters.*
172-
new Discover(
173-
Map[Class[_], Node](${ Varargs(mapping.map(_._1)) }*),
174-
${ Expr(mapping.iterator.flatMap(_._2).distinct.toList.sorted) }
175-
)
176-
}
177-
// TODO: if needed for debugging, we can re-enable this
178-
// report.warning(s"generated discovery for ${TypeRepr.of[T].show}:\n${expr.asTerm.show}", TypeRepr.of[T].typeSymbol.pos.getOrElse(Position.ofMacroExpansion))
179-
expr
175+
'{ new Discover(Map[Class[_], ClassInfo](${ Varargs(mappingExpr) }*)) }
180176
}
181177
}
182178
}

main/resolve/src/mill/resolve/Resolve.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ object Resolve {
142142
val invoked = invokeCommand0(
143143
p,
144144
r.segments.last.value,
145-
rootModule.millDiscover.asInstanceOf[Discover],
145+
rootModule.millDiscover,
146146
args,
147147
nullCommandDefaults,
148148
allowPositionalCommandArgs
@@ -159,11 +159,8 @@ object Resolve {
159159
rest: Seq[String],
160160
nullCommandDefaults: Boolean,
161161
allowPositionalCommandArgs: Boolean
162-
): Iterable[Either[String, Command[_]]] = for {
163-
(cls, node) <- discover.classInfo
164-
if cls.isAssignableFrom(target.getClass)
165-
ep <- node.entryPoints
166-
if ep.name == name
162+
): Option[Either[String, Command[_]]] = for {
163+
ep <- discover.resolveEntrypoint(target.getClass, name)
167164
} yield {
168165
def withNullDefault(a: mainargs.ArgSig): mainargs.ArgSig = {
169166
if (a.default.nonEmpty) a
@@ -303,7 +300,8 @@ trait Resolve[T] {
303300
) match {
304301
case ResolveCore.Success(value) => Right(value)
305302
case ResolveCore.NotFound(segments, found, next, possibleNexts) =>
306-
val allPossibleNames = rootModule.millDiscover.allNames.toSet
303+
val allPossibleNames =
304+
rootModule.millDiscover.classInfo.values.flatMap(_.declaredNames).toSet
307305
Left(ResolveNotFoundHandler(
308306
selector = sel,
309307
segments = segments,

main/src/mill/main/MainModule.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,12 @@ trait MainModule extends BaseModule0 {
265265
val mainDataOpt = evaluator
266266
.rootModule
267267
.millDiscover
268-
.classInfo
269-
.get(t.ctx.enclosingCls)
270-
.flatMap(_.entryPoints.find(_.name == t.ctx.segments.last.value))
271-
.headOption
268+
.resolveEntrypoint(t.ctx.enclosingCls, t.ctx.segments.last.value)
272269

273270
mainDataOpt match {
274271
case Some(mainData) if mainData.renderedArgSigs.nonEmpty =>
275272
val rendered = mainargs.Renderer.formatMainMethodSignature(
276-
mainDataOpt.get,
273+
mainData,
277274
leftIndent = 2,
278275
totalWidth = 100,
279276
leftColWidth = mainargs.Renderer.getLeftColWidth(mainData.renderedArgSigs),

0 commit comments

Comments
 (0)