Skip to content

Commit 5f49993

Browse files
Merge pull request #490 from SciML/autoopt
Changes to make auto-optimize work
2 parents f792b15 + cc062be commit 5f49993

File tree

9 files changed

+121
-152
lines changed

9 files changed

+121
-152
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1212
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1313
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1414
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
15+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1516
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
1617
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ModelingToolkit
22

33
using DiffEqBase, Distributed
4-
using StaticArrays, LinearAlgebra, SparseArrays
4+
using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
55
using Latexify, Unitful, ArrayInterface
66
using MacroTools
77
using UnPack: @unpack

src/build_function.jl

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ i.e., f(u,p,args...) for the out-of-place and scalar functions and
2424
```julia
2525
build_function(ex, args...;
2626
conv = simplified_expr, expression = Val{true},
27-
checkbounds = false, constructor=nothing,
27+
checkbounds = false,
2828
linenumbers = false, target = JuliaTarget())
2929
```
3030
@@ -46,8 +46,6 @@ Keyword Arguments:
4646
4747
- `checkbounds`: For whether to enable bounds checking inside of the generated
4848
function. Defaults to false, meaning that `@inbounds` is applied.
49-
- `constructor`: Allows for an arbitrary constructor function to be passed in
50-
for handling expressions of "weird" types. Defaults to nothing.
5149
- `linenumbers`: Determines whether the generated function expression retains
5250
the line numbers. Defaults to true.
5351
- `target`: The output target of the compilation process. Possible options are:
@@ -104,7 +102,7 @@ end
104102
# Scalar output
105103
function _build_function(target::JuliaTarget, op::Operation, args...;
106104
conv = simplified_expr, expression = Val{true},
107-
checkbounds = false, constructor=nothing,
105+
checkbounds = false,
108106
linenumbers = true, headerfun=addheader)
109107

110108
argnames = [gensym(:MTKArg) for i in 1:length(args)]
@@ -165,7 +163,7 @@ end
165163

166164
function _build_function(target::JuliaTarget, rhss, args...;
167165
conv = simplified_expr, expression = Val{true},
168-
checkbounds = false, constructor=nothing,
166+
checkbounds = false,
169167
linenumbers = false, multithread=nothing,
170168
headerfun=addheader, outputidxs=nothing,
171169
skipzeros = false, parallel=SerialForm())
@@ -323,41 +321,39 @@ function _build_function(target::JuliaTarget, rhss, args...;
323321

324322
if rhss isa Matrix
325323
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])
326-
# : x because ??? what to do in the general case?
327-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
328324
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
329325
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
330326
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
331-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->(out = similar(typeof($(fargs.args[1])),$(size(rhss)...)); out .= x)) : constructor
332327
elseif rhss isa SparseMatrixCSC
333328
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
334329
arr_sys_expr = :(SparseMatrixCSC{eltype($(first(argnames))),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
335-
# Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
336-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
337330
else # Vector
338331
arr_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
339-
# Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
340-
_constructor = constructor === nothing ? :($(first(argnames)) isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof($(fargs.args[1])), eltype(X)) : x->convert(typeof($(fargs.args[1])),x)) : constructor
341332
end
342333

334+
xname = gensym(:MTK)
335+
336+
arr_sys_expr = (typeof(rhss) <: Vector || typeof(rhss) <: Matrix) && !(eltype(rhss) <: AbstractArray) ? quote
337+
if typeof($(fargs.args[1])) <: Union{ModelingToolkit.StaticArrays.SArray,ModelingToolkit.LabelledArrays.SLArray}
338+
$xname = ModelingToolkit.StaticArrays.@SArray $arr_sys_expr
339+
convert(typeof($(fargs.args[1])),$xname)
340+
else
341+
$xname = $arr_sys_expr
342+
if !(typeof($(fargs.args[1])) <: Array)
343+
convert(typeof($(fargs.args[1])),$xname)
344+
else
345+
$xname
346+
end
347+
end
348+
end : arr_sys_expr
349+
343350
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
344351
arr_let_expr = Expr(:let, var_eqs, arr_sys_expr)
345352
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
346-
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
353+
oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
347354
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
348355

349-
oop_body_block = :(
350-
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
351-
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
352-
return $arr_bounds_block
353-
else
354-
X = $bounds_block
355-
construct = $_constructor
356-
return construct(X)
357-
end
358-
)
359-
360-
oop_ex = headerfun(oop_body_block, fargs, false)
356+
oop_ex = headerfun(oop_bounds_block, fargs, false)
361357
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)
362358

363359
if !linenumbers

0 commit comments

Comments
 (0)