diff --git a/src/parsing.jl b/src/parsing.jl index 79a674d..cc848bb 100644 --- a/src/parsing.jl +++ b/src/parsing.jl @@ -169,6 +169,18 @@ function parse_function(lhs::Union{Symbol, Expr}, rhs::Expr; autovec::Bool=true, end end +# helper: detect x -> fn(x, wt, ...) and synthesize (x, w) -> fn(x, w, ...) +function detect_twoarg(ex) + if ex isa Expr && @capture(ex, x_->body_) + if body isa Expr && @capture(body, fn_(x_, with_Symbol, rest__)) + return :( (x, w) -> $(fn)(x, w, $(rest...)) ), with + elseif body isa Expr && @capture(body, fn_(x_, with_Symbol)) + return :( (x, w) -> $(fn)(x, w) ), with + end + end + return nothing, nothing +end + # Not exported # Note: `parse_across` currently does not support the use of numbers for selecting columns function parse_across(vars::Union{Expr,Symbol}, funcs::Union{Expr,Symbol}) @@ -184,26 +196,69 @@ function parse_across(vars::Union{Expr,Symbol}, funcs::Union{Expr,Symbol}) end func_array = Union{Expr,Symbol}[] # expression containing functions + needs_w = Bool[] # <— tracks whether each func wants (x,w) + with_sym = nothing # mark that this function should be called as f(x, w) if funcs isa Symbol push!(func_array, esc(funcs)) # fixes bug where single function is used inside across + push!(needs_w, false) elseif @capture(funcs, (args__,)) for arg in args if arg isa Symbol push!(func_array, esc(arg)) + push!(needs_w, false) else - push!(func_array, esc(parse_tidy(arg; from_across=true))) # fixes bug with compound and anonymous functions getting wrapped in Cols() + twoarg, with = detect_twoarg(arg) + if twoarg === nothing + push!(func_array, esc(parse_tidy(arg; from_across=true))) + push!(needs_w, false) + else + with_sym === nothing && (with_sym = with) + push!(func_array, esc(twoarg)) + push!(needs_w, true) + end end end - else # for compound functions like mean or anonymous functions - push!(func_array, esc(funcs)) + else + twoarg, with = detect_twoarg(funcs) + if twoarg === nothing + push!(func_array, esc(funcs)) + push!(needs_w, false) + else + with_sym = with + push!(func_array, esc(twoarg)) + push!(needs_w, true) + end end num_funcs = length(func_array) - return :(Cols($(src...)) .=> reshape([$(func_array...)], 1, $num_funcs)) + if with_sym === nothing + return :(Cols($(src...)) .=> reshape([$(func_array...)], 1, $num_funcs)) + end + + return :(AsTable(Cols($(src...), $(QuoteNode(with_sym)))) => (tbl -> begin + w = getproperty(tbl, $(QuoteNode(with_sym))) + acc = Pair{Symbol,Any}[] + @inbounds for nm in propertynames(tbl) + nm === $(QuoteNode(with_sym)) && continue + x = getproperty(tbl, nm) + eltype(x) <: Number || continue + $(let pushes = Expr[] + for (i, f) in enumerate(func_array) + sfx = "_" * string(i) + call_ex = needs_w[i] ? :($(f)(x, w)) : :($(f)(x)) + push!(pushes, :(push!(acc, Symbol(string(nm), $sfx) => $call_ex))) + end + Expr(:block, pushes...) + end) + end + (; acc...) + end) => AsTable) end + + # Not exported function parse_desc(tidy_expr::Union{Expr,Symbol}) tidy_expr, found_n, found_row_number = parse_interpolation(tidy_expr)