Skip to content

Commit 46298ca

Browse files
Fix unsafe_wrap of a view. (#452)
And some other small fixes.
1 parent cd9e68a commit 46298ca

File tree

4 files changed

+12
-3
lines changed

4 files changed

+12
-3
lines changed

perf/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ const m = 512
22
const n = 1000
33

44
for (S, smname) in [(Metal.PrivateStorage,"private"), (Metal.SharedStorage,"shared")]
5-
group = addgroup!(SUITE, "$smname array")
5+
local group = addgroup!(SUITE, "$smname array")
66

77
# generate some arrays
88
cpu_mat = rand(rng, Float32, m, n)

src/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ Base.unsafe_convert(::Type{MTL.MTLBuffer}, A::PermutedDimsArray) =
556556
## unsafe_wrap
557557

558558
function Base.unsafe_wrap(::Type{<:Array}, arr::MtlArray{T,N}, dims=size(arr); own=false) where {T,N}
559-
return unsafe_wrap(Array{T,N}, arr.data[], dims; own)
559+
return unsafe_wrap(Array{T,N}, pointer(arr), dims; own)
560560
end
561561

562562
function Base.unsafe_wrap(t::Type{<:Array{T}}, buf::MTLBuffer, dims; own=false) where T

src/device/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ Base.show(io::IO, a::MtlDeviceArray) =
149149
Base.show(io::IO, mime::MIME"text/plain", a::MtlDeviceArray) = show(io, a)
150150

151151
@inline function Base.unsafe_view(A::MtlDeviceVector{T}, I::Vararg{Base.ViewIndex,1}) where {T}
152-
ptr = pointer(A) + (I[1].start-1)*sizeof(T)
152+
ptr = pointer(A, I[1].start)
153153
len = I[1].stop - I[1].start + 1
154154
return MtlDeviceArray(len, ptr)
155155
end

test/array.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,15 @@ end
370370
arr2 .+= 1;
371371
@test all(arr2 .== 2)
372372
@test all(marr2 .== 2)
373+
374+
@testset "Issue #451" begin
375+
a = mtl(reshape(Float32.(1:60), 5,4,3);storage=Metal.SharedStorage)
376+
view_a = @view a[:,1:4,2]
377+
b = copy(unsafe_wrap(Array, view_a))
378+
c = Array(view_a)
379+
380+
@test b == c
381+
end
373382
end
374383

375384
@testset "ReshapedArray" begin

0 commit comments

Comments
 (0)