Skip to content

Commit 8539e10

Browse files
fix: fix bugs in @fallback_iip_specialize, handle static array problems
1 parent 070de8c commit 8539e10

File tree

1 file changed

+51
-3
lines changed

1 file changed

+51
-3
lines changed

src/systems/problem_utils.jl

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,35 +1249,83 @@ parameters for `iip` and `specialize`. Generates fallbacks with
12491249
"""
12501250
macro fallback_iip_specialize(ex)
12511251
@assert Meta.isexpr(ex, :function)
1252+
# fnname is ODEProblem{iip, spec}(args...) where {iip, spec}
1253+
# body is function body
12521254
fnname, body = ex.args
12531255
@assert Meta.isexpr(fnname, :where)
1256+
# fnname_call is ODEProblem{iip, spec}(args...)
1257+
# where_args are `iip, spec`
12541258
fnname_call, where_args... = fnname.args
12551259
@assert length(where_args) == 2
12561260
iiparg, specarg = where_args
12571261

12581262
@assert Meta.isexpr(fnname_call, :call)
1263+
# fnname_curly is ODEProblem{iip, spec}
12591264
fnname_curly, args... = fnname_call.args
1260-
args = map(args) do arg
1265+
# the function should have keyword arguments
1266+
@assert Meta.isexpr(args[1], :parameters)
1267+
1268+
# arguments to call with
1269+
call_args = map(args) do arg
1270+
# keyword args are in `Expr(:parameters)` so any `Expr(:kw)` here
1271+
# are optional positional arguments. Analyze `:(f(a, b = 1; k = 1, l...))`
1272+
# to understand
12611273
Meta.isexpr(arg, :kw) && return arg.args[1]
12621274
return arg
12631275
end
1276+
call_kwargs = map(call_args[1].args) do arg
1277+
Meta.isexpr(arg, :...) && return arg
1278+
@assert Meta.isexpr(arg, :kw)
1279+
return Expr(:kw, arg.args[1], arg.args[1])
1280+
end
1281+
call_args[1] = Expr(:parameters, call_kwargs...)
12641282

12651283
@assert Meta.isexpr(fnname_curly, :curly)
1284+
# fnname_name is `ODEProblem`
1285+
# curly_args is `iip, spec`
12661286
fnname_name, curly_args... = fnname_curly.args
12671287
@assert curly_args == where_args
12681288

1289+
# callexpr_iip is `ODEProblem{iip, FullSpecialize}(call_args...)`
12691290
callexpr_iip = Expr(
1270-
:call, Expr(:curly, fnname_name, curly_args[1], SciMLBase.FullSpecialize), args...)
1291+
:call, Expr(:curly, fnname_name, curly_args[1], SciMLBase.FullSpecialize), call_args...)
1292+
# `ODEProblem{iip}`
12711293
fnname_iip = Expr(:curly, fnname_name, curly_args[1])
1294+
# `ODEProblem{iip}(args...)`
12721295
fncall_iip = Expr(:call, fnname_iip, args...)
1296+
# ODEProblem{iip}(args...) where {iip}
12731297
fnwhere_iip = Expr(:where, fncall_iip, where_args[1])
12741298
fn_iip = Expr(:function, fnwhere_iip, callexpr_iip)
12751299

1276-
callexpr_base = Expr(:call, Expr(:curly, fnname_name, true), args...)
1300+
# `ODEProblem{true}(call_args...)`
1301+
callexpr_base = Expr(:call, Expr(:curly, fnname_name, true), call_args...)
1302+
# `ODEProblem(args...)`
12771303
fncall_base = Expr(:call, fnname_name, args...)
12781304
fn_base = Expr(:function, fncall_base, callexpr_base)
1305+
1306+
# Handle case when this is a problem constructor and `u0map` is a `StaticArray`,
1307+
# where `iip` should default to `false`.
1308+
fn_sarr = nothing
1309+
if endswith(string(fnname_name), "Problem")
1310+
# args should at least contain an argument for the `u0map`
1311+
@assert length(args) > 3
1312+
u0_arg = args[3]
1313+
# should not have a type-annotation
1314+
@assert !Meta.isexpr(u0_arg, :(::))
1315+
if Meta.isexpr(u0_arg, :kw)
1316+
argname, default = u0_arg.args
1317+
u0_arg = Expr(:kw, Expr(:(::), argname, StaticArray), default)
1318+
else
1319+
u0_arg = Expr(:(::), u0_arg, StaticArray)
1320+
end
1321+
1322+
callexpr_sarr = Expr(:call, Expr(:curly, fnname_name, false), call_args...)
1323+
fncall_sarr = Expr(:call, fnname_name, args[1], args[2], u0_arg, args[4:end]...)
1324+
fn_sarr = Expr(:function, fncall_sarr, callexpr_sarr)
1325+
end
12791326
return quote
12801327
$fn_base
1328+
$fn_sarr
12811329
$fn_iip
12821330
Base.@__doc__ $ex
12831331
end |> esc

0 commit comments

Comments
 (0)