Skip to content

Commit b16e3f9

Browse files
committed
Part 8 - implement Cross.Factory macro (TODO split shims out)
- fix scanning of Cross modules in Discover macro
1 parent 87b44bc commit b16e3f9

File tree

5 files changed

+397
-111
lines changed

5 files changed

+397
-111
lines changed

example/depth/cross/6-axes-extension/build.mill

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ trait FooModule3 extends FooModule2 with Cross.Module3[String, Int, Boolean] {
5656

5757
> mill show foo3[b,2,false].param3
5858
error: ...object foo3 extends Cross[FooModule3](("a", 1), ("b", 2))
59-
error: ... ^
60-
error: ...value _3 is not a member of (String, Int)
61-
*/
59+
error: ... ^^^^^^^^
60+
error: ...expected at least 3 elements, got 2...
61+
error: ...object foo3 extends Cross[FooModule3](("a", 1), ("b", 2))
62+
error: ... ^^^^^^^^
63+
error: ...expected at least 3 elements, got 2...
64+
*/

main/define/src/mill/define/Cross.scala

Lines changed: 252 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ package mill.define
22

33
import mill.api.{BuildScriptException, Lazy}
44

5-
import language.experimental.macros
65
import scala.collection.mutable
76
import scala.reflect.ClassTag
8-
import scala.reflect.macros.blackbox
7+
8+
import scala.quoted.*
99

1010
object Cross {
1111

@@ -136,100 +136,264 @@ object Cross {
136136
* expression of type `Any`, but type-checking on the macro- expanded code
137137
* provides some degree of type-safety.
138138
*/
139-
implicit def make[M <: Module[_]](t: Any): Factory[M] = ??? // macro makeImpl[M]
140-
// def makeImpl[T: c.WeakTypeTag](c: blackbox.Context)(t: c.Expr[Any]): c.Expr[Factory[T]] = {
141-
// import c.universe._
142-
// val tpe = weakTypeOf[T]
143-
144-
// if (!tpe.typeSymbol.isClass) {
145-
// c.abort(c.enclosingPosition, s"Cross type $tpe must be trait")
146-
// }
147-
148-
// if (!tpe.typeSymbol.asClass.isTrait) abortOldStyleClass(c)(tpe)
149-
150-
// val wrappedT = if (t.tree.tpe <:< typeOf[Seq[_]]) t.tree else q"_root_.scala.Seq($t)"
151-
// val v1 = c.freshName(TermName("v1"))
152-
// val ctx0 = c.freshName(TermName("ctx0"))
153-
// val concreteCls = c.freshName(TypeName(tpe.typeSymbol.name.toString))
154-
155-
// val newTrees = collection.mutable.Buffer.empty[Tree]
156-
// var valuesTree: Tree = null
157-
// var pathSegmentsTree: Tree = null
158-
159-
// val segments = q"_root_.mill.define.Cross.ToSegments"
160-
// if (tpe <:< typeOf[Module[_]]) {
161-
// newTrees.append(q"override def crossValue = $v1")
162-
// pathSegmentsTree = q"$segments($v1)"
163-
// valuesTree = q"$wrappedT.map(List(_))"
164-
// } else c.abort(
165-
// c.enclosingPosition,
166-
// s"Cross type $tpe must implement Cross.Module[T]"
167-
// )
168-
169-
// if (tpe <:< typeOf[Module2[_, _]]) {
170-
// // For `Module2` and above, `crossValue` is no longer the entire value,
171-
// // but instead is just the first element of a tuple
172-
// newTrees.clear()
173-
// newTrees.append(q"override def crossValue = $v1._1")
174-
// newTrees.append(q"override def crossValue2 = $v1._2")
175-
// pathSegmentsTree = q"$segments($v1._1) ++ $segments($v1._2)"
176-
// valuesTree = q"$wrappedT.map(_.productIterator.toList)"
177-
// }
178-
179-
// if (tpe <:< typeOf[Module3[_, _, _]]) {
180-
// newTrees.append(q"override def crossValue3 = $v1._3")
181-
// pathSegmentsTree = q"$segments($v1._1) ++ $segments($v1._2) ++ $segments($v1._3)"
182-
// }
183-
184-
// if (tpe <:< typeOf[Module4[_, _, _, _]]) {
185-
// newTrees.append(q"override def crossValue4 = $v1._4")
186-
// pathSegmentsTree =
187-
// q"$segments($v1._1) ++ $segments($v1._2) ++ $segments($v1._3) ++ $segments($v1._4)"
188-
// }
189-
190-
// if (tpe <:< typeOf[Module5[_, _, _, _, _]]) {
191-
// newTrees.append(q"override def crossValue5 = $v1._5")
192-
// pathSegmentsTree =
193-
// q"$segments($v1._1) ++ $segments($v1._2) ++ $segments($v1._3) ++ $segments($v1._4) ++ $segments($v1._5)"
194-
// }
195-
196-
// // We need to create a `class $concreteCls` here, rather than just
197-
// // creating an anonymous sub-type of $tpe, because our task resolution
198-
// // logic needs to use java reflection to identify sub-modules and java
199-
// // reflect can only properly identify nested `object`s inside Scala
200-
// // `object` and `class`es.
201-
// val tree = q"""
202-
// new mill.define.Cross.Factory[$tpe](
203-
// makeList = $wrappedT.map{($v1: ${tq""}) =>
204-
// class $concreteCls()(implicit ctx: mill.define.Ctx) extends $tpe{..$newTrees}
205-
// (classOf[$concreteCls], ($ctx0: ${tq""}) => new $concreteCls()($ctx0))
206-
// },
207-
// crossSegmentsList = $wrappedT.map(($v1: ${tq""}) => $pathSegmentsTree ),
208-
// crossValuesListLists = $valuesTree,
209-
// crossValuesRaw = $wrappedT
210-
// ).asInstanceOf[${weakTypeOf[Factory[T]]}]
211-
// """
212-
213-
// c.Expr[Factory[T]](tree)
214-
// }
215-
216-
def abortOldStyleClass(c: blackbox.Context)(tpe: c.Type): Nothing = {
139+
implicit inline def make[M <: Module[_]](inline t: Any): Factory[M] = ${ makeImpl[M]('t) }
140+
def makeImpl[T: Type](using Quotes)(t: Expr[Any]): Expr[Factory[T]] = {
141+
import quotes.reflect.*
142+
143+
val shims = ShimService.reflect
144+
145+
val tpe = TypeRepr.of[T]
146+
147+
val cls = tpe.classSymbol.getOrElse(
148+
report.errorAndAbort(s"Cross type ${tpe.show} must be trait", Position.ofMacroExpansion)
149+
)
150+
151+
if (!cls.flags.is(Flags.Trait)) abortOldStyleClass(tpe)
152+
153+
val wrappedT: Expr[Seq[Any]] = t match
154+
case '{ $t1: Seq[elems] } => t1
155+
case '{ $t1: t1 } => '{ Seq.apply($t1) }
156+
157+
val elems0: Type[?] = t match {
158+
case '{ $t1: Seq[elems] } => TypeRepr.of[elems].widen.asType
159+
case '{ $t1: elems } => TypeRepr.of[elems].widen.asType
160+
}
161+
val elemTypes: (Expr[Seq[Seq[Any]]], Seq[(Type[?], Expr[?] => Expr[?])]) = {
162+
def select[T: Type](n: Int): Expr[?] => Expr[T] = {
163+
elems0 match {
164+
case '[type elems1 <: NonEmptyTuple; `elems1`] =>
165+
arg => arg match {
166+
case '{ $arg: `elems1` } =>
167+
'{ $arg.apply(${Expr(n)}) }.asExprOf[T]
168+
}
169+
}
170+
}
171+
def asSeq(tpe: Type[?], n: Int): Seq[(Type[?], Expr[?] => Expr[?])] = tpe match {
172+
case '[t *: ts] => (Type.of[t], select[t](n)) +: asSeq(Type.of[ts], n + 1)
173+
case '[EmptyTuple] => Nil
174+
}
175+
elems0 match {
176+
case '[type elems <: Tuple; `elems`] =>
177+
val wrappedElems = wrappedT.asExprOf[Seq[elems]]
178+
(
179+
'{ $wrappedElems.map(_.productIterator.toList) },
180+
asSeq(elems0, 0)
181+
)
182+
case '[t] =>
183+
(
184+
'{ $wrappedT.map(List(_)) },
185+
List((Type.of[t], identity))
186+
)
187+
}
188+
}
189+
190+
def exPair(n: Int): (Type[?], Expr[?] => Expr[?]) = {
191+
elemTypes(1).lift(n).getOrElse(
192+
report.errorAndAbort(
193+
s"expected at least ${n + 1} elements, got ${elemTypes(1).size}",
194+
Position.ofMacroExpansion
195+
)
196+
)
197+
}
198+
199+
def exType(n: Int): TypeRepr = {
200+
val (elemType, _) = exPair(n)
201+
elemType match
202+
case '[t] => TypeRepr.of[t]
203+
}
204+
205+
def exTerm(n: Int): Expr[?] => Expr[?] = {
206+
exPair(n)(1)
207+
}
208+
209+
def mkSegmentsCall[T: Type](t: Expr[T]): Expr[List[String]] = {
210+
val summonCall = Expr.summon[ToSegments[T]].getOrElse(
211+
report.errorAndAbort(s"Could not summon ToSegments[${Type.show[T]}]", Position.ofMacroExpansion)
212+
)
213+
'{mill.define.Cross.ToSegments[T]($t)(using $summonCall) }
214+
}
215+
216+
def mkSegmentsCallN(n: Int)(arg: Expr[?]): Expr[List[String]] = {
217+
exTerm(n)(arg) match {
218+
case '{ $v1: t1 } => mkSegmentsCall[t1](v1)
219+
}
220+
}
221+
222+
def newGetter(name: String, res: TypeRepr, flags: Flags = Flags.Override): Symbol => Symbol =
223+
cls =>
224+
Symbol.newMethod(
225+
parent = cls,
226+
name = name,
227+
tpe = ByNameType(res),
228+
flags = flags,
229+
privateWithin = Symbol.noSymbol
230+
)
231+
def newField(name: String, res: TypeRepr, flags: Flags): Symbol => Symbol =
232+
cls =>
233+
Symbol.newVal(
234+
parent = cls,
235+
name = name,
236+
tpe = res,
237+
flags = flags,
238+
privateWithin = Symbol.noSymbol
239+
)
240+
241+
def newGetterTree(name: String, rhs: Expr[?] => Expr[?]): (Symbol, Expr[?]) => Statement = {
242+
(cls, arg) =>
243+
val sym = cls.declaredMethod(name)
244+
.headOption
245+
.getOrElse(report.errorAndAbort(s"could not find method $name in $cls", Position.ofMacroExpansion))
246+
DefDef(sym, _ => Some(rhs(arg).asTerm))
247+
}
248+
249+
def newValTree(name: String, rhs: Option[Term]): (Symbol, Expr[?]) => Statement = {
250+
(cls, _) =>
251+
val sym = {
252+
val sym0 = cls.declaredField(name)
253+
if sym0 != Symbol.noSymbol then sym0
254+
else report.errorAndAbort(s"could not find field $name in $cls", Position.ofMacroExpansion)
255+
}
256+
ValDef(sym, rhs)
257+
}
258+
259+
extension (sym: Symbol) {
260+
def mkRef(debug: => String): Ref = {
261+
if sym.isTerm then
262+
Ref(sym)
263+
else
264+
report.errorAndAbort(s"could not ref ${debug}, it was not a term")
265+
}
266+
}
267+
268+
val newSyms = List.newBuilder[Symbol => Symbol]
269+
val newTrees = collection.mutable.Buffer.empty[(Symbol, Expr[?]) => Statement]
270+
val valuesTree: Expr[Seq[Seq[Any]]] = elemTypes(0)
271+
val pathSegmentsTrees = List.newBuilder[Expr[?] => Expr[List[String]]]
272+
273+
def pushElemTrees(n: Int): Unit = {
274+
val name = s"crossValue${if n > 0 then (n + 1).toString else ""}"
275+
newSyms += newGetter(name, res = exType(n))
276+
newTrees += newGetterTree(name, rhs = exTerm(n))
277+
pathSegmentsTrees += mkSegmentsCallN(n)
278+
}
279+
280+
newSyms += newField(
281+
"local_ctx",
282+
res = TypeRepr.of[mill.define.Ctx],
283+
flags = Flags.PrivateLocal | Flags.ParamAccessor)
284+
285+
newTrees += newValTree("local_ctx", rhs = None)
286+
287+
if tpe <:< TypeRepr.of[Module[?]] then
288+
pushElemTrees(0)
289+
else
290+
report.errorAndAbort(
291+
s"Cross type ${tpe.show} must implement Cross.Module[T]",
292+
Position.ofMacroExpansion
293+
)
294+
295+
if tpe <:< TypeRepr.of[Module2[?, ?]] then
296+
pushElemTrees(1)
297+
298+
if tpe <:< TypeRepr.of[Module3[?, ?, ?]] then
299+
pushElemTrees(2)
300+
301+
if (tpe <:< TypeRepr.of[Module4[?, ?, ?, ?]])
302+
pushElemTrees(3)
303+
304+
if (tpe <:< TypeRepr.of[Module5[?, ?, ?, ?, ?]])
305+
pushElemTrees(4)
306+
307+
val pathSegmentsTree: Expr[?] => Expr[List[String]] =
308+
pathSegmentsTrees.result().reduceLeft((a, b) => arg => '{ ${a(arg)} ++ ${b(arg)} })
309+
310+
def newCtor(cls: Symbol): (List[String], List[TypeRepr]) =
311+
(List("local_ctx"), List(TypeRepr.of[mill.define.Ctx]))
312+
313+
def newClassDecls(cls: Symbol): List[Symbol] = {
314+
newSyms.result().map(_(cls))
315+
}
316+
317+
def clsFactory()(using Quotes): Symbol = {
318+
shims.Symbol.newClass(
319+
parent = cls,
320+
name = s"${cls.name}_impl",
321+
parents = List(TypeRepr.of[mill.define.Module.BaseClass], tpe),
322+
ctor = newCtor,
323+
decls = newClassDecls,
324+
selfType = None
325+
)
326+
}
327+
328+
// We need to create a `class $concreteCls` here, rather than just
329+
// creating an anonymous sub-type of $tpe, because our task resolution
330+
// logic needs to use java reflection to identify sub-modules and java
331+
// reflect can only properly identify nested `object`s inside Scala
332+
// `object` and `class`es.
333+
elems0 match {
334+
case '[elems] =>
335+
val wrappedElems = wrappedT.asExprOf[Seq[elems]]
336+
val ref = '{
337+
new mill.define.Cross.Factory[T](
338+
makeList = $wrappedElems.map { (v2: elems) =>
339+
${
340+
val concreteCls = clsFactory()
341+
val concreteClsDef = shims.ClassDef(
342+
cls = concreteCls,
343+
parents = {
344+
val parentCtor =
345+
New(TypeTree.of[mill.define.Module.BaseClass]).select(
346+
TypeRepr.of[mill.define.Module.BaseClass].typeSymbol.primaryConstructor
347+
)
348+
val parentApp =
349+
parentCtor.appliedToNone.appliedTo(
350+
concreteCls.declaredField("local_ctx").mkRef(s"${concreteCls} field local_ctx")
351+
)
352+
List(parentApp, TypeTree.of[T])
353+
},
354+
body = newTrees.toList.map(_(concreteCls, 'v2))
355+
)
356+
val clsOf = Ref(defn.Predef_classOf).appliedToType(concreteCls.typeRef)
357+
def newCls(ctx0: Expr[mill.define.Ctx]): Expr[T] = {
358+
New(TypeTree.ref(concreteCls))
359+
.select(concreteCls.primaryConstructor)
360+
.appliedTo(ctx0.asTerm)
361+
.asExprOf[T]
362+
}
363+
Block(
364+
List(concreteClsDef),
365+
'{ (${clsOf.asExprOf[Class[?]]}, (ctx0: mill.define.Ctx) => ${newCls('ctx0)}) }.asTerm
366+
).asExprOf[(Class[?], mill.define.Ctx => T)]
367+
}
368+
},
369+
crossSegmentsList = $wrappedElems.map((segArg: elems) => ${pathSegmentsTree('segArg)}),
370+
crossValuesListLists = $valuesTree,
371+
crossValuesRaw = $wrappedT
372+
)(using compiletime.summonInline[reflect.ClassTag[T]])
373+
}
374+
// report.errorAndAbort(s"made factory ${ref.show}")
375+
ref
376+
}
377+
}
378+
379+
def abortOldStyleClass(using Quotes)(tpe: quotes.reflect.TypeRepr): Nothing = {
380+
import quotes.reflect.*
381+
217382
val primaryConstructorArgs =
218-
tpe.typeSymbol.asClass.primaryConstructor.typeSignature.paramLists.head
383+
tpe.classSymbol.get.primaryConstructor.paramSymss.head
219384

220385
val oldArgStr = primaryConstructorArgs
221-
.map { s => s"${s.name}: ${s.typeSignature}" }
386+
.map { s => s"${s.name}: ${s.termRef.widen.show}" }
222387
.mkString(", ")
223388

224389
def parenWrap(s: String) =
225390
if (primaryConstructorArgs.size == 1) s
226391
else s"($s)"
227392

228-
val newTypeStr = primaryConstructorArgs.map(_.typeSignature.toString).mkString(", ")
229-
val newForwarderStr = primaryConstructorArgs.map(_.name.toString).mkString(", ")
393+
val newTypeStr = primaryConstructorArgs.map(_.termRef.widen.show).mkString(", ")
394+
val newForwarderStr = primaryConstructorArgs.map(_.name).mkString(", ")
230395

231-
c.abort(
232-
c.enclosingPosition,
396+
report.errorAndAbort(
233397
s"""
234398
|Cross type ${tpe.typeSymbol.name} must be trait, not a class. Please change:
235399
|
@@ -255,7 +419,8 @@ object Cross {
255419
|you may remove it. If you do not have this definition, you can
256420
|preserve the old behavior via `def millSourcePath = super.millSourcePath / crossValue`
257421
|
258-
|""".stripMargin
422+
|""".stripMargin,
423+
Position.ofMacroExpansion
259424
)
260425
}
261426
}

0 commit comments

Comments
 (0)