Skip to content

Commit 414521f

Browse files
feat: add @fallback_iip_specialize
1 parent 297044a commit 414521f

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/systems/problem_utils.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,49 @@ function SciMLBase.detect_cycles(sys::AbstractSystem, varmap::Dict{Any, Any}, va
986986
return !isempty(cycles)
987987
end
988988

989+
"""
990+
$(TYPEDSIGNATURES)
991+
992+
Macro for writing problem/function constructors. Expects a function definition with type
993+
parameters for `iip` and `specialize`. Generates fallbacks with
994+
`specialize = SciMLBase.FullSpecialize` and `iip = true`.
995+
"""
996+
macro fallback_iip_specialize(ex)
997+
@assert Meta.isexpr(ex, :function)
998+
fnname, body = ex.args
999+
@assert Meta.isexpr(fnname, :where)
1000+
fnname_call, where_args... = fnname.args
1001+
@assert length(where_args) == 2
1002+
iiparg, specarg = where_args
1003+
1004+
@assert Meta.isexpr(fnname_call, :call)
1005+
fnname_curly, args... = fnname_call.args
1006+
args = map(args) do arg
1007+
Meta.isexpr(arg, :kw) && return arg.args[1]
1008+
return arg
1009+
end
1010+
1011+
@assert Meta.isexpr(fnname_curly, :curly)
1012+
fnname_name, curly_args... = fnname_curly.args
1013+
@assert curly_args == where_args
1014+
1015+
callexpr_iip = Expr(
1016+
:call, Expr(:curly, fnname_name, curly_args[1], SciMLBase.FullSpecialize), args...)
1017+
fnname_iip = Expr(:curly, fnname_name, curly_args[1])
1018+
fncall_iip = Expr(:call, fnname_iip, args...)
1019+
fnwhere_iip = Expr(:where, fncall_iip, where_args[1])
1020+
fn_iip = Expr(:function, fnwhere_iip, callexpr_iip)
1021+
1022+
callexpr_base = Expr(:call, Expr(:curly, fnname_name, true), args...)
1023+
fncall_base = Expr(:call, fnname_name, args...)
1024+
fn_base = Expr(:function, fncall_base, callexpr_base)
1025+
return quote
1026+
$fn_base
1027+
$fn_iip
1028+
Base.@__doc__ $ex
1029+
end
1030+
end
1031+
9891032
##############
9901033
# Legacy functions for backward compatibility
9911034
##############

0 commit comments

Comments
 (0)