Skip to content

Commit 9e9f967

Browse files
Changes to make auto-optimize work
- Remove Wfact handling: the giant tuples made it too hard to scale in code size, and this is now unused - Fallbacks for putting vectors into ODEProblem(sys,...) constructors - Improved sparsity and simplification handling in extra pieces
1 parent f792b15 commit 9e9f967

File tree

5 files changed

+70
-146
lines changed

5 files changed

+70
-146
lines changed

src/build_function.jl

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ i.e., f(u,p,args...) for the out-of-place and scalar functions and
2424
```julia
2525
build_function(ex, args...;
2626
conv = simplified_expr, expression = Val{true},
27-
checkbounds = false, constructor=nothing,
27+
checkbounds = false,
2828
linenumbers = false, target = JuliaTarget())
2929
```
3030
@@ -46,8 +46,6 @@ Keyword Arguments:
4646
4747
- `checkbounds`: For whether to enable bounds checking inside of the generated
4848
function. Defaults to false, meaning that `@inbounds` is applied.
49-
- `constructor`: Allows for an arbitrary constructor function to be passed in
50-
for handling expressions of "weird" types. Defaults to nothing.
5149
- `linenumbers`: Determines whether the generated function expression retains
5250
the line numbers. Defaults to true.
5351
- `target`: The output target of the compilation process. Possible options are:
@@ -104,7 +102,7 @@ end
104102
# Scalar output
105103
function _build_function(target::JuliaTarget, op::Operation, args...;
106104
conv = simplified_expr, expression = Val{true},
107-
checkbounds = false, constructor=nothing,
105+
checkbounds = false,
108106
linenumbers = true, headerfun=addheader)
109107

110108
argnames = [gensym(:MTKArg) for i in 1:length(args)]
@@ -165,7 +163,7 @@ end
165163

166164
function _build_function(target::JuliaTarget, rhss, args...;
167165
conv = simplified_expr, expression = Val{true},
168-
checkbounds = false, constructor=nothing,
166+
checkbounds = false,
169167
linenumbers = false, multithread=nothing,
170168
headerfun=addheader, outputidxs=nothing,
171169
skipzeros = false, parallel=SerialForm())
@@ -323,40 +321,22 @@ function _build_function(target::JuliaTarget, rhss, args...;
323321

324322
if rhss isa Matrix
325323
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])
326-
# : x because ??? what to do in the general case?
327-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
328324
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
329325
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
330326
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
331-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
332327
elseif rhss isa SparseMatrixCSC
333328
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
334329
arr_sys_expr = :(SparseMatrixCSC{eltype($(first(argnames))),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
335-
# Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
336-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
337330
else # Vector
338331
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
339-
# Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
340-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof($(fargs.args[1])), eltype(X)) : x->convert(typeof($(fargs.args[1])),x)) : constructor
341332
end
342333

343334
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
344335
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)
345336
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
346-
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
337+
oop_body_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
347338
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
348339

349-
oop_body_block = :(
350-
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
351-
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
352-
return $arr_bounds_block
353-
else
354-
X = $bounds_block
355-
construct = $_constructor
356-
return construct(X)
357-
end
358-
)
359-
360340
oop_ex = headerfun(oop_body_block, fargs, false)
361341
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)
362342

0 commit comments

Comments
 (0)