Skip to content

Commit f29ec53

Browse files
better default constructors in the weird cases
1 parent 0b3b4f5 commit f29ec53

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/utils.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ function flatten_expr!(x)
3333
x
3434
end
3535

36+
default_tensor_constructor(x) = x->(out = similar(typeof(u),$(size(rhss)...)); out .= x)
37+
3638
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
3739
checkbounds = false, constructor=nothing, linenumbers = true)
3840
_vs = map(x-> x isa Operation ? x.op : x, vs)
@@ -59,11 +61,11 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
5961
if rhss isa Matrix
6062
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])
6163
# : x because ??? what to do in the general case?
62-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : x->x) : constructor
64+
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SMatrix{$(size(rhss)...)} : default_tensor_constructor) : constructor
6365
elseif typeof(rhss) <: Array && !(typeof(rhss) <: Vector)
6466
vector_form = build_expr(:vect, [conv(rhs) for rhs rhss])
6567
arr_sys_expr = :(reshape($vector_form,$(size(rhss)...)))
66-
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : x->x) : constructor
68+
_constructor = constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.SArray{$(size(rhss)...)} : default_tensor_constructor) : constructor
6769
elseif rhss isa SparseMatrixCSC
6870
vector_form = build_expr(:vect, [conv(rhs) for rhs nonzeros(rhss)])
6971
arr_sys_expr = :(SparseMatrixCSC{eltype(u),Int}($(size(rhss)...), $(rhss.colptr), $(rhss.rowval), $vector_form))
@@ -85,15 +87,16 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
8587

8688
oop_ex = :(
8789
($(fargs.args...),) -> begin
88-
if $(fargs.args[1]) isa Array
90+
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
91+
if $(fargs.args[1]) isa Array || (!($(fargs.args[1]) <: StaticArray) && $(rhss isa SparseMatrixCSC))
8992
return $arr_bounds_block
9093
else
9194
X = $bounds_block
95+
T = promote_type(map(typeof,X)...)
96+
map(T,X)
97+
construct = $_constructor
98+
return construct(X)
9299
end
93-
T = promote_type(map(typeof,X)...)
94-
map(T,X)
95-
construct = $_constructor
96-
construct(X)
97100
end
98101
)
99102

0 commit comments

Comments
 (0)