@@ -182,6 +182,12 @@ function _is_array_of_symbolics(O)
182182 any (x -> symbolic_type (x) != NotSymbolic () || _is_array_of_symbolics (x), O))
183183end
184184
185+ # workaround for https://github.com/JuliaSparse/SparseArrays.jl/issues/599
186+ function _is_array_of_symbolics (O:: SparseMatrixCSC )
187+ return symbolic_type (eltype (O)) != NotSymbolic () ||
188+ any (x -> symbolic_type (x) != NotSymbolic () || _is_array_of_symbolics (x), findnz (O)[3 ])
189+ end
190+
185191function toexpr (O, st)
186192 if issym (O)
187193 O = substitute_name (O, st)
@@ -190,7 +196,7 @@ function toexpr(O, st)
190196 O = substitute_name (O, st)
191197
192198 if _is_array_of_symbolics (O)
193- return toexpr (MakeArray (O, typeof (O)), st)
199+ return issparse (O) ? toexpr ( MakeSparseArray (O)) : toexpr (MakeArray (O, typeof (O)), st)
194200 end
195201 ! iscall (O) && return O
196202 op = operation (O)
697703
698704@inline newsym (:: Type{T} ) where T = Sym {T} (gensym (" cse" ))
699705
706+ """
707+ $(TYPEDSIGNATURES)
708+
709+ Return `true` if CSE should descend inside `sym`, which has operation `f` and
710+ arguments `args...`.
711+ """
712+ function cse_inside_expr (sym, f, args... )
713+ return true
714+ end
715+
700716"""
701717$(SIGNATURES)
702718
@@ -721,16 +737,30 @@ function topological_sort(graph)
721737 return visited[node]
722738 end
723739 if iscall (node)
740+ op = operation (node)
741+ args = arguments (node)
742+ if ! cse_inside_expr (node, op, args... )
743+ visited[node] = node
744+ return node
745+ end
724746 args = map (dfs, arguments (node))
725747 # use `term` instead of `maketerm` because we only care about the operation being performed
726748 # and not the representation. This avoids issues with `newsym` symbols not having sizes, etc.
727- new_node = term (operation (node), args... )
749+ new_node = term (operation (node), args... ; type = symtype (node) )
728750 sym = newsym (symtype (new_node))
729751 push! (sorted_nodes, sym ← new_node)
730752 visited[node] = sym
731753 return sym
732754 elseif _is_array_of_symbolics (node)
733- new_node = map (dfs, node)
755+ # workaround for https://github.com/JuliaSparse/SparseArrays.jl/issues/599
756+ if issparse (node)
757+ new_node = copy (node)
758+ for (i, j, v) in zip (findnz (node)... )
759+ new_node[i, j] = dfs (v)
760+ end
761+ else
762+ new_node = map (dfs, node)
763+ end
734764 sym = newsym (typeof (new_node))
735765 push! (sorted_nodes, sym ← new_node)
736766 visited[node] = sym
0 commit comments