diff --git a/src/MOL_utils.jl b/src/MOL_utils.jl index a4357df81..972dc4181 100644 --- a/src/MOL_utils.jl +++ b/src/MOL_utils.jl @@ -311,3 +311,28 @@ function ex2term(term, v) name = Symbol("⟦" * string(term) * "⟧") return setname(similarterm(exdv, rename(operation(exdv), name), arguments(exdv)), name) end + +function _to_kw(ex) + @assert ex.head == :call + args = ex.args + @assert all(x isa Symbol for x in args) + u = args[1] + xs = args[2:end] + kw_fn_name = Symbol(ex.args[1], "_kw") + kws = [] + for arg in xs + kw = Expr(:kw, arg, esc(arg)) + push!(kws, kw) + end + params = Expr(:parameters, kws...) + + def_call_ex = Expr(:call, esc(kw_fn_name), params) + block_call_ex = Expr(:call, esc(u), xs...) + + body = Expr(:block, block_call_ex) + Expr(:function, def_call_ex, body) +end + +macro to_kw(ex) + _to_kw(ex) +end diff --git a/src/MethodOfLines.jl b/src/MethodOfLines.jl index e1c6dff5e..327b14cea 100644 --- a/src/MethodOfLines.jl +++ b/src/MethodOfLines.jl @@ -76,7 +76,7 @@ include("error_analysis.jl") include("scalar_discretization.jl") include("MOL_discretization.jl") -export MOLFiniteDifference, discretize, symbolic_discretize, ODEFunctionExpr, generate_code, grid_align, edge_align, center_align, get_discrete +export MOLFiniteDifference, discretize, symbolic_discretize, ODEFunctionExpr, generate_code, grid_align, edge_align, center_align, get_discrete, @to_kw export UpwindScheme, WENOScheme end diff --git a/test/utils_test.jl b/test/utils_test.jl index d7be033c4..7c134d2e2 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -131,3 +131,21 @@ end I = CartesianIndex(-1, 2) @test MethodOfLines._wrapperiodic(I, 2, 1, 4) == CartesianIndex(2, 2) end + +@testset "@to_kw" begin + @variables t + N = 2 + x = Symbolics.variables(:x, 1:N) + + for (i, xi) in enumerate(x) + @eval $(Symbol(:x, i)) = $xi + end + + @variables u(..) + @to_kw u(t, x1, x2) + @test isequal(u_kw(t=0), u(0, x1, x2)) + + @variables foo(..) + @to_kw foo(t, x1, x2) + @test isequal(foo_kw(t=0), foo(0, x1, x2)) +end