Skip to content

Commit 9660dd3

Browse files
authored
[Issue-#89] Fix call of main methods (refer to the actual method owner not just the method) (#142)
Fixes #89
1 parent 6634d17 commit 9660dd3

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

mainargs/src-3/Macros.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,11 @@ object Macros {
8989
val argSigs = Expr.ofList(argSigsExprs)
9090

9191
val invokeRaw: Expr[(B, Seq[Any]) => T] = {
92-
def callOf(args: Expr[Seq[Any]]) = call(method, '{ Seq( ${ args }) }).asExprOf[T]
93-
'{ ((b: B, params: Seq[Any]) => ${ callOf('{ params }) }) }
92+
93+
def callOf(methodOwner: Expr[Any], args: Expr[Seq[Any]]) =
94+
call(methodOwner, method, '{ Seq($args) }).asExprOf[T]
95+
96+
'{ (b: B, params: Seq[Any]) => ${ callOf('b, 'params) } }
9497
}
9598
'{ MainData.create[T, B](${ Expr(method.name) }, ${ mainAnnotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
9699
}
@@ -115,8 +118,9 @@ object Macros {
115118
*
116119
*/
117120
private def call(using Quotes)(
118-
method: quotes.reflect.Symbol,
119-
argss: Expr[Seq[Seq[Any]]]
121+
methodOwner: Expr[Any],
122+
method: quotes.reflect.Symbol,
123+
argss: Expr[Seq[Seq[Any]]]
120124
): Expr[_] = {
121125
// Copy pasted from Cask.
122126
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
@@ -127,8 +131,6 @@ object Macros {
127131
report.throwError("At least one parameter list must be declared.", method.pos.get)
128132
}
129133

130-
val fct = Ref(method)
131-
132134
val accesses: List[List[Term]] = for (i <- paramss.indices.toList) yield {
133135
for (j <- paramss(i).indices.toList) yield {
134136
val tpe = paramss(i)(j).tree.asInstanceOf[ValDef].tpt.tpe
@@ -137,12 +139,9 @@ object Macros {
137139
}
138140
}
139141

140-
val base = Apply(fct, accesses.head)
141-
val application: Apply = accesses.tail.foldLeft(base)((lhs, args) => Apply(lhs, args))
142-
val expr = application.asExpr
143-
expr
142+
methodOwner.asTerm.select(method).appliedToArgss(accesses).asExpr
144143
}
145-
144+
146145

147146
/** Lookup default values for a method's parameters. */
148147
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any]] = {
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package mainargs
2+
import utest._
3+
4+
trait CommandList {
5+
@main
6+
def list(@arg v: String): String = v
7+
}
8+
9+
trait CommandCopy {
10+
@main
11+
def copy(@arg from: String, @arg to: String): (String, String) = (from, to)
12+
}
13+
14+
object Joined extends CommandCopy with CommandList {
15+
@main
16+
def test(@arg from: String, @arg to: String): (String, String) = (from, to)
17+
}
18+
19+
object MultiTraitTests extends TestSuite {
20+
val check = new Checker(ParserForMethods(Joined), allowPositional = true)
21+
val tests = Tests {
22+
test - check(List("copy", "fromArg", "toArg"), Result.Success(("fromArg", "toArg")))
23+
test - check(List("test", "fromArg", "toArg"), Result.Success(("fromArg", "toArg")))
24+
test - check(List("list", "vArg"), Result.Success("vArg"))
25+
}
26+
}

0 commit comments

Comments
 (0)