Skip to content

Commit ea35575

Browse files
authored
Merge pull request #448 from SciML/myb/skipzeros
Add skipzeros
2 parents 8d8b6f9 + 04220a2 commit ea35575

File tree

5 files changed

+71
-20
lines changed

5 files changed

+71
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "3.8.1"
4+
version = "3.9.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/build_function.jl

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
167167
checkbounds = false, constructor=nothing,
168168
linenumbers = false, multithread=nothing,
169169
headerfun=addheader, outputidxs=nothing,
170-
parallel=SerialForm())
170+
skipzeros = false, parallel=SerialForm())
171171

172172
if multithread isa Bool
173173
@warn("multithraded is deprecated for the parallel argument. See the documentation.")
@@ -202,18 +202,55 @@ function _build_function(target::JuliaTarget, rhss, args...;
202202
_rhss = rhss
203203
end
204204

205-
if is_array_array_sparse_matrix(rhss) # Array of arrays of sparse matrices
206-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2.nzval)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
207-
elseif is_array_array_matrix(rhss) # Array of arrays of arrays
208-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)], init=Expr[])) for (i,rhsel) enumerate(_rhss)], init=Expr[])
209-
elseif is_array_sparse_matrix(rhss) # Array of sparse matrices
210-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel.nzval)]) for (i,rhsel) enumerate(_rhss)], init=Expr[])
205+
ip_sys_exprs = Expr[]
206+
if is_array_array_sparse_matrix(rhss) # Array of arrays of sparse matrices
207+
for (i, rhsel) enumerate(_rhss)
208+
for (j, rhsel2) enumerate(rhsel)
209+
for (k, rhs) enumerate(rhsel2.nzval)
210+
rhs′ = conv(rhs)
211+
(skipzeros && rhs′ isa Number && iszero(rhs′)) && continue
212+
push!(ip_sys_exprs, :($X[$i][$j].nzval[$k] = $rhs′))
213+
end
214+
end
215+
end
216+
elseif is_array_array_matrix(rhss) # Array of arrays of arrays
217+
for (i, rhsel) enumerate(_rhss)
218+
for (j, rhsel2) enumerate(rhsel)
219+
for (k, rhs) enumerate(rhsel2)
220+
rhs′ = conv(rhs)
221+
(skipzeros && rhs′ isa Number && iszero(rhs′)) && continue
222+
push!(ip_sys_exprs, :($X[$i][$j][$k] = $rhs′))
223+
end
224+
end
225+
end
226+
elseif is_array_sparse_matrix(rhss) # Array of sparse matrices
227+
for (i, rhsel) enumerate(_rhss)
228+
for (j, rhs) enumerate(rhsel.nzval)
229+
rhs′ = conv(rhs)
230+
(skipzeros && rhs′ isa Number && iszero(rhs′)) && continue
231+
push!(ip_sys_exprs, :($X[$i].nzval[$j] = $rhs′))
232+
end
233+
end
211234
elseif is_array_matrix(rhss) # Array of arrays
212-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)], init=Expr[])
235+
for (i, rhsel) enumerate(_rhss)
236+
for (j, rhs) enumerate(rhsel)
237+
rhs′ = conv(rhs)
238+
(skipzeros && rhs′ isa Number && iszero(rhs′)) && continue
239+
push!(ip_sys_exprs, :($X[$i][$j] = $rhs′))
240+
end
241+
end
213242
elseif rhss isa SparseMatrixCSC
214-
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]
243+
for (i, rhs) enumerate(_rhss)
244+
rhs′ = conv(rhs)
245+
(skipzeros && rhs′ isa Number && iszero(rhs′)) && continue
246+
push!(ip_sys_exprs, :($X.nzval[$i] = $rhs′))
247+
end
215248
else
216-
ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]
249+
for (i, rhs) enumerate(_rhss)
250+
rhs′ = conv(rhs)
251+
(skipzeros && rhs′ isa Number && iszero(rhs′)) && continue
252+
push!(ip_sys_exprs, :($X[$(oidx(i))] = $rhs′))
253+
end
217254
end
218255

