Skip to content

Commit ad2a65b

Browse files
committed
Widen random signatures to support wrappers.
1 parent 816c105 commit ad2a65b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/host/random.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct RNG <: AbstractRNG
6767
end
6868

6969
# return an instance of GPUArrays.RNG suitable for the requested array type
70-
default_rng(::Type{<:AbstractGPUArray}) = error("Not implemented") # COV_EXCL_LINE
70+
default_rng(::Type{<:AnyGPUArray}) = error("Not implemented") # COV_EXCL_LINE
7171

7272
make_seed(rng::RNG) = make_seed(rng, rand(UInt))
7373
function make_seed(rng::RNG, n::Integer)
@@ -81,7 +81,7 @@ function Random.seed!(rng::RNG, seed::Vector{UInt32})
8181
return
8282
end
8383

84-
function Random.rand!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
84+
function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
8585
gpu_call(A, rng.state) do ctx, a, randstates
8686
idx = linear_index(ctx)
8787
idx > length(a) && return
@@ -91,7 +91,7 @@ function Random.rand!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
9191
A
9292
end
9393

94-
function Random.randn!(rng::RNG, A::AbstractGPUArray{T}) where T <: Number
94+
function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
9595
threads = (length(A) - 1) ÷ 2 + 1
9696
length(A) == 0 && return
9797
gpu_call(A, rng.state; total_threads = threads) do ctx, a, randstates

0 commit comments

Comments
 (0)