@@ -40,102 +40,149 @@ object Discover {
4040 def apply [T ](value : Map [Class [_], Seq [mainargs.MainData [_, _]]]): Discover =
4141 new Discover (value.view.mapValues((Nil , _)).toMap)
4242
43- def apply [T ]: Discover = ??? // macro Router.applyImpl[T]
43+ inline def apply [T ]: Discover = $ { Router .applyImpl[T ] }
4444
45- // private class Router(val ctx: blackbox.Context) extends mainargs.Macros(ctx) {
46- // import c.universe._
45+ private object Router {
46+ import quoted .*
47+ import mainargs .Macros .*
48+ import scala .util .control .NonFatal
4749
48- // def applyImpl[T: WeakTypeTag]: Expr[Discover] = {
49- // val seen = mutable.Set.empty[Type]
50- // def rec(tpe: Type): Unit = {
51- // if (!seen(tpe)) {
52- // seen.add(tpe)
53- // for {
54- // m <- tpe.members.toList.sortBy(_.name.toString)
55- // memberTpe = m.typeSignature
56- // if memberTpe.resultType <:< typeOf[mill.define.Module] && memberTpe.paramLists.isEmpty
57- // } rec(memberTpe.resultType)
50+ def applyImpl [T : Type ](using Quotes ): Expr [Discover ] = {
51+ import quotes .reflect .*
52+ val seen = mutable.Set .empty[TypeRepr ]
53+ val crossSym = Symbol .requiredClass(" mill.define.Cross" )
54+ val crossArg = crossSym.typeMembers.filter(_.isTypeParam).head
55+ val moduleSym = Symbol .requiredClass(" mill.define.Module" )
56+ val deprecatedSym = Symbol .requiredClass(" scala.deprecated" )
57+ def rec (tpe : TypeRepr ): Unit = {
58+ if (seen.add(tpe)) {
59+ val typeSym = tpe.typeSymbol
60+ for {
61+ // for some reason mill.define.Foreign has NoSymbol as type member.
62+ m <- typeSym.fieldMembers.filterNot(_ == Symbol .noSymbol).toList.sortBy(_.name.toString)
63+ memberTpe = {
64+ if m == Symbol .noSymbol then
65+ report.errorAndAbort(s " no symbol found in $typeSym typemembers ${typeSym.typeMembers}" , typeSym.pos.getOrElse(Position .ofMacroExpansion))
66+ // try tpe.memberType(m)
67+ // catch {
68+ // case NonFatal(err) =>
69+ // // report.errorAndAbort(s"Error getting member type for $m in $typeSym: ${err}", m.pos.getOrElse(Position.ofMacroExpansion))
70+ // tpe.memberType(m.typeRef.dealias.typeSymbol)
71+ // }
72+ m.termRef
73+ }
74+ if memberTpe.baseClasses.contains(moduleSym)
75+ } rec(memberTpe)
5876
59- // if (tpe <:< typeOf[mill.define.Cross[_]]) {
60- // val inner = typeOf[Cross[_]]
61- // .typeSymbol
62- // .asClass
63- // .typeParams
64- // .head
65- // .asType
66- // .toType
67- // .asSeenFrom(tpe, typeOf[Cross[_]].typeSymbol)
77+ if (tpe.baseClasses.contains(crossSym)) {
78+ val arg = tpe.memberType(crossArg)
79+ val argSym = arg.typeSymbol
80+ rec(tpe.memberType(argSym))
81+ }
82+ }
83+ }
84+ rec(TypeRepr .of[T ])
6885
69- // rec(inner)
70- // }
71- // }
72- // }
73- // rec(weakTypeOf[T])
86+ def methodReturn (tpe : TypeRepr ): TypeRepr = tpe match
87+ case MethodType (_, _, res) => res
88+ case ByNameType (tpe) => tpe
89+ case _ => tpe
7490
75- // def assertParamListCounts(
76- // methods: Iterable[MethodSymbol],
77- // cases: (Type, Int, String)*
78- // ): Unit = {
79- // for (m <- methods.toList) {
80- // cases
81- // .find { case (tt, n, label) =>
82- // m.returnType <:< tt && !(m.returnType <:< weakTypeOf[Nothing])
83- // }
84- // .foreach { case (tt, n, label) =>
85- // if (m.paramLists.length != n) c.abort(
86- // m.pos,
87- // s"$label definitions must have $n parameter list" + (if (n == 1) "" else "s")
88- // )
89- // }
90- // }
91- // }
91+ def assertParamListCounts (
92+ curCls : TypeRepr ,
93+ methods : Iterable [Symbol ],
94+ cases : (TypeRepr , Int , String )*
95+ ): Unit = {
96+ for (m <- methods.toList) {
97+ cases
98+ .find { case (tt, n, label) =>
99+ val mType = curCls.memberType(m)
100+ val returnType = methodReturn(mType)
101+ returnType <:< tt && ! (returnType <:< TypeRepr .of[Nothing ])
102+ }
103+ .foreach { case (tt, n, label) =>
104+ if (m.paramSymss.length != n) report.errorAndAbort(
105+ s " $label definitions must have $n parameter list " + (if (n == 1 ) " " else " s" ),
106+ m.pos.getOrElse(Position .ofMacroExpansion)
107+ )
108+ }
109+ }
110+ }
92111
93112 // Make sure we sort the types and methods to keep the output deterministic;
94113 // otherwise the compiler likes to give us stuff in random orders, which
95114 // causes the code to be generated in random order resulting in code hashes
96115 // changing unnecessarily
97- // val mapping = for {
98- // discoveredModuleType <- seen.toSeq.sortBy(_.typeSymbol.fullName)
99- // curCls = discoveredModuleType
100- // methods = getValsOrMeths(curCls)
101- // overridesRoutes = {
102- // assertParamListCounts(
103- // methods,
104- // (weakTypeOf[mill.define.Command[_]], 1, "`Task.Command`"),
105- // (weakTypeOf[mill.define.Target[_]], 0, "Target")
106- // )
116+ val mapping = for {
117+ discoveredModuleType <- seen.toSeq.sortBy(_.typeSymbol.fullName)
118+ curCls = discoveredModuleType
119+ methods = curCls.typeSymbol.methodMembers.filterNot(m => m.isSuperAccessor || m.hasAnnotation(deprecatedSym) || m.flags.is(Flags .Synthetic | Flags .Invisible | Flags .Private | Flags .Protected )) // getValsOrMeths(curCls) replaced by equivalent from Scala 3 mainargs
120+ overridesRoutes = {
121+ assertParamListCounts(
122+ curCls,
123+ methods,
124+ (TypeRepr .of[mill.define.Command [? ]], 1 , " `Task.Command`" ),
125+ (TypeRepr .of[mill.define.Target [? ]], 0 , " Target" )
126+ )
107127
108- // Tuple2(
109- // for {
110- // m <- methods.toList.sortBy(_.fullName)
111- // if m.returnType <:< weakTypeOf[mill.define.NamedTask[_]]
112- // } yield m.name.decoded,
113- // for {
114- // m <- methods.toList.sortBy(_.fullName)
115- // if m.returnType <:< weakTypeOf[mill.define.Command[_]]
116- // } yield extractMethod(
117- // m.name,
118- // m.paramLists.flatten,
119- // m.pos,
120- // m.annotations.find(_.tree.tpe =:= typeOf[mainargs.main]),
121- // curCls,
122- // weakTypeOf[Any]
123- // )
124- // )
125- // }
126- // if overridesRoutes._1.nonEmpty || overridesRoutes._2.nonEmpty
127- // } yield {
128- // // by wrapping the `overridesRoutes` in a lambda function we kind of work around
129- // // the problem of generating a *huge* macro method body that finally exceeds the
130- // // JVM's maximum allowed method size
131- // val overridesLambda = q"(() => $overridesRoutes)()"
132- // val lhs = q"classOf[${discoveredModuleType.typeSymbol.asClass}]"
133- // q"$lhs -> $overridesLambda"
134- // }
128+ def sortedMethods (sub : TypeRepr ): Seq [Symbol ] =
129+ for {
130+ m <- methods.toList.sortBy(_.fullName)
131+ mType = curCls.memberType(m)
132+ returnType = methodReturn(mType)
133+ if returnType <:< sub
134+ } yield m
135135
136- // c.Expr[Discover](
137- // q"_root_.mill.define.Discover.apply2(_root_.scala.collection.immutable.Map(..$mapping))"
138- // )
139- // }
140- // }
136+ Tuple2 (
137+ for {
138+ m <- sortedMethods(sub = TypeRepr .of[mill.define.NamedTask [? ]])
139+ } yield m.name,// .decoded // we don't need to decode the name in Scala 3
140+ for {
141+ m <- sortedMethods(sub = TypeRepr .of[mill.define.Command [? ]])
142+ } yield curCls.asType match {
143+ case ' [t] =>
144+ val expr =
145+ try
146+ createMainData[Any , t](
147+ m,
148+ m.annotations.find(_.tpe =:= TypeRepr .of[mainargs.main]).getOrElse(' {new mainargs.main()}.asTerm),
149+ m.paramSymss
150+ ).asExprOf[mainargs.MainData [? , ? ]]
151+ catch {
152+ case NonFatal (e) =>
153+ val (before, Array (after, _* )) = e.getStackTrace().span(e => ! (e.getClassName() == " mill.define.Discover$Router$" && e.getMethodName() == " applyImpl" )): @ unchecked
154+ val trace = (before :+ after).map(_.toString).mkString(" trace:\n " , " \n " , " \n ..." )
155+ report.errorAndAbort(s " Error generating maindata for ${m.fullName}: ${e}\n $trace" , m.pos.getOrElse(Position .ofMacroExpansion))
156+ }
157+ // report.warning(s"generated maindata for ${m.fullName}:\n${expr.asTerm.show}", m.pos.getOrElse(Position.ofMacroExpansion))
158+ expr
159+ }
160+ )
161+ }
162+ if overridesRoutes._1.nonEmpty || overridesRoutes._2.nonEmpty
163+ } yield {
164+ val (names, mainDataExprs) = overridesRoutes
165+ val mainDatas = Expr .ofList(mainDataExprs)
166+ // by wrapping the `overridesRoutes` in a lambda function we kind of work around
167+ // the problem of generating a *huge* macro method body that finally exceeds the
168+ // JVM's maximum allowed method size
169+ val overridesLambda = ' {
170+ def pair () = ($ {Expr (names)}, $mainDatas)
171+ pair()
172+ }
173+ val lhs = Ref (defn.Predef_classOf ).appliedToType(discoveredModuleType.widen).asExprOf[Class [? ]]
174+ ' {$lhs -> $overridesLambda}
175+ }
176+
177+ val expr : Expr [Discover ] =
178+ ' {
179+ // TODO: we can not import this here, so we have to import at the use site now, or redesign?
180+ // import mill.main.TokenReaders.*
181+ Discover .apply2(Map ($ {Varargs (mapping)}* ))
182+ }
183+ // TODO: if needed for debugging, we can re-enable this
184+ // report.warning(s"generated discovery for ${TypeRepr.of[T].show}:\n${expr.asTerm.show}", TypeRepr.of[T].typeSymbol.pos.getOrElse(Position.ofMacroExpansion))
185+ expr
186+ }
187+ }
141188}
0 commit comments