Skip to content

Commit d5e3d64

Browse files
committed
Part 2 - reimplement discover macro
- add import mill.given - fix summon of Discover in CodeGen
1 parent 74e6df9 commit d5e3d64

File tree

11 files changed

+208
-110
lines changed

11 files changed

+208
-110
lines changed

bsp/src/mill/bsp/BSP.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package mill.bsp
22

3+
import mill.given
34
import mill.api.{Ctx, PathRef}
4-
import mill.{Agg, T, Task}
5+
import mill.{Agg, T, Task, given}
56
import mill.define.{Command, Discover, ExternalModule}
67
import mill.main.BuildInfo
78
import mill.eval.Evaluator

bsp/worker/src/mill/bsp/worker/MillBuildServer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import mill.main.MainModule
1414
import mill.runner.MillBuildRootModule
1515
import mill.scalalib.bsp.{BspModule, JvmBuildTarget, ScalaBuildTarget}
1616
import mill.scalalib.{JavaModule, SemanticDbJavaModule, TestModule}
17+
import mill.given
1718

1819
import java.io.PrintStream
1920
import java.util.concurrent.CompletableFuture

idea/src/mill/idea/GenIdea.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package mill.idea
22

3+
import mill.given
34
import mill.Task
45
import mill.api.Result
56
import mill.define.{Command, Discover, ExternalModule}

