@@ -43,8 +43,8 @@ for randfun in (:rand, :randn, :randexp)
43
43
@reactant_overlay @noinline function Random. $ (randfun)(
44
44
rng:: AbstractRNG , :: Type{T} , dims:: Dims
45
45
) where {T}
46
- if T <: ReactantPrimitive
47
- return TracedRandom.$ (overload_randfun)(rng, T , dims)
46
+ if unwrapped_eltype (T) <: ReactantPrimitive
47
+ return TracedRandom.$ (overload_randfun)(rng, unwrapped_eltype (T) , dims)
48
48
end
49
49
@warn " Reactant doesn't support sampling of $(T) with the current \
50
50
interpreter. Falling back to native interpreter." maxlog = 1
@@ -60,8 +60,10 @@ for randfun in (:rand, :randn, :randexp)
60
60
@reactant_overlay @noinline function Random. $ (randfun)(
61
61
rng:: AbstractRNG , :: Type{T} , dim1:: Integer , dims:: Integer...
62
62
) where {T}
63
- if T <: ReactantPrimitive
64
- return TracedRandom.$ (overload_randfun)(rng, T, dim1, dims... )
63
+ if unwrapped_eltype (T) <: ReactantPrimitive
64
+ return TracedRandom.$ (overload_randfun)(
65
+ rng, unwrapped_eltype (T), dim1, dims...
66
+ )
65
67
end
66
68
@warn " Reactant doesn't support sampling of $(T) with the current \
67
69
interpreter. Falling back to native interpreter." maxlog = 1
@@ -72,8 +74,8 @@ for randfun in (:rand, :randn, :randexp)
72
74
@reactant_overlay @noinline function Random. $ (randfun)(
73
75
rng:: AbstractRNG , :: Type{T} = Float64
74
76
) where {T}
75
- if T <: ReactantPrimitive
76
- return TracedRandom.$ (overload_randfun)(rng, T )
77
+ if unwrapped_eltype (T) <: ReactantPrimitive
78
+ return TracedRandom.$ (overload_randfun)(rng, unwrapped_eltype (T) )
77
79
end
78
80
@warn " Reactant doesn't support sampling of $(T) with the current \
79
81
interpreter. Falling back to native interpreter." maxlog = 1
@@ -86,6 +88,9 @@ for randfun in (:rand, :randn, :randexp)
86
88
)
87
89
return TracedRandom.$ (overload_randfun!)(rng, A)
88
90
end
91
+ @reactant_overlay @noinline function Random. $ (randfun!)(A:: AnyTracedRArray )
92
+ return TracedRandom.$ (overload_randfun!)(TracedRandom. default_rng (), A)
93
+ end
89
94
end
90
95
end
91
96
0 commit comments