diff --git a/src/utils.jl b/src/utils.jl index 6c0150d..92d325a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -226,7 +226,7 @@ isshortdef(ex) = (@capture(ex, (fcall_ = body_)) && function longdef1(ex) if @capture(ex, (arg_ -> body_)) - @q function ($arg,) $(body.args...) end + Expr(:function, arg isa Symbol ? :($arg,) : arg, body) elseif isshortdef(ex) @assert @capture(ex, (fcall_ = body_)) Expr(:function, fcall, body) @@ -289,14 +289,31 @@ function splitdef(fdef) function (fcall_ | fcall_) body_ end), "Not a function definition: $fdef") fcall_nowhere, whereparams = gatherwheres(fcall) - @assert(@capture(fcall_nowhere, ((func_(args__; kwargs__)) | - (func_(args__; kwargs__)::rtype_) | - (func_(args__)) | - (func_(args__)::rtype_))), - error_msg) - @assert(@capture(func, (fname_{params__} | fname_)), error_msg) - di = Dict(:name=>fname, :args=>args, - :kwargs=>(kwargs===nothing ? [] : kwargs), :body=>body) + func = args = kwargs = rtype = nothing + if @capture(fcall_nowhere, ((func_(args__; kwargs__)) | + (func_(args__; kwargs__)::rtype_) | + (func_(args__)) | + (func_(args__)::rtype_))) + elseif isexpr(fcall_nowhere, :tuple) + if length(fcall_nowhere.args) > 1 && isexpr(fcall_nowhere.args[1], :parameters) + args = fcall_nowhere.args[2:end] + kwargs = fcall_nowhere.args[1].args + else + args = fcall_nowhere.args + end + elseif isexpr(fcall_nowhere, :(::)) + args = Any[fcall_nowhere] + else + throw(ArgumentError(error_msg)) + end + if func !== nothing + @assert(@capture(func, (fname_{params__} | fname_)), error_msg) + di = Dict(:name=>fname, :args=>args, + :kwargs=>(kwargs===nothing ? [] : kwargs), :body=>body) + else + params = nothing + di = Dict(:args=>args, :kwargs=>(kwargs===nothing ? [] : kwargs), :body=>body) + end if rtype !== nothing; di[:rtype] = rtype end if whereparams !== nothing; di[:whereparams] = whereparams end if params !== nothing; di[:params] = params end @@ -313,33 +330,54 @@ function combinedef(dict::Dict) params = get(dict, :params, []) wparams = get(dict, :whereparams, []) body = block(dict[:body]) - name = dict[:name] - name_param = isempty(params) ? name : :($name{$(params...)}) - # We need the `if` to handle parametric inner/outer constructors like - # SomeType{X}(x::X) where X = SomeType{X}(x, x+2) - if isempty(wparams) - if rtype==nothing - @q(function $name_param($(dict[:args]...); - $(dict[:kwargs]...)) - $(body.args...) - end) + if haskey(dict, :name) + name = dict[:name] + name_param = isempty(params) ? name : :($name{$(params...)}) + # We need the `if` to handle parametric inner/outer constructors like + # SomeType{X}(x::X) where X = SomeType{X}(x, x+2) + if isempty(wparams) + if rtype==nothing + @q(function $name_param($(dict[:args]...); + $(dict[:kwargs]...)) + $(body.args...) + end) + else + @q(function $name_param($(dict[:args]...); + $(dict[:kwargs]...))::$rtype + $(body.args...) + end) + end else - @q(function $name_param($(dict[:args]...); - $(dict[:kwargs]...))::$rtype - $(body.args...) - end) + if rtype==nothing + @q(function $name_param($(dict[:args]...); + $(dict[:kwargs]...)) where {$(wparams...)} + $(body.args...) + end) + else + @q(function $name_param($(dict[:args]...); + $(dict[:kwargs]...))::$rtype where {$(wparams...)} + $(body.args...) + end) + end end else - if rtype==nothing - @q(function $name_param($(dict[:args]...); - $(dict[:kwargs]...)) where {$(wparams...)} - $(body.args...) - end) + if isempty(dict[:kwargs]) + arg = :($(dict[:args]...),) + else + arg = Expr(:tuple, Expr(:parameters, dict[:kwargs]...), dict[:args]...) + end + if isempty(wparams) + if rtype==nothing + @q($arg -> $body) + else + @q(($arg::$rtype) -> $body) + end else - @q(function $name_param($(dict[:args]...); - $(dict[:kwargs]...))::$rtype where {$(wparams...)} - $(body.args...) - end) + if rtype==nothing + @q(($arg where {$(wparams...)}) -> $body) + else + @q(($arg::$rtype where {$(wparams...)}) -> $body) + end end end end diff --git a/test/split.jl b/test/split.jl index 52540e0..940a68c 100644 --- a/test/split.jl +++ b/test/split.jl @@ -54,6 +54,20 @@ let # Parametric outer constructor @splitcombine Foo{A}(a::A) where A = Foo{A, A}(a,a) @test Foo{Int}(2) == Foo{Int, Int}(2, 2) + + @test (@splitcombine x -> x + 2)(10) === 12 + @test (@splitcombine (a, b=2; c=3, d=4) -> a+b+c+d)(1; d=10) === 16 + @test (@splitcombine ((a, b)::Tuple{Int,Int} -> a + b))((1, 2)) == 3 + @test (@splitcombine ((a::T) where {T}) -> T)([]) === Vector{Any} + @test (@splitcombine ((x::T, y::Vector{U}) where T <: U where U) -> (T, U))(1, Number[2.0]) == + (Int, Number) + @test (@splitcombine () -> @zeroarg)() == 1 + @test (@splitcombine () -> @onearg 1)() == 2 + @test (@splitcombine function (x) x + 2 end)(10) === 12 + @test (@splitcombine function (a::T) where {T} T end)([]) === Vector{Any} + @test (@splitcombine function (x::T, y::Vector{U}) where T <: U where U + (T, U) + end)(1, Number[2.0]) == (Int, Number) end @testset "combinestructdef, splitstructdef" begin