Skip to content

Commit 69d0432

Browse files
shashiYingboMa
andcommitted
SparseMatrixCSC goes to biuld function
Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent 2ac1f7d commit 69d0432

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/build_function.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function unflatten_long_ops(op, N=4)
9494
*(*((~~x)[1:N]...) * (*)((~~x)[N+1:end]...)) : nothing)
9595

9696
op = to_symbolic(op)
97-
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op) |> to_mtk
97+
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op)
9898
end
9999

100100
# Scalar output
@@ -399,8 +399,13 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
399399
end
400400
end
401401

402+
if rhss isa SparseMatrixCSC
403+
rhss′ = map(convunflatten_long_ops, rhss.nzval)
404+
else
405+
rhss′ = [conv(unflatten_long_ops(r)) for r in rhss]
406+
end
402407

403-
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
408+
tuple_sys_expr = build_expr(:tuple, rhss)
404409

405410
if rhss isa Matrix
406411
arr_sys_expr = build_expr(:vcat, [build_expr(:row,[conv(rhs) for rhs rhss[i,:]]) for i in 1:size(rhss,1)])

0 commit comments

Comments
 (0)