Skip to content

Commit 351fc0f

Browse files
Clean up Discover macro and codegen (#4461)
* Make `Discover` return a `class` (that can be evolved by adding fields) rather than a `Tuple` (which cannot) * Simplify handling of `millDiscover` flags, in particular we do not need them to be defined for subfolder base modules * Remove unused `ObjectDataInstrument`, `Snippet`, `ObjectData` * Remove `MILL_SPLICED_CODE_START_MARKER` --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 737dec9 commit 351fc0f

File tree

18 files changed

+135
-517
lines changed

18 files changed

+135
-517
lines changed

integration/failure/root-subfolder-module-collision/src/RootSubfolderModuleCollisionTests.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ object RootSubfolderModuleCollisionTests extends UtestIntegrationTestSuite {
1010
import tester._
1111
val res = eval(("resolve", "_"))
1212
assert(res.isSuccess == false)
13-
assert(res.err.contains("cannot override final member"))
14-
assert(res.err.contains(
15-
" final lazy val sub: _root_.build_.sub.package_.type = _root_.build_.sub.package_ // subfolder module referenc"
16-
))
13+
assert(res.err.contains("Reference to sub is ambiguous."))
14+
assert(res.err.contains("It is both defined in class package_"))
15+
assert(res.err.contains("and inherited subsequently in class package_"))
1716
}
1817
}
1918
}

integration/failure/subfolder-missing-build-prefix/resources/build.mill

Lines changed: 0 additions & 5 deletions
This file was deleted.

integration/failure/subfolder-missing-build-prefix/resources/sub/package.mill

Lines changed: 0 additions & 6 deletions
This file was deleted.

integration/failure/subfolder-missing-build-prefix/src/SubfolderMissingBuildPrefix.scala

Lines changed: 0 additions & 16 deletions
This file was deleted.

integration/feature/docannotations/src/DocAnnotationsTests.scala renamed to integration/feature/inspect/src/InspectTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package mill.integration
33
import mill.testkit.UtestIntegrationTestSuite
44
import utest._
55

