Skip to content

Commit b82e713

Browse files
Merge pull request #581 from SciML/s/unflatten
unflatten long operations in build_function
2 parents 929f6bc + 0a278a7 commit b82e713

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

src/build_function.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
8787
wrappedex
8888
end
8989

90+
function unflatten_long_ops(op, N=4)
91+
rule1 = @rule((+)((~~x)) => length(~~x) > N ?
92+
+(+((~~x)[1:N]...) + (+)((~~x)[N+1:end]...)) : nothing)
93+
rule2 = @rule((*)((~~x)) => length(~~x) > N ?
94+
*(*((~~x)[1:N]...) * (*)((~~x)[N+1:end]...)) : nothing)
95+
96+
op = to_symbolic(op)
97+
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op) |> to_mtk
98+
end
99+
90100
# Scalar output
91101
function _build_function(target::JuliaTarget, op::Operation, args...;
92102
conv = simplified_expr, expression = Val{true},
@@ -97,9 +107,10 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
97107
arg_pairs = map(vars_to_pairs,zip(argnames,args))
98108
ls = reduce(vcat,first.(arg_pairs))
99109
rs = reduce(vcat,last.(arg_pairs))
100-
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, rs))
110+
var_eqs = Expr(:(=), ModelingToolkit.build_expr(:tuple, ls), ModelingToolkit.build_expr(:tuple, unflatten_long_ops.(rs)))
101111

102112
fname = gensym(:ModelingToolkitFunction)
113+
op = unflatten_long_ops(op)
103114
out_expr = conv(op)
104115
let_expr = Expr(:let, var_eqs, Expr(:block, out_expr))
105116
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
@@ -242,7 +253,13 @@ function _build_function(target::JuliaTarget, rhss, args...;
242253
oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
243254
X = gensym(:MTIIPVar)
244255

245-
rhs_length = rhss isa SparseMatrixCSC ? length(rhss.nzval) : length(rhss)
256+
if rhss isa SparseMatrixCSC
257+
rhs_length = length(rhss.nzval)
258+
rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(unflatten_long_ops, rhss.nzval))
259+
else
260+
rhs_length = length(rhss)
261+
rhss = [unflatten_long_ops(r) for r in rhss]
262+
end
246263

247264
if parallel isa DistributedForm
248265
numworks = Distributed.nworkers()
@@ -251,6 +268,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
251268
finalsize = rhs_length - (numworks-1)*lens
252269
_rhss = vcat(reduce(vcat,[[getindex(reducevars[i],j) for j in 1:lens] for i in 1:numworks-1],init=Expr[]),
253270
[getindex(reducevars[end],j) for j in 1:finalsize])
271+
254272
elseif parallel isa DaggerForm
255273
computevars = [Variable(gensym(:MTComputeVar))() for i in axes(rhss,1)]
256274
reducevar = Variable(gensym(:MTReduceVar))()

test/build_function.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,10 @@ h_julia_scalar(a, b, c, d, e, g) = a[1] + b[1] + c[1] + c[2] + c[3] + d[1] + e[1
6464
h_str_scalar = ModelingToolkit.build_function(h_scalar, [a], [b], [c1, c2, c3], [d], [e], [g])
6565
h_oop_scalar = eval(h_str_scalar)
6666
@test h_oop_scalar(inputs...) == h_julia_scalar(inputs...)
67+
68+
@variables z[1:100]
69+
@test isequal(simplify(ModelingToolkit.unflatten_long_ops(sum(z))),
70+
simplify(sum(z)))
71+
72+
@test isequal(simplify(ModelingToolkit.unflatten_long_ops(prod(z))),
73+
simplify(prod(z)))

0 commit comments

Comments
 (0)