Skip to content

Commit bb8259f

Browse files
cuTENSOR: Preserve storage type when multiplying (#2775)
1 parent 71bc923 commit bb8259f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

lib/cutensor/src/interfaces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function Base.:(*)(A::CuTensor, B::CuTensor)
3030
B_sizes = map(x->size(B,x[1]), B_uniqs)
3131
A_inds = map(x->x[2], A_uniqs)
3232
B_inds = map(x->x[2], B_uniqs)
33-
C = CuTensor(CUDA.zeros(tC, Dims(vcat(A_sizes, B_sizes))), vcat(A_inds, B_inds))
33+
C = CuTensor(fill!(similar(B.data, tC, Dims(vcat(A_sizes, B_sizes))), zero(tC)), vcat(A_inds, B_inds))
3434
return mul!(C, A, B)
3535
end
3636

0 commit comments

Comments
 (0)