Skip to content

Commit 202008d

Browse files
committed
NDArray construction bug fix
1 parent d8370f7 commit 202008d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

lib/mps/ndarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function MPSNDArray(arr::MtlArray{T,N}) where {T,N}
114114
arrsize = size(arr)
115115
@assert arrsize[1] * sizeof(T) % 16 == 0 "First dimension of input MtlArray must have a byte size divisible by 16"
116116
desc = MPSNDArrayDescriptor(T, arrsize)
117-
return MPSNDArray(arr.data[], UInt(arr.offset), desc)
117+
return MPSNDArray(arr.data[], UInt(arr.offset) * sizeof(T), desc)
118118
end
119119

120120
function Metal.MtlArray(ndarr::MPSNDArray; storage = Metal.DefaultStorageMode, async = false)

0 commit comments

Comments
 (0)