@@ -35,11 +35,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dst::MtlPtr{T}, src::Ptr{T}, N::Int
3535 storage_type = dst. buffer. storageMode
3636 if storage_type == MTL. MTLStorageModePrivate
3737 # stage through a shared buffer
38- # shared = alloc(dev, N*sizeof(T), src; storage=Shared)
39- # unsafe_copyto!(dev, dst, pointer(shared), N; queue, async=false)
40- # free(shared)
41- tmp_buf = alloc (dev, N* sizeof (T), src; storage= Shared) # CPU -> GPU (Shared)
42- unsafe_copyto! (dev, MtlPtr {T} (dst. buffer, dst. offset), MtlPtr {T} (tmp_buf, 0 ), N; queue, async= false ) # GPU (Shared) -> GPU (Private)
38+ nocopy = MTL. can_alloc_nocopy (src, N* sizeof (T))
39+ tmp_buf = alloc (dev, N* sizeof (T), src; storage= Shared, nocopy)
40+
41+ # copy to the private buffer
42+ unsafe_copyto! (dev, MtlPtr {T} (dst. buffer, dst. offset), MtlPtr {T} (tmp_buf, 0 ), N;
43+ queue, async= (nocopy && async))
4344 free (tmp_buf)
4445 elseif storage_type == MTL. MTLStorageModeShared
4546 unsafe_copyto! (convert (Ptr{T}, dst), src, N)
5455function Base. unsafe_copyto! (dev:: MTLDevice , dst:: Ptr{T} , src:: MtlPtr{T} , N:: Integer ;
5556 queue:: MTLCommandQueue = global_queue (dev), async:: Bool = false ) where T
5657 storage_type = src. buffer. storageMode
57- if storage_type == MTL. MTLStorageModePrivate
58+ if storage_type == MTL. MTLStorageModePrivate
5859 # stage through a shared buffer
59- shared = alloc (dev, N* sizeof (T); storage= Shared)
60- unsafe_copyto! (dev, MtlPtr {T} (shared, 0 ), MtlPtr {T} (src. buffer, src. offset), N; queue, async= false )
61- unsafe_copyto! (dst, convert (Ptr{T}, shared), N)
62- free (shared)
60+ nocopy = MTL. can_alloc_nocopy (dst, N* sizeof (T))
61+ tmp_buf = if nocopy
62+ alloc (dev, N* sizeof (T), dst; storage= Shared, nocopy)
63+ else
64+ alloc (dev, N* sizeof (T); storage= Shared)
65+ end
66+ unsafe_copyto! (dev, MtlPtr {T} (tmp_buf, 0 ), MtlPtr {T} (src. buffer, src. offset), N;
67+ queue, async= (nocopy && async))
68+
69+ # copy from the shared buffer
70+ if ! nocopy
71+ unsafe_copyto! (dst, convert (Ptr{T}, tmp_buf), N)
72+ end
73+ free (tmp_buf)
6374 elseif storage_type == MTL. MTLStorageModeShared
6475 unsafe_copyto! (dst, convert (Ptr{T}, src), N)
6576 elseif storage_type == MTL. MTLStorageModeManaged
0 commit comments