Skip to content

Commit 649437f

Browse files
Merge pull request #703 from AayushSabharwal/as/cse-sparse
fix: handle CSE and codegen for sparse arrays
2 parents b1f111b + fa0fc98 commit 649437f

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

src/code.jl

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ function _is_array_of_symbolics(O)
182182
any(x -> symbolic_type(x) != NotSymbolic() || _is_array_of_symbolics(x), O))
183183
end
184184

185+
# workaround for https://github.com/JuliaSparse/SparseArrays.jl/issues/599
186+
function _is_array_of_symbolics(O::SparseMatrixCSC)
187+
return symbolic_type(eltype(O)) != NotSymbolic() ||
188+
any(x -> symbolic_type(x) != NotSymbolic() || _is_array_of_symbolics(x), findnz(O)[3])
189+
end
190+
185191
function toexpr(O, st)
186192
if issym(O)
187193
O = substitute_name(O, st)
@@ -190,7 +196,7 @@ function toexpr(O, st)
190196
O = substitute_name(O, st)
191197

192198
if _is_array_of_symbolics(O)
193-
return toexpr(MakeArray(O, typeof(O)), st)
199+
return issparse(O) ? toexpr(MakeSparseArray(O)) : toexpr(MakeArray(O, typeof(O)), st)
194200
end
195201
!iscall(O) && return O
196202
op = operation(O)
@@ -697,6 +703,16 @@ end
697703

698704
@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse"))
699705

706+
"""
707+
$(TYPEDSIGNATURES)
708+
709+
Return `true` if CSE should descend inside `sym`, which has operation `f` and
710+
arguments `args...`.
711+
"""
712+
function cse_inside_expr(sym, f, args...)
713+
return true
714+
end
715+
700716
"""
701717
$(SIGNATURES)
702718
@@ -721,16 +737,30 @@ function topological_sort(graph)
721737
return visited[node]
722738
end
723739
if iscall(node)
740+
op = operation(node)
741+
args = arguments(node)
742+
if !cse_inside_expr(node, op, args...)
743+
visited[node] = node
744+
return node
745+
end
724746
args = map(dfs, arguments(node))
725747
# use `term` instead of `maketerm` because we only care about the operation being performed
726748
# and not the representation. This avoids issues with `newsym` symbols not having sizes, etc.
727-
new_node = term(operation(node), args...)
749+
new_node = term(operation(node), args...; type = symtype(node))
728750
sym = newsym(symtype(new_node))
729751
push!(sorted_nodes, sym new_node)
730752
visited[node] = sym
731753
return sym
732754
elseif _is_array_of_symbolics(node)
733-
new_node = map(dfs, node)
755+
# workaround for https://github.com/JuliaSparse/SparseArrays.jl/issues/599
756+
if issparse(node)
757+
new_node = copy(node)
758+
for (i, j, v) in zip(findnz(node)...)
759+
new_node[i, j] = dfs(v)
760+
end
761+
else
762+
new_node = map(dfs, node)
763+
end
734764
sym = newsym(typeof(new_node))
735765
push!(sorted_nodes, sym new_node)
736766
visited[node] = sym

test/code.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,27 @@ nanmath_st.rewrites[:nanmath] = true
264264
@test f(1.0, 2.0) 13.0 + sqrt(2)
265265
end
266266
end
267+
268+
@testset "Sparse array CSE" begin
269+
@syms x y z
270+
arr = [x^2 + y^2 0 0; 0 sin(y^2 + z^2) 0; 0 0 z^2 + x^2]
271+
sarr = sparse(arr);
272+
fn = eval(toexpr(Func([x, y, z], [], Code.cse(sarr))))
273+
274+
expected = eval(toexpr(Let([x 1, y 2, z 3], sarr)))
275+
@test fn(1, 2, 3) expected
276+
end
277+
278+
function foo(args...) end
279+
280+
SymbolicUtils.Code.cse_inside_expr(sym, ::typeof(foo), args...) = false
281+
282+
@testset "`cse_inside_expr`" begin
283+
@syms x y
284+
ex1 = (x^2 + y^2)
285+
exfoo = term(foo, ex1; type = Real)
286+
ex2 = ex1 + exfoo
287+
letblock = cse(ex2)
288+
ex3 = letblock.body
289+
@test any(isequal(exfoo), arguments(ex3))
290+
end

0 commit comments

Comments
 (0)