@@ -110,7 +110,18 @@ const MtlMatrix{T,S} = MtlArray{T,2,S}
110110const MtlVecOrMat{T,S} = Union{MtlVector{T,S},MtlMatrix{T,S}}
111111
112112# default to private memory
113- const DefaultStorageMode = Private
113+ const DefaultStorageMode = let str = @load_preference (" default_storage" , " Private" )
114+ if str == " Private"
115+ Private
116+ elseif str == " Shared"
117+ Shared
118+ elseif str == " Managed"
119+ Managed
120+ else
121+ error (" unknown default storage mode: $default_storage " )
122+ end
123+ end
124+
114125MtlArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N} =
115126 MtlArray {T,N,DefaultStorageMode} (undef, dims)
116127
@@ -170,14 +181,16 @@ end
170181
171182function Base. unsafe_convert (:: Type{MtlPointer{T}} , x:: MtlArray ) where {T}
172183 buf = x. data[]
184+ synchronize ()
173185 MtlPointer {T} (buf, x. offset* Base. elsize (x))
174186 end
175187
176188function Base. unsafe_convert (:: Type{Ptr{S}} , x:: MtlArray{T} ) where {S, T}
177- buf = x. data[]
178189 if is_private (x)
179190 throw (ArgumentError (" cannot take the CPU address of a $(typeof (x)) " ))
180191 end
192+ synchronize ()
193+ buf = x. data[]
181194 convert (Ptr{T}, buf) + x. offset* Base. elsize (x)
182195end
183196
@@ -237,7 +250,7 @@ Base.convert(::Type{T}, x::T) where T <: MtlArray = x
237250Base. unsafe_convert (:: Type{<:Ptr} , x:: MtlArray ) =
238251 throw (ArgumentError (" cannot take the host address of a $(typeof (x)) " ))
239252
240- Base. unsafe_convert (t :: Type{MTL.MTLBuffer} , x:: MtlArray ) = x. data[]
253+ Base. unsafe_convert (:: Type{MTL.MTLBuffer} , x:: MtlArray ) = x. data[]
241254
242255
243256# # interop with ObjC libraries
0 commit comments