Skip to content

Commit add4021

Browse files
committed
don't rely on map for infering type of unflatten_long_ops
1 parent c1285c7 commit add4021

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/build_function.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
258258
rhss = SparseMatrixCSC(rhss.m, rhss.m, rhss.colptr, rhss.rowval, map(unflatten_long_ops, rhss.nzval))
259259
else
260260
rhs_length = length(rhss)
261-
rhss = map(unflatten_long_ops, rhss)
261+
rhss = Expression[unflatten_long_ops(r) for r in rhss]
262262
end
263263

264264
if parallel isa DistributedForm
@@ -268,6 +268,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
268268
finalsize = rhs_length - (numworks-1)*lens
269269
_rhss = vcat(reduce(vcat,[[getindex(reducevars[i],j) for j in 1:lens] for i in 1:numworks-1],init=Expr[]),
270270
[getindex(reducevars[end],j) for j in 1:finalsize])
271+
271272
elseif parallel isa DaggerForm
272273
computevars = [Variable(gensym(:MTComputeVar))() for i in axes(rhss,1)]
273274
reducevar = Variable(gensym(:MTReduceVar))()
@@ -376,8 +377,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
376377
dagwrap(ex::Expr) = dagwrap(ex, Val(ex.head))
377378
dagwrap(ex::Expr, ::Val) = ex
378379
dagwrap(ex::Expr, ::Val{:call}) = :(Dagger.delayed($(ex.args[1]))($(dagwrap.(ex.args[2:end])...)))
379-
380-
new_rhss = dagwrap.(conv.(Array(rhss)))
380+
new_rhss = dagwrap.(conv.(rhss))
381381
delayed_exprs = build_expr(:block, [:($(Symbol(computevars[i])) = Dagger.delayed(identity)($(new_rhss[i]))) for i in axes(computevars,1)])
382382
# TODO: treereduce?
383383
reduce_expr = quote

0 commit comments

Comments
 (0)