11# host array
22
3- export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl
3+ export MtlArray, MtlVector, MtlMatrix, MtlVecOrMat, mtl, is_shared, is_managed, is_private
44
55function hasfieldcount (@nospecialize (dt))
66 try
@@ -77,8 +77,16 @@ mutable struct MtlArray{T,N,S} <: AbstractGPUArray{T,N}
7777 function MtlArray {T,N} (data:: DataRef{<:MTLBuffer} , dims:: Dims{N} ;
7878 maxsize:: Int = prod (dims) * sizeof (T), offset:: Int = 0 ) where {T,N}
7979 check_eltype (T)
80- S = convert (MTL. MTLResourceOptions, data[]. storageMode)
81- obj = new {T,N,S} (copy (data), maxsize, offset, dims)
80+ storagemode = data[]. storageMode
81+ if storagemode == MTL. MTLStorageModeShared
82+ obj = new {T,N,Shared} (copy (data), maxsize, offset, dims)
83+ elseif storagemode == MTL. MTLStorageModeManaged
84+ obj = new {T,N,Managed} (copy (data), maxsize, offset, dims)
85+ elseif storagemode == MTL. MTLStorageModePrivate
86+ obj = new {T,N,Private} (copy (data), maxsize, offset, dims)
87+ elseif storagemode == MTL. MTLStorageModeMemoryless
88+ obj = new {T,N,Memoryless} (copy (data), maxsize, offset, dims)
89+ end
8290 finalizer (unsafe_free!, obj)
8391 end
8492end
@@ -90,6 +98,10 @@ device(A::MtlArray) = A.data[].device
9098storagemode (x:: MtlArray ) = storagemode (typeof (x))
9199storagemode (:: Type{<:MtlArray{<:Any,<:Any,S}} ) where {S} = S
92100
101+ is_shared (a:: MtlArray ) = storagemode (a) == Shared
102+ is_managed (a:: MtlArray ) = storagemode (a) == Managed
103+ is_private (a:: MtlArray ) = storagemode (a) == Private
104+ is_memoryless (a:: MtlArray ) = storagemode (a) == Memoryless
93105
94106# # convenience constructors
95107
@@ -144,15 +156,42 @@ Base.elsize(::Type{<:MtlArray{T}}) where {T} = sizeof(T)
144156Base. size (x:: MtlArray ) = x. dims
145157Base. sizeof (x:: MtlArray ) = Base. elsize (x) * length (x)
146158
147- Base. pointer (x:: MtlArray{T} ) where {T} = Base. unsafe_convert (MtlPointer{T}, x)
148- @inline function Base. pointer (x:: MtlArray{T} , i:: Integer ) where T
149- Base. unsafe_convert (MtlPointer{T}, x) + Base. _memory_offset (x, i)
159+ @inline function Base. pointer (x:: MtlArray{T} , i:: Integer = 1 ; storage= Private) where {T}
160+ PT = if storage == Private
161+ MtlPointer{T}
162+ elseif storage == Shared || storage == Managed
163+ Ptr{T}
164+ else
165+ error (" unknown memory type" )
166+ end
167+ Base. unsafe_convert (PT, x) + Base. _memory_offset (x, i)
150168end
151169
152- Base. unsafe_convert (:: Type{Ptr{S}} , x:: MtlArray{T} ) where {S, T} =
153- throw (ArgumentError (" cannot take the CPU address of a $(typeof (x)) " ))
154- Base. unsafe_convert (:: Type{MtlPointer{T}} , x:: MtlArray ) where {T} =
155- MtlPointer {T} (x. data[], x. offset* Base. elsize (x))
170+
171+ function Base. unsafe_convert (:: Type{MtlPointer{T}} , x:: MtlArray ) where {T}
172+ buf = x. data[]
173+ MtlPointer {T} (buf, x. offset* Base. elsize (x))
174+ end
175+
176+ function Base. unsafe_convert (:: Type{Ptr{S}} , x:: MtlArray{T} ) where {S, T}
177+ buf = x. data[]
178+ if is_private (x)
179+ throw (ArgumentError (" cannot take the CPU address of a $(typeof (x)) " ))
180+ end
181+ convert (Ptr{T}, buf) + x. offset* Base. elsize (x)
182+ end
183+
184+
185+ # # indexing
186+ function Base. getindex (x:: MtlArray{T,N,S} , I:: Int ) where {T,N,S<: Union{Shared,Managed} }
187+ @boundscheck checkbounds (x, I)
188+ unsafe_load (pointer (x, I; storage= S))
189+ end
190+
191+ function Base. setindex! (x:: MtlArray{T,N,S} , v, I:: Int ) where {T,N,S<: Union{Shared,Managed} }
192+ @boundscheck checkbounds (x, I)
193+ unsafe_store! (pointer (x, I; storage= S), v)
194+ end
156195
157196
158197# # interop with other arrays
@@ -354,7 +393,7 @@ Uses Adapt.jl to act inside some wrapper structs.
354393
355394```jldoctests
356395julia> mtl(ones(3)')
357- 1×3 adjoint(::MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate }) with eltype Float32:
396+ 1×3 adjoint(::MtlVector{Float32, Private }) with eltype Float32:
358397 1.0 1.0 1.0
359398
360399julia> mtl(zeros(1,3); storage=Shared)
@@ -365,13 +404,13 @@ julia> mtl(1:3)
3654041:3
366405
367406julia> MtlArray(1:3)
368- 3-element MtlVector{Int64, Metal.MTL.MTLResourceStorageModePrivate }:
407+ 3-element MtlVector{Int64, Private }:
369408 1
370409 2
371410 3
372411
373412julia> mtl[1,2,3]
374- 3-element MtlVector{Int64, Metal.MTL.MTLResourceStorageModePrivate }:
413+ 3-element MtlVector{Int64, Private }:
375414 1
376415 2
377416 3
@@ -433,8 +472,9 @@ Base.unsafe_convert(::Type{MTL.MTLBuffer}, A::PermutedDimsArray) =
433472
434473# # unsafe_wrap
435474
436- Base. unsafe_wrap (t:: Type{<:Array} , arr:: MtlArray , dims; own= false ) =
437- unsafe_wrap (t, arr. data[], dims; own= own)
475+ function Base. unsafe_wrap (:: Type{<:Array} , arr:: MtlArray{T,N} , dims= size (arr); own= false ) where {T,N}
476+ return unsafe_wrap (Array{T,N}, arr. data[], dims; own= own)
477+ end
438478
439479function Base. unsafe_wrap (t:: Type{<:Array{T}} , buf:: MTLBuffer , dims; own= false ) where T
440480 ptr = convert (Ptr{T}, buf)
0 commit comments