main/api/src/mill/api/Retry.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ case class Retry(
3838
if (timeoutMillis == -1) t(retryCount)
3939
else {
4040
val result = Promise[T]
41-
val thread = new Thread(() => {
41+
val thread = new Thread({() =>
4242
result.complete(scala.util.Try(t(retryCount)))
43-
})
43+
}: Runnable)
4444
thread.start()
4545
Await.result(result.future, Duration.apply(timeoutMillis, TimeUnit.MILLISECONDS))
4646
}

main/define/src/mill/define/Discover.scala

Lines changed: 133 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -40,102 +40,149 @@ object Discover {
4040
def apply[T](value: Map[Class[_], Seq[mainargs.MainData[_, _]]]): Discover =
4141
new Discover(value.view.mapValues((Nil, _)).toMap)
4242

43-
def apply[T]: Discover = ??? // macro Router.applyImpl[T]
43+
inline def apply[T]: Discover = ${ Router.applyImpl[T] }
4444

45-
// private class Router(val ctx: blackbox.Context) extends mainargs.Macros(ctx) {
46-
// import c.universe._
45+
private object Router {
46+
import quoted.*
47+
import mainargs.Macros.*
48+
import scala.util.control.NonFatal
4749

48-
// def applyImpl[T: WeakTypeTag]: Expr[Discover] = {
49-
// val seen = mutable.Set.empty[Type]
50-
// def rec(tpe: Type): Unit = {
51-
// if (!seen(tpe)) {
52-
// seen.add(tpe)
53-
// for {
54-
// m <- tpe.members.toList.sortBy(_.name.toString)
55-
// memberTpe = m.typeSignature
56-
// if memberTpe.resultType <:< typeOf[mill.define.Module] && memberTpe.paramLists.isEmpty
57-
// } rec(memberTpe.resultType)
50+
def applyImpl[T: Type](using Quotes): Expr[Discover] = {
51+
import quotes.reflect.*
52+
val seen = mutable.Set.empty[TypeRepr]
53+
val crossSym = Symbol.requiredClass("mill.define.Cross")
54+
val crossArg = crossSym.typeMembers.filter(_.isTypeParam).head
55+
val moduleSym = Symbol.requiredClass("mill.define.Module")
56+
val deprecatedSym = Symbol.requiredClass("scala.deprecated")
57+
def rec(tpe: TypeRepr): Unit = {
58+
if (seen.add(tpe)) {
59+
val typeSym = tpe.typeSymbol
60+
for {
61+
// for some reason mill.define.Foreign has NoSymbol as type member.
62+
m <- typeSym.fieldMembers.filterNot(_ == Symbol.noSymbol).toList.sortBy(_.name.toString)
63+
memberTpe = {
64+
if m == Symbol.noSymbol then
65+
report.errorAndAbort(s"no symbol found in $typeSym typemembers ${typeSym.typeMembers}", typeSym.pos.getOrElse(Position.ofMacroExpansion))
66+
// try tpe.memberType(m)
67+
// catch {
68+
// case NonFatal(err) =>
69+
// // report.errorAndAbort(s"Error getting member type for $m in $typeSym: ${err}", m.pos.getOrElse(Position.ofMacroExpansion))
70+
// tpe.memberType(m.typeRef.dealias.typeSymbol)
71+
// }
72+
m.termRef
73+
}
74+
if memberTpe.baseClasses.contains(moduleSym)
75+
} rec(memberTpe)
5876

59-
// if (tpe <:< typeOf[mill.define.Cross[_]]) {
60-
// val inner = typeOf[Cross[_]]
61-
// .typeSymbol
62-
// .asClass
63-
// .typeParams
64-
// .head
65-
// .asType
66-
// .toType
67-
// .asSeenFrom(tpe, typeOf[Cross[_]].typeSymbol)
77+
if (tpe.baseClasses.contains(crossSym)) {
78+
val arg = tpe.memberType(crossArg)
79+
val argSym = arg.typeSymbol
80+
rec(tpe.memberType(argSym))
81+
}
82+
}
83+
}
84+
rec(TypeRepr.of[T])
6885

69-
// rec(inner)
70-
// }
71-
// }
72-
// }
73-
// rec(weakTypeOf[T])
86+
def methodReturn(tpe: TypeRepr): TypeRepr = tpe match
87+
case MethodType(_, _, res) => res
88+
case ByNameType(tpe) => tpe
89+
case _ => tpe
7490

75-
// def assertParamListCounts(
76-
// methods: Iterable[MethodSymbol],
77-
// cases: (Type, Int, String)*
78-
// ): Unit = {
79-
// for (m <- methods.toList) {
80-
// cases
81-
// .find { case (tt, n, label) =>
82-
// m.returnType <:< tt && !(m.returnType <:< weakTypeOf[Nothing])
83-
// }
84-
// .foreach { case (tt, n, label) =>
85-
// if (m.paramLists.length != n) c.abort(
86-
// m.pos,
87-
// s"$label definitions must have $n parameter list" + (if (n == 1) "" else "s")
88-
// )
89-
// }
90-
// }
91-
// }
91+
def assertParamListCounts(
92+
curCls: TypeRepr,
93+
methods: Iterable[Symbol],
94+
cases: (TypeRepr, Int, String)*
95+
): Unit = {
96+
for (m <- methods.toList) {
97+
cases
98+
.find { case (tt, n, label) =>
99+
val mType = curCls.memberType(m)
100+
val returnType = methodReturn(mType)
101+
returnType <:< tt && !(returnType <:< TypeRepr.of[Nothing])
102+
}
103+
.foreach { case (tt, n, label) =>
104+
if (m.paramSymss.length != n) report.errorAndAbort(
105+
s"$label definitions must have $n parameter list" + (if (n == 1) "" else "s"),
106+
m.pos.getOrElse(Position.ofMacroExpansion)
107+
)
108+
}
109+
}
110+
}
92111

93112
// Make sure we sort the types and methods to keep the output deterministic;
94113
// otherwise the compiler likes to give us stuff in random orders, which
95114
// causes the code to be generated in random order resulting in code hashes
96115
// changing unnecessarily
97-
// val mapping = for {
98-
// discoveredModuleType <- seen.toSeq.sortBy(_.typeSymbol.fullName)
99-
// curCls = discoveredModuleType
100-
// methods = getValsOrMeths(curCls)
101-
// overridesRoutes = {
102-
// assertParamListCounts(
103-
// methods,
104-
// (weakTypeOf[mill.define.Command[_]], 1, "`Task.Command`"),
105-
// (weakTypeOf[mill.define.Target[_]], 0, "Target")
106-
// )
116+
val mapping = for {
117+
discoveredModuleType <- seen.toSeq.sortBy(_.typeSymbol.fullName)
118+
curCls = discoveredModuleType
119+
methods = curCls.typeSymbol.methodMembers.filterNot(m => m.isSuperAccessor || m.hasAnnotation(deprecatedSym) || m.flags.is(Flags.Synthetic | Flags.Invisible | Flags.Private | Flags.Protected)) // getValsOrMeths(curCls) replaced by equivalent from Scala 3 mainargs
120+
overridesRoutes = {
121+
assertParamListCounts(
122+
curCls,
123+
methods,
124+
(TypeRepr.of[mill.define.Command[?]], 1, "`Task.Command`"),
125+
(TypeRepr.of[mill.define.Target[?]], 0, "Target")
126+
)
107127

108-
// Tuple2(
109-
// for {
110-
// m <- methods.toList.sortBy(_.fullName)
111-
// if m.returnType <:< weakTypeOf[mill.define.NamedTask[_]]
112-
// } yield m.name.decoded,
113-
// for {
114-
// m <- methods.toList.sortBy(_.fullName)
115-
// if m.returnType <:< weakTypeOf[mill.define.Command[_]]
116-
// } yield extractMethod(
117-
// m.name,
118-
// m.paramLists.flatten,
119-
// m.pos,
120-
// m.annotations.find(_.tree.tpe =:= typeOf[mainargs.main]),
121-
// curCls,
122-
// weakTypeOf[Any]
123-
// )
124-
// )
125-
// }
126-
// if overridesRoutes._1.nonEmpty || overridesRoutes._2.nonEmpty
127-
// } yield {
128-
// // by wrapping the `overridesRoutes` in a lambda function we kind of work around
129-
// // the problem of generating a *huge* macro method body that finally exceeds the
130-
// // JVM's maximum allowed method size
131-
// val overridesLambda = q"(() => $overridesRoutes)()"
132-
// val lhs = q"classOf[${discoveredModuleType.typeSymbol.asClass}]"
133-
// q"$lhs -> $overridesLambda"
134-
// }
128+
def sortedMethods(sub: TypeRepr): Seq[Symbol] =
129+
for {
130+
m <- methods.toList.sortBy(_.fullName)
131+
mType = curCls.memberType(m)
132+
returnType = methodReturn(mType)
133+
if returnType <:< sub
134+
} yield m
135135

136-
// c.Expr[Discover](
137-
// q"_root_.mill.define.Discover.apply2(_root_.scala.collection.immutable.Map(..$mapping))"
138-
// )
139-
// }
140-
// }
136+
Tuple2(
137+
for {
138+
m <- sortedMethods(sub = TypeRepr.of[mill.define.NamedTask[?]])
139+
} yield m.name,//.decoded // we don't need to decode the name in Scala 3
140+
for {
141+
m <- sortedMethods(sub = TypeRepr.of[mill.define.Command[?]])
142+
} yield curCls.asType match {
143+
case '[t] =>
144+
val expr =
145+
try
146+
createMainData[Any, t](
147+
m,
148+
m.annotations.find(_.tpe =:= TypeRepr.of[mainargs.main]).getOrElse('{new mainargs.main()}.asTerm),
149+
m.paramSymss
150+
).asExprOf[mainargs.MainData[?, ?]]
151+
catch {
152+
case NonFatal(e) =>
153+
val (before, Array(after, _*)) = e.getStackTrace().span(e => !(e.getClassName() == "mill.define.Discover$Router$" && e.getMethodName() == "applyImpl")): @unchecked
154+
val trace = (before :+ after).map(_.toString).mkString("trace:\n", "\n", "\n...")
155+
report.errorAndAbort(s"Error generating maindata for ${m.fullName}: ${e}\n$trace", m.pos.getOrElse(Position.ofMacroExpansion))
156+
}
157+
// report.warning(s"generated maindata for ${m.fullName}:\n${expr.asTerm.show}", m.pos.getOrElse(Position.ofMacroExpansion))
158+
expr
159+
}
160+
)
161+
}
162+
if overridesRoutes._1.nonEmpty || overridesRoutes._2.nonEmpty
163+
} yield {
164+
val (names, mainDataExprs) = overridesRoutes
165+
val mainDatas = Expr.ofList(mainDataExprs)
166+
// by wrapping the `overridesRoutes` in a lambda function we kind of work around
167+
// the problem of generating a *huge* macro method body that finally exceeds the
168+
// JVM's maximum allowed method size
169+
val overridesLambda = '{
170+
def pair() = (${Expr(names)}, $mainDatas)
171+
pair()
172+
}
173+
val lhs = Ref(defn.Predef_classOf).appliedToType(discoveredModuleType.widen).asExprOf[Class[?]]
174+
'{$lhs -> $overridesLambda}
175+
}
176+
177+
val expr: Expr[Discover] =
178+
'{
179+
// TODO: we can not import this here, so we have to import at the use site now, or redesign?
180+
// import mill.main.TokenReaders.*
181+
Discover.apply2(Map(${Varargs(mapping)}*))
182+
}
183+
// TODO: if needed for debugging, we can re-enable this
184+
// report.warning(s"generated discovery for ${TypeRepr.of[T].show}:\n${expr.asTerm.show}", TypeRepr.of[T].typeSymbol.pos.getOrElse(Position.ofMacroExpansion))
185+
expr
186+
}
187+
}
141188
}

main/test/src/mill/main/MainModuleTests.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package mill.main
22

3+
import mill.given
34
import mill.api.{PathRef, Result, Val}
4-
import mill.{Agg, T, Task}
5+
import mill.{Agg, T, Task, given}
56
import mill.define.{Cross, Discover, Module}
67
import mill.testkit.UnitTester
78
import mill.testkit.TestBaseModule

0 commit comments

Comments
 (0)