6-
object DocAnnotationsTests extends UtestIntegrationTestSuite {
6+
object InspectTests extends UtestIntegrationTestSuite {
77
def globMatches(glob: String, input: String): Boolean = {
88
StringContext
99
.glob(

integration/feature/scala-3-syntax/resources/build.mill

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,27 @@ import $packages._
55
import $file.foo.Box
66
import $file.foo.{given Box[Int]}
77

8+
9+
given Cross.ToSegments[DayValue](d => List(d.toString))
10+
11+
given mainargs.TokensReader.Simple[DayValue] with
12+
def shortName = "day"
13+
14+
def read(strs: Seq[String]) =
15+
try
16+
Right(DayValue.valueOf(strs.head))
17+
catch
18+
case _: Exception => Left("not a day")
19+
20+
enum DayValue:
21+
case Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday
22+
823
object `package` extends RootModule:
924

1025
def someTopLevelCommand(): Command[Unit] = Task.Command:
1126
println(s"Hello, world! ${summon[Box[Int]]} ${build.sub.subTask()}")
1227
end someTopLevelCommand
1328

14-
given Cross.ToSegments[DayValue](d => List(d.toString))
15-
16-
given mainargs.TokensReader.Simple[DayValue] with
17-
def shortName = "day"
18-
def read(strs: Seq[String]) =
19-
try
20-
Right(DayValue.valueOf(strs.head))
21-
catch
22-
case _: Exception => Left("not a day")
23-
24-
enum DayValue:
25-
case Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday
2629

2730
object day extends Cross[DayModule](DayValue.values.toSeq)
2831

main/define/src/mill/define/Discover.scala

Lines changed: 61 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,13 @@ import scala.collection.mutable
1414
* the `Task.Command` methods we find. This mapping from `Class[_]` to `MainData`
1515
* can then be used later to look up the `MainData` for any module.
1616
*/
17-
case class Discover private (
18-
value: Map[
19-
Class[_],
20-
(Seq[String], Seq[mainargs.MainData[_, _]], Seq[String])
21-
],
22-
dummy: Int = 0 /* avoid conflict with Discover.apply(value: Map) below*/
23-
) {
24-
@deprecated("Binary compatibility shim", "Mill 0.11.4")
25-
private[define] def this(value: Map[Class[_], Seq[mainargs.MainData[_, _]]]) =
26-
this(value.view.mapValues((Nil, _, Nil)).toMap)
27-
@deprecated("Binary compatibility shim", "Mill 0.11.4")
28-
private[define] def copy(value: Map[Class[_], Seq[mainargs.MainData[_, _]]]): Discover = {
29-
new Discover(value.view.mapValues((Nil, _, Nil)).toMap, dummy)
30-
}
31-
}
17+
class Discover(val classInfo: Map[Class[_], Discover.Node], val allNames: Seq[String])
3218

3319
object Discover {
34-
def apply2[T](value: Map[Class[_], (Seq[String], Seq[mainargs.MainData[_, _]], Seq[String])])
35-
: Discover =
36-
new Discover(value)
37-
38-
@deprecated("Binary compatibility shim", "Mill 0.11.4")
39-
def apply[T](value: Map[Class[_], Seq[mainargs.MainData[_, _]]]): Discover =
40-
new Discover(value.view.mapValues((Nil, _, Nil)).toMap)
20+
class Node(
21+
val entryPoints: Seq[mainargs.MainData[_, _]],
22+
val declaredNames: Seq[String]
23+
)
4124

4225
inline def apply[T]: Discover = ${ Router.applyImpl[T] }
4326

@@ -46,7 +29,7 @@ object Discover {
4629
import mainargs.Macros.*
4730
import scala.util.control.NonFatal
4831

49-
def applyImpl[T: Type](using Quotes): Expr[Discover] = {
32+
def applyImpl[T: Type](using quotes: Quotes): Expr[Discover] = {
5033
import quotes.reflect.*
5134
val seen = mutable.Set.empty[TypeRepr]
5235
val moduleSym = Symbol.requiredClass("mill.define.Module")
@@ -62,10 +45,8 @@ object Discover {
6245
} {
6346
rec(memberTpe)
6447
memberTpe.asType match {
65-
case '[mill.define.Cross[m]] =>
66-
rec(TypeRepr.of[m])
67-
case _ =>
68-
() // no cross argument to extract
48+
case '[mill.define.Cross[m]] => rec(TypeRepr.of[m])
49+
case _ => () // no cross argument to extract
6950
}
7051
}
7152
}
@@ -107,89 +88,91 @@ object Discover {
10788
)
10889
)
10990

91+
def sortedMethods(curCls: TypeRepr, sub: TypeRepr, methods: Seq[Symbol]): Seq[Symbol] =
92+
for {
93+
m <- methods.toList.sortBy(_.fullName)
94+
mType = curCls.memberType(m)
95+
returnType = methodReturn(mType)
96+
if returnType <:< sub
97+
} yield m
98+
11099
// Make sure we sort the types and methods to keep the output deterministic;
111100
// otherwise the compiler likes to give us stuff in random orders, which
112101
// causes the code to be generated in random order resulting in code hashes
113102
// changing unnecessarily
114-
val mapping = for {
103+
val mapping: Seq[(Expr[(Class[_], Node)], Seq[String])] = for {
115104
discoveredModuleType <- seen.toSeq.sortBy(_.typeSymbol.fullName)
116105
curCls = discoveredModuleType
117106
methods = filterDefs(curCls.typeSymbol.methodMembers)
118107
declMethods = filterDefs(curCls.typeSymbol.declaredMethods)
119-
overridesRoutes = {
108+
_ = {
120109
assertParamListCounts(
121110
curCls,
122111
methods,
123112
(TypeRepr.of[mill.define.Command[?]], 1, "`Task.Command`"),
124113
(TypeRepr.of[mill.define.Target[?]], 0, "Target")
125114
)
115+
}
126116

127-
def sortedMethods(sub: TypeRepr, methods: Seq[Symbol] = methods): Seq[Symbol] =
128-
for {
129-
m <- methods.toList.sortBy(_.fullName)
130-
mType = curCls.memberType(m)
131-
returnType = methodReturn(mType)
132-
if returnType <:< sub
133-
} yield m
134-
135-
Tuple3(
136-
for {
137-
m <- sortedMethods(sub = TypeRepr.of[mill.define.NamedTask[?]])
138-
} yield m.name, // .decoded // we don't need to decode the name in Scala 3
139-
for {
140-
m <- sortedMethods(sub = TypeRepr.of[mill.define.Command[?]])
141-
} yield curCls.asType match {
142-
case '[t] =>
143-
val expr =
144-
try
145-
createMainData[Any, t](
146-
m,
147-
m.annotations.find(_.tpe =:= TypeRepr.of[mainargs.main]).getOrElse('{
148-
new mainargs.main()
149-
}.asTerm),
150-
m.paramSymss
151-
).asExprOf[mainargs.MainData[?, ?]]
152-
catch {
153-
case NonFatal(e) =>
154-
val (before, Array(after, _*)) = e.getStackTrace().span(e =>
155-
!(e.getClassName() == "mill.define.Discover$Router$" && e.getMethodName() == "applyImpl")
156-
): @unchecked
157-
val trace =
158-
(before :+ after).map(_.toString).mkString("trace:\n", "\n", "\n...")
159-
report.errorAndAbort(
160-
s"Error generating maindata for ${m.fullName}: ${e}\n$trace",
161-
m.pos.getOrElse(Position.ofMacroExpansion)
162-
)
163-
}
164-
// report.warning(s"generated maindata for ${m.fullName}:\n${expr.asTerm.show}", m.pos.getOrElse(Position.ofMacroExpansion))
165-
expr
166-
},
167-
for
168-
m <- sortedMethods(sub = TypeRepr.of[mill.define.Task[?]], methods = declMethods)
169-
yield m.name.toString
170-
)
117+
names =
118+
sortedMethods(curCls, sub = TypeRepr.of[mill.define.NamedTask[?]], methods).map(_.name)
119+
entryPoints = for {
120+
m <- sortedMethods(curCls, sub = TypeRepr.of[mill.define.Command[?]], methods)
121+
} yield curCls.asType match {
122+
case '[t] =>
123+
val expr =
124+
try
125+
createMainData[Any, t](
126+
m,
127+
m.annotations.find(_.tpe =:= TypeRepr.of[mainargs.main]).getOrElse('{
128+
new mainargs.main()
129+
}.asTerm),
130+
m.paramSymss
131+
).asExprOf[mainargs.MainData[?, ?]]
132+
catch {
133+
case NonFatal(e) =>
134+
val (before, Array(after, _*)) = e.getStackTrace().span(e =>
135+
!(e.getClassName() == "mill.define.Discover$Router$" && e.getMethodName() == "applyImpl")
136+
): @unchecked
137+
val trace =
138+
(before :+ after).map(_.toString).mkString("trace:\n", "\n", "\n...")
139+
report.errorAndAbort(
140+
s"Error generating maindata for ${m.fullName}: ${e}\n$trace",
141+
m.pos.getOrElse(Position.ofMacroExpansion)
142+
)
143+
}
144+
expr
171145
}
172-
if overridesRoutes._1.nonEmpty || overridesRoutes._2.nonEmpty || overridesRoutes._3.nonEmpty
146+
declaredNames =
147+
sortedMethods(
148+
curCls,
149+
sub = TypeRepr.of[mill.define.NamedTask[?]],
150+
declMethods
151+
).map(_.name)
152+
if names.nonEmpty || entryPoints.nonEmpty
173153
} yield {
174-
val (names, mainDataExprs, taskNames) = overridesRoutes
175154
// by wrapping the `overridesRoutes` in a lambda function we kind of work around
176155
// the problem of generating a *huge* macro method body that finally exceeds the
177156
// JVM's maximum allowed method size
178157
val overridesLambda = '{
179-
def triple() = (${ Expr(names) }, ${ Expr.ofList(mainDataExprs) }, ${ Expr(taskNames) })
158+
def triple() =
159+
new Node(${ Expr.ofList(entryPoints) }, ${ Expr(declaredNames) })
180160
triple()
181161
}
182162
val lhs =
183163
Ref(defn.Predef_classOf).appliedToType(discoveredModuleType.widen).asExprOf[Class[?]]
184-
'{ $lhs -> $overridesLambda }
164+
('{ $lhs -> $overridesLambda }, names)
185165
}
186166

187167
val expr: Expr[Discover] =
188168
'{
189169
// TODO: we can not import this here, so we have to import at the use site now, or redesign?
190170
// import mill.main.TokenReaders.*
191171
// import mill.api.JsonFormatters.*
192-
Discover.apply2(Map(${ Varargs(mapping) }*))
172+
new Discover(
173+
Map[Class[_], Node](${ Varargs(mapping.map(_._1)) }*),
174+
${ Expr(mapping.iterator.flatMap(_._2).distinct.toList.sorted) }
175+
)
193176
}
194177
// TODO: if needed for debugging, we can re-enable this
195178
// report.warning(s"generated discovery for ${TypeRepr.of[T].show}:\n${expr.asTerm.show}", TypeRepr.of[T].typeSymbol.pos.getOrElse(Position.ofMacroExpansion))

main/resolve/src/mill/resolve/Resolve.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ object Resolve {
160160
nullCommandDefaults: Boolean,
161161
allowPositionalCommandArgs: Boolean
162162
): Iterable[Either[String, Command[_]]] = for {
163-
(cls, (names, entryPoints, _)) <- discover.value
163+
(cls, node) <- discover.classInfo
164164
if cls.isAssignableFrom(target.getClass)
165-
ep <- entryPoints
165+
ep <- node.entryPoints
166166
if ep.name == name
167167
} yield {
168168
def withNullDefault(a: mainargs.ArgSig): mainargs.ArgSig = {
@@ -303,7 +303,7 @@ trait Resolve[T] {
303303
) match {
304304
case ResolveCore.Success(value) => Right(value)
305305
case ResolveCore.NotFound(segments, found, next, possibleNexts) =>
306-
val allPossibleNames = rootModule.millDiscover.value.values.flatMap(_._1).toSet
306+
val allPossibleNames = rootModule.millDiscover.allNames.toSet
307307
Left(ResolveNotFoundHandler(
308308
selector = sel,
309309
segments = segments,

0 commit comments

Comments
 (0)