File tree Expand file tree Collapse file tree 1 file changed +18
-7
lines changed Expand file tree Collapse file tree 1 file changed +18
-7
lines changed Original file line number Diff line number Diff line change @@ -71,18 +71,29 @@ const NormalArray = MtlArray{<:Float32}
7171end
7272
7373# CPU arrays
74- # TODO : use unsafe_wrap when possible
7574function Random. rand! (rng:: RNG , A:: AbstractArray{T, N} ) where {T <: Union{UniformTypes...} , N}
7675 isempty (A) && return A
77- B = MtlArray {T, N, SharedStorage} (undef, size (A))
78- rand! (rng, B)
79- return copyto! (A, B)
76+ if MTL. can_alloc_nocopy (pointer (A), sizeof (A))
77+ mtlA = unsafe_wrap (MtlArray{T, N}, A)
78+ rand! (rng, mtlA)
79+ else
80+ B = MtlArray {T, N, SharedStorage} (undef, size (A))
81+ rand! (rng, B)
82+ copyto! (A, B)
83+ end
84+ return A
8085end
8186function Random. randn! (rng:: RNG , A:: AbstractArray{T, N} ) where {T <: Float32 , N}
8287 isempty (A) && return A
83- B = MtlArray {T, N, SharedStorage} (undef, size (A))
84- randn! (rng, B)
85- return copyto! (A, B)
88+ if MTL. can_alloc_nocopy (pointer (A), sizeof (A))
89+ mtlA = unsafe_wrap (MtlArray{T, N}, A)
90+ randn! (rng, mtlA)
91+ else
92+ B = MtlArray {T, N, SharedStorage} (undef, size (A))
93+ randn! (rng, B)
94+ copyto! (A, B)
95+ end
96+ return A
8697end
8798
8899# Out of place
You can’t perform that action at this time.
0 commit comments