@@ -87,6 +87,16 @@ function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
87
87
wrappedex
88
88
end
89
89
90
+ function unflatten_long_ops (op, N= 4 )
91
+ rule1 = @rule ((+ )((~~ x)) => length (~~ x) > N ?
92
+ + (+ ((~~ x)[1 : N]. .. ) + (+ )((~~ x)[N+ 1 : end ]. .. )) : nothing )
93
+ rule2 = @rule ((* )((~~ x)) => length (~~ x) > N ?
94
+ * (* ((~~ x)[1 : N]. .. ) * (* )((~~ x)[N+ 1 : end ]. .. )) : nothing )
95
+
96
+ op = to_symbolic (op)
97
+ Rewriters. Fixpoint (Rewriters. Postwalk (Rewriters. Chain ([rule1, rule2])))(op) |> to_mtk
98
+ end
99
+
90
100
# Scalar output
91
101
function _build_function (target:: JuliaTarget , op:: Operation , args... ;
92
102
conv = simplified_expr, expression = Val{true },
@@ -97,9 +107,10 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
97
107
arg_pairs = map (vars_to_pairs,zip (argnames,args))
98
108
ls = reduce (vcat,first .(arg_pairs))
99
109
rs = reduce (vcat,last .(arg_pairs))
100
- var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , rs ))
110
+ var_eqs = Expr (:(= ), ModelingToolkit. build_expr (:tuple , ls), ModelingToolkit. build_expr (:tuple , unflatten_long_ops .(rs) ))
101
111
102
112
fname = gensym (:ModelingToolkitFunction )
113
+ op = unflatten_long_ops (op)
103
114
out_expr = conv (op)
104
115
let_expr = Expr (:let , var_eqs, Expr (:block , out_expr))
105
116
bounds_block = checkbounds ? let_expr : :(@inbounds begin $ let_expr end )
@@ -242,7 +253,13 @@ function _build_function(target::JuliaTarget, rhss, args...;
242
253
oidx = isnothing (outputidxs) ? (i -> i) : (i -> outputidxs[i])
243
254
X = gensym (:MTIIPVar )
244
255
245
- rhs_length = rhss isa SparseMatrixCSC ? length (rhss. nzval) : length (rhss)
256
+ if rhss isa SparseMatrixCSC
257
+ rhs_length = length (rhss. nzval)
258
+ rhss = SparseMatrixCSC (rhss. m, rhss. m, rhss. colptr, rhss. rowval, map (unflatten_long_ops, rhss. nzval))
259
+ else
260
+ rhs_length = length (rhss)
261
+ rhss = [unflatten_long_ops (r) for r in rhss]
262
+ end
246
263
247
264
if parallel isa DistributedForm
248
265
numworks = Distributed. nworkers ()
@@ -251,6 +268,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
251
268
finalsize = rhs_length - (numworks- 1 )* lens
252
269
_rhss = vcat (reduce (vcat,[[getindex (reducevars[i],j) for j in 1 : lens] for i in 1 : numworks- 1 ],init= Expr[]),
253
270
[getindex (reducevars[end ],j) for j in 1 : finalsize])
271
+
254
272
elseif parallel isa DaggerForm
255
273
computevars = [Variable (gensym (:MTComputeVar ))() for i in axes (rhss,1 )]
256
274
reducevar = Variable (gensym (:MTReduceVar ))()
0 commit comments