Skip to content

Commit b7606f0

Browse files
Remove copy when possible from cpu rand using GPU RNG (#568)
* Remove copy when possible from cpu rand using GPU RNG * Format
1 parent ebd90c5 commit b7606f0

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

lib/mps/random.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,29 @@ const NormalArray = MtlArray{<:Float32}
7171
end
7272

7373
# CPU arrays
74-
# TODO: use unsafe_wrap when possible
7574
function Random.rand!(rng::RNG, A::AbstractArray{T, N}) where {T <: Union{UniformTypes...}, N}
7675
isempty(A) && return A
77-
B = MtlArray{T, N, SharedStorage}(undef, size(A))
78-
rand!(rng, B)
79-
return copyto!(A, B)
76+
if MTL.can_alloc_nocopy(pointer(A), sizeof(A))
77+
mtlA = unsafe_wrap(MtlArray{T, N}, A)
78+
rand!(rng, mtlA)
79+
else
80+
B = MtlArray{T, N, SharedStorage}(undef, size(A))
81+
rand!(rng, B)
82+
copyto!(A, B)
83+
end
84+
return A
8085
end
8186
function Random.randn!(rng::RNG, A::AbstractArray{T, N}) where {T <: Float32, N}
8287
isempty(A) && return A
83-
B = MtlArray{T, N, SharedStorage}(undef, size(A))
84-
randn!(rng, B)
85-
return copyto!(A, B)
88+
if MTL.can_alloc_nocopy(pointer(A), sizeof(A))
89+
mtlA = unsafe_wrap(MtlArray{T, N}, A)
90+
randn!(rng, mtlA)
91+
else
92+
B = MtlArray{T, N, SharedStorage}(undef, size(A))
93+
randn!(rng, B)
94+
copyto!(A, B)
95+
end
96+
return A
8697
end
8798

8899
# Out of place

0 commit comments

Comments
 (0)