Skip to content

Commit 1e68728

Browse files
authored
fix: fixes for random sampling (#1513)
1 parent 688d041 commit 1e68728

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

src/Overlay.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ for randfun in (:rand, :randn, :randexp)
4343
@reactant_overlay @noinline function Random.$(randfun)(
4444
rng::AbstractRNG, ::Type{T}, dims::Dims
4545
) where {T}
46-
if T <: ReactantPrimitive
47-
return TracedRandom.$(overload_randfun)(rng, T, dims)
46+
if unwrapped_eltype(T) <: ReactantPrimitive
47+
return TracedRandom.$(overload_randfun)(rng, unwrapped_eltype(T), dims)
4848
end
4949
@warn "Reactant doesn't support sampling of $(T) with the current \
5050
interpreter. Falling back to native interpreter." maxlog = 1
@@ -60,8 +60,10 @@ for randfun in (:rand, :randn, :randexp)
6060
@reactant_overlay @noinline function Random.$(randfun)(
6161
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
6262
) where {T}
63-
if T <: ReactantPrimitive
64-
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
63+
if unwrapped_eltype(T) <: ReactantPrimitive
64+
return TracedRandom.$(overload_randfun)(
65+
rng, unwrapped_eltype(T), dim1, dims...
66+
)
6567
end
6668
@warn "Reactant doesn't support sampling of $(T) with the current \
6769
interpreter. Falling back to native interpreter." maxlog = 1
@@ -72,8 +74,8 @@ for randfun in (:rand, :randn, :randexp)
7274
@reactant_overlay @noinline function Random.$(randfun)(
7375
rng::AbstractRNG, ::Type{T}=Float64
7476
) where {T}
75-
if T <: ReactantPrimitive
76-
return TracedRandom.$(overload_randfun)(rng, T)
77+
if unwrapped_eltype(T) <: ReactantPrimitive
78+
return TracedRandom.$(overload_randfun)(rng, unwrapped_eltype(T))
7779
end
7880
@warn "Reactant doesn't support sampling of $(T) with the current \
7981
interpreter. Falling back to native interpreter." maxlog = 1
@@ -86,6 +88,9 @@ for randfun in (:rand, :randn, :randexp)
8688
)
8789
return TracedRandom.$(overload_randfun!)(rng, A)
8890
end
91+
@reactant_overlay @noinline function Random.$(randfun!)(A::AnyTracedRArray)
92+
return TracedRandom.$(overload_randfun!)(TracedRandom.default_rng(), A)
93+
end
8994
end
9095
end
9196

test/integration/random.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,19 @@ end
185185
@test μ 1.0 atol = 0.05 rtol = 0.05
186186
end
187187
end
188+
189+
rand_sample(rng, x) = rand(rng, eltype(x), size(x))
190+
191+
function rand_on_device()
192+
x = Reactant.Ops.fill(0.0f0, (3, 4, 5))
193+
rand!(x)
194+
return x
195+
end
196+
197+
@testset "TracedTypes in Sampling" begin
198+
@test @jit(rand_sample(Reactant.ReactantRNG(), rand(3, 4))) isa
199+
ConcreteRArray{Float64,2}
200+
@test @jit(rand_sample(Reactant.ReactantRNG(), Reactant.to_rarray(rand(3, 4)))) isa
201+
ConcreteRArray{Float64,2}
202+
@test @jit(rand_on_device()) isa ConcreteRArray{Float32,3}
203+
end

0 commit comments

Comments
 (0)