@@ -24,7 +24,7 @@ i.e., f(u,p,args...) for the out-of-place and scalar functions and
24
24
```julia
25
25
build_function(ex, args...;
26
26
conv = simplified_expr, expression = Val{true},
27
- checkbounds = false, constructor=nothing,
27
+ checkbounds = false,
28
28
linenumbers = false, target = JuliaTarget())
29
29
```
30
30
@@ -46,8 +46,6 @@ Keyword Arguments:
46
46
47
47
- `checkbounds`: For whether to enable bounds checking inside of the generated
48
48
function. Defaults to false, meaning that `@inbounds` is applied.
49
- - `constructor`: Allows for an arbitrary constructor function to be passed in
50
- for handling expressions of "weird" types. Defaults to nothing.
51
49
- `linenumbers`: Determines whether the generated function expression retains
52
50
the line numbers. Defaults to true.
53
51
- `target`: The output target of the compilation process. Possible options are:
104
102
# Scalar output
105
103
function _build_function (target:: JuliaTarget , op:: Operation , args... ;
106
104
conv = simplified_expr, expression = Val{true },
107
- checkbounds = false , constructor = nothing ,
105
+ checkbounds = false ,
108
106
linenumbers = true , headerfun= addheader)
109
107
110
108
argnames = [gensym (:MTKArg ) for i in 1 : length (args)]
165
163
166
164
function _build_function (target:: JuliaTarget , rhss, args... ;
167
165
conv = simplified_expr, expression = Val{true },
168
- checkbounds = false , constructor = nothing ,
166
+ checkbounds = false ,
169
167
linenumbers = false , multithread= nothing ,
170
168
headerfun= addheader, outputidxs= nothing ,
171
169
skipzeros = false , parallel= SerialForm ())
@@ -323,41 +321,39 @@ function _build_function(target::JuliaTarget, rhss, args...;
323
321
324
322
if rhss isa Matrix
325
323
arr_sys_expr = build_expr (:vcat , [build_expr (:row ,[conv (rhs) for rhs ∈ rhss[i,:]]) for i in 1 : size (rhss,1 )])
326
- # : x because ??? what to do in the general case?
327
- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> (out = similar (typeof ($ (fargs. args[1 ])),$ (size (rhss)... )); out .= x)) : constructor
328
324
elseif typeof (rhss) <: Array && ! (typeof (rhss) <: Vector )
329
325
vector_form = build_expr (:vect , [conv (rhs) for rhs ∈ rhss])
330
326
arr_sys_expr = :(reshape ($ vector_form,$ (size (rhss)... )))
331
- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SArray{$ (size (rhss)... )} : x-> (out = similar (typeof ($ (fargs. args[1 ])),$ (size (rhss)... )); out .= x)) : constructor
332
327
elseif rhss isa SparseMatrixCSC
333
328
vector_form = build_expr (:vect , [conv (rhs) for rhs ∈ nonzeros (rhss)])
334
329
arr_sys_expr = :(SparseMatrixCSC {eltype($(first(argnames))),Int} ($ (size (rhss)... ), $ (rhss. colptr), $ (rhss. rowval), $ vector_form))
335
- # Static and sparse? Probably not a combo that will actually be hit, but give a default anyways
336
- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. SMatrix{$ (size (rhss)... )} : x-> x) : constructor
337
330
else # Vector
338
331
arr_sys_expr = build_expr (:vect , [conv (rhs) for rhs ∈ rhss])
339
- # Handle vector constructor separately using `typeof(u)` to support things like LabelledArrays
340
- _constructor = constructor === nothing ? :($ (first (argnames)) isa ModelingToolkit. StaticArrays. StaticArray ? ModelingToolkit. StaticArrays. similar_type (typeof ($ (fargs. args[1 ])), eltype (X)) : x-> convert (typeof ($ (fargs. args[1 ])),x)) : constructor
341
332
end
342
333
334
+ xname = gensym (:MTK )
335
+
336
+ arr_sys_expr = (typeof (rhss) <: Vector || typeof (rhss) <: Matrix ) && ! (eltype (rhss) <: AbstractArray ) ? quote
337
+ if typeof ($ (fargs. args[1 ])) <: Union{ModelingToolkit.StaticArrays.SArray,ModelingToolkit.LabelledArrays.SLArray}
338
+ $ xname = ModelingToolkit. StaticArrays. @SArray $ arr_sys_expr
339
+ convert (typeof ($ (fargs. args[1 ])),$ xname)
340
+ else
341
+ $ xname = $ arr_sys_expr
342
+ if ! (typeof ($ (fargs. args[1 ])) <: Array )
343
+ convert (typeof ($ (fargs. args[1 ])),$ xname)
344
+ else
345
+ $ xname
346
+ end
347
+ end
348
+ end : arr_sys_expr
349
+
343
350
let_expr = Expr (:let , var_eqs, tuple_sys_expr)
344
351
arr_let_expr = Expr (:let , var_eqs, arr_sys_expr)
345
352
bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
346
- arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $ arr_let_expr end )
353
+ oop_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $ arr_let_expr end )
347
354
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ ip_let_expr end )
348
355
349
- oop_body_block = :(
350
- # If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
351
- if $ (fargs. args[1 ]) isa Array || (! (typeof ($ (fargs. args[1 ])) <: StaticArray ) && $ (rhss isa SparseMatrixCSC))
352
- return $ arr_bounds_block
353
- else
354
- X = $ bounds_block
355
- construct = $ _constructor
356
- return construct (X)
357
- end
358
- )
359
-
360
- oop_ex = headerfun (oop_body_block, fargs, false )
356
+ oop_ex = headerfun (oop_bounds_block, fargs, false )
361
357
iip_ex = headerfun (ip_bounds_block, fargs, true ; X= X)
362
358
363
359
if ! linenumbers
0 commit comments