@@ -258,7 +258,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
258
258
rhss = SparseMatrixCSC (rhss. m, rhss. m, rhss. colptr, rhss. rowval, map (unflatten_long_ops, rhss. nzval))
259
259
else
260
260
rhs_length = length (rhss)
261
- rhss = map ( unflatten_long_ops, rhss)
261
+ rhss = Expression[ unflatten_long_ops (r) for r in rhss]
262
262
end
263
263
264
264
if parallel isa DistributedForm
@@ -268,6 +268,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
268
268
finalsize = rhs_length - (numworks- 1 )* lens
269
269
_rhss = vcat (reduce (vcat,[[getindex (reducevars[i],j) for j in 1 : lens] for i in 1 : numworks- 1 ],init= Expr[]),
270
270
[getindex (reducevars[end ],j) for j in 1 : finalsize])
271
+
271
272
elseif parallel isa DaggerForm
272
273
computevars = [Variable (gensym (:MTComputeVar ))() for i in axes (rhss,1 )]
273
274
reducevar = Variable (gensym (:MTReduceVar ))()
@@ -376,8 +377,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
376
377
dagwrap (ex:: Expr ) = dagwrap (ex, Val (ex. head))
377
378
dagwrap (ex:: Expr , :: Val ) = ex
378
379
dagwrap (ex:: Expr , :: Val{:call} ) = :(Dagger. delayed ($ (ex. args[1 ]))($ (dagwrap .(ex. args[2 : end ])... )))
379
-
380
- new_rhss = dagwrap .(conv .(Array (rhss)))
380
+ new_rhss = dagwrap .(conv .(rhss))
381
381
delayed_exprs = build_expr (:block , [:($ (Symbol (computevars[i])) = Dagger. delayed (identity)($ (new_rhss[i]))) for i in axes (computevars,1 )])
382
382
# TODO : treereduce?
383
383
reduce_expr = quote
0 commit comments