219256
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))

src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,3 @@ end
133133
_substitute(expr, ks, vs) = substituter(ks, vs)(expr)
134134

135135
@deprecate substitute_expr!(expr,s) substitute(expr,s)
136-

test/build_function.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,33 @@ using ModelingToolkit, Test
22
@variables a b c1 c2 c3 d e g
33

44
# Multiple argument matrix
5-
h = [a + b + c1 + c2; c3 + d + e + g] # uses the same number of arguments as our application
6-
h_julia(a, b, c, d, e, g) = [a[1] + b[1] + c[1] + c[2]; c[3] + d[1] + e[1] + g[1]]
5+
h = [a + b + c1 + c2,
6+
c3 + d + e + g,
7+
0] # uses the same number of arguments as our application
8+
h_julia(a, b, c, d, e, g) = [a[1] + b[1] + c[1] + c[2],
9+
c[3] + d[1] + e[1] + g[1],
10+
0]
711
function h_julia!(out, a, b, c, d, e, g)
8-
out .= [a[1] + b[1] + c[1] + c[2]; c[3] + d[1] + e[1] + g[1]]
12+
out .= [a[1] + b[1] + c[1] + c[2], c[3] + d[1] + e[1] + g[1], 0]
913
end
1014

1115
h_str = ModelingToolkit.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g])
1216
h_oop = eval(h_str[1])
1317
h_ip! = eval(h_str[2])
18+
h_ip_skip! = eval(ModelingToolkit.build_function(h, [a], [b], [c1, c2, c3], [d], [e], [g], skipzeros=true)[2])
1419
inputs = ([1], [2], [3, 4, 5], [6], [7], [8])
1520

1621
@test h_oop(inputs...) == h_julia(inputs...)
17-
out_1 = Array{Int64}(undef, 2)
22+
out_1 = similar(h, Int)
1823
out_2 = similar(out_1)
1924
h_ip!(out_1, inputs...)
2025
h_julia!(out_2, inputs...)
2126
@test out_1 == out_2
27+
fill!(out_1, 10)
28+
h_ip_skip!(out_1, inputs...)
29+
@test out_1[3] == 10
30+
out_1[3] = 0
31+
@test out_1 == out_2
2232

2333
# Multiple input matrix, some unused arguments
2434
h_skip = [a + b + c1; c2 + c3 + g] # skip d, e

test/build_function_arrayofarray.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@ function h_dense_arraymat_julia!(out, x)
1818
out[3] .= [a[1] c[1]; 1 0]
1919
end
2020

21-
h_dense_arraymat_str = ModelingToolkit.build_function(h_dense_arraymat, [a, b, c])
22-
h_dense_arraymat_ip! = eval(h_dense_arraymat_str[2])
23-
out_1_arraymat = [Array{Int64}(undef, 2, 2) for i in 1:3]
24-
out_2_arraymat = [similar(x) for x in out_1_arraymat]
21+
h_dense_arraymat_ip! = eval(ModelingToolkit.build_function(h_dense_arraymat, [a, b, c])[2])
22+
h_dense_arraymat_ip_skip! = eval(ModelingToolkit.build_function(h_dense_arraymat, [a, b, c], skipzeros=true)[2])
23+
out_1_arraymat = [fill(42, 2, 2) for i in 1:3]
24+
out_2_arraymat = deepcopy(out_1_arraymat)
2525
h_dense_arraymat_julia!(out_1_arraymat, input)
26+
h_dense_arraymat_ip_skip!(out_2_arraymat, input)
27+
@test all(isequal(42), out_2_arraymat[2])
28+
foreach(mat->fill!(mat, 0), out_2_arraymat)
29+
h_dense_arraymat_ip_skip!(out_2_arraymat, input)
30+
@test out_1_arraymat == out_2_arraymat
2631
h_dense_arraymat_ip!(out_2_arraymat, input)
2732
@test out_1_arraymat == out_2_arraymat
2833

0 commit comments

Comments
 (0)