@@ -5,80 +5,45 @@ module TracedRandom
5
5
# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl
6
6
7
7
using .. Reactant:
8
- Reactant,
9
- TracedRArray,
10
- TracedRNumber,
11
- ConcreteRNG,
12
- TracedRNG,
13
- AnyTracedRArray,
14
- Reactant,
15
- TracedUtils,
16
- Ops,
17
- AbstractConcreteArray,
18
- AbstractConcreteNumber,
19
- unwrapped_eltype
8
+ Reactant, TracedRArray, TracedRNumber, ReactantRNG, AnyTracedRArray, TracedUtils, Ops
20
9
using Random: Random, AbstractRNG
21
10
22
11
@noinline make_seed (rng:: AbstractRNG = Random. RandomDevice ()) =
23
12
Random. rand! (rng, Vector {UInt64} (undef, 2 ))
24
13
25
- @noinline function Random. seed! (rng:: TracedRNG , seed:: Number )
14
+ @noinline function Random. seed! (rng:: ReactantRNG , seed:: Number )
26
15
if seed isa TracedRNumber
27
16
error (" Passing in `TracedRNumber` as a seed is not supported. Please pass in a \
28
17
`TracedRArray` of the appropriate size instead." )
29
18
end
30
19
31
20
seed = reinterpret (UInt64, Random. hash_seed (seed))
32
- return Random. seed! (
33
- rng, TracedUtils. promote_to (TracedRArray{UInt64,1 }, seed[1 : length (rng. seed)])
34
- )
21
+ return Random. seed! (rng, seed[1 : length (rng. seed)])
35
22
end
36
23
37
- @noinline function Random. seed! (rng:: TracedRNG , seed:: AbstractVector{<:Integer} )
38
- return Random. seed! (rng, UInt64 .(seed))
39
- end
40
-
41
- @noinline function Random. seed! (rng:: TracedRNG , seed:: AbstractVector{UInt64} )
42
- return Random. seed! (rng, TracedUtils. promote_to (TracedRArray{UInt64,1 }, seed))
43
- end
44
-
45
- @noinline function Random. seed! (rng:: TracedRNG , seed:: TracedRArray{UInt64,1} )
46
- copyto! (rng. seed, seed)
24
+ @noinline function Random. seed! (rng:: ReactantRNG , seed:: AbstractVector )
25
+ rng. seed .= seed
47
26
return rng
48
27
end
49
28
50
- @noinline function Random. seed! (rng:: ConcreteRNG , seed:: Number )
51
- seed isa AbstractConcreteNumber && (seed = unwrapped_eltype (seed)(seed))
52
- seed = reinterpret (UInt64, Random. hash_seed (seed))
53
- return Random. seed! (rng, Reactant. to_rarray (seed))
54
- end
55
-
56
- @noinline function Random. seed! (rng:: ConcreteRNG , seed:: AbstractVector{<:Integer} )
57
- return Random. seed! (rng, seed)
58
- end
59
-
60
- @noinline function Random. seed! (rng:: ConcreteRNG , seed:: AbstractVector{UInt64} )
61
- return Random. seed! (rng, Reactant. to_rarray (seed))
62
- end
29
+ Base. copy (rng:: ReactantRNG ) = ReactantRNG (copy (rng. seed), rng. algorithm)
63
30
64
- @noinline function Random. seed! (rng:: ConcreteRNG , seed:: AbstractConcreteArray{UInt64,1} )
65
- Base. copyto! (rng. seed, seed)
66
- return rng
31
+ @noinline function ReactantRNG ()
32
+ if Reactant. within_compile ()
33
+ return ReactantRNG (TracedUtils. promote_to (TracedRArray{UInt64,1 }, make_seed ()))
34
+ else
35
+ return ReactantRNG (Reactant. to_rarray (make_seed ()))
36
+ end
67
37
end
38
+ @noinline ReactantRNG (seed:: AbstractVector ) = ReactantRNG (seed, " DEFAULT" )
68
39
69
- Base. copy (rng:: ConcreteRNG ) = ConcreteRNG (copy (rng. seed), rng. algorithm)
70
- Base. copy (rng:: TracedRNG ) = TracedRNG (copy (rng. seed), rng. algorithm)
40
+ @noinline default_rng () = ReactantRNG ()
71
41
72
- @noinline ConcreteRNG () = ConcreteRNG (Reactant. to_rarray (make_seed ()))
73
- @noinline ConcreteRNG (seed:: AbstractConcreteArray{UInt64,1} ) = ConcreteRNG (seed, " DEFAULT" )
74
-
75
- @noinline default_rng () = ConcreteRNG ()
76
-
77
- @noinline rng_algorithm (rng:: TracedRNG ) = rng. algorithm
42
+ @noinline rng_algorithm (rng:: ReactantRNG ) = rng. algorithm
78
43
@noinline rng_algorithm (:: AbstractRNG ) = " DEFAULT"
79
44
80
45
@noinline function internal_overload_rand! (
81
- rng:: TracedRNG , A:: AnyTracedRArray{T,N}
46
+ rng:: ReactantRNG{<:TracedRArray} , A:: AnyTracedRArray{T,N}
82
47
) where {T,N}
83
48
length (A) == 0 && return A
84
49
res = Ops. rng_bit_generator (T, rng. seed, [size (A)... ]; rng. algorithm)
@@ -88,7 +53,7 @@ Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm)
88
53
end
89
54
90
55
@noinline function internal_overload_randn! (
91
- rng:: TracedRNG , A:: AnyTracedRArray{T,N}
56
+ rng:: ReactantRNG{<:TracedRArray} , A:: AnyTracedRArray{T,N}
92
57
) where {T,N}
93
58
length (A) == 0 && return A
94
59
res = Ops. randn (T, rng. seed, [size (A)... ]; rng. algorithm)
98
63
end
99
64
100
65
@noinline function internal_overload_randexp! (
101
- rng:: TracedRNG , A:: AnyTracedRArray{T,N}
66
+ rng:: ReactantRNG{<:TracedRArray} , A:: AnyTracedRArray{T,N}
102
67
) where {T,N}
103
68
length (A) == 0 && return A
104
69
res = Ops. randexp (T, rng. seed, [size (A)... ]; rng. algorithm)
@@ -114,25 +79,25 @@ for randfun in (:rand, :randn, :randexp)
114
79
115
80
@eval begin
116
81
@noinline function $ (overload_randfun)(
117
- rng:: TracedRNG , :: Type{T} , dims:: Dims
82
+ rng:: ReactantRNG{<:TracedRArray} , :: Type{T} , dims:: Dims
118
83
) where {T}
119
84
return $ (overload_randfun!)(
120
85
rng, TracedRArray {T,length(dims)} ((), nothing , dims)
121
86
)
122
87
end
123
88
124
- @noinline function $ (overload_randfun)(rng:: TracedRNG , dims:: Dims )
89
+ @noinline function $ (overload_randfun)(rng:: ReactantRNG{<:TracedRArray} , dims:: Dims )
125
90
return $ (overload_randfun)(rng, Float64, dims)
126
91
end
127
92
128
93
@noinline function $ (overload_randfun)(
129
- rng:: TracedRNG , dim1:: Integer , dims:: Integer...
94
+ rng:: ReactantRNG{<:TracedRArray} , dim1:: Integer , dims:: Integer...
130
95
)
131
96
return $ (overload_randfun)(rng, Dims ((dim1, dims... )))
132
97
end
133
98
134
99
@noinline function $ (overload_randfun)(
135
- rng:: TracedRNG , :: Type{T} , dim1:: Integer , dims:: Integer...
100
+ rng:: ReactantRNG{<:TracedRArray} , :: Type{T} , dim1:: Integer , dims:: Integer...
136
101
) where {T}
137
102
return $ (overload_randfun)(rng, T, Dims ((dim1, dims... )))
138
103
end
@@ -142,7 +107,9 @@ for randfun in (:rand, :randn, :randexp)
142
107
end
143
108
144
109
# scalars
145
- @noinline function $ (overload_randfun)(rng:: TracedRNG , :: Type{T} = Float64) where {T}
110
+ @noinline function $ (overload_randfun)(
111
+ rng:: ReactantRNG{<:TracedRArray} , :: Type{T} = Float64
112
+ ) where {T}
146
113
A = TracedUtils. promote_to (TracedRArray{T,0 }, fill (T (0 )))
147
114
$ (overload_randfun!)(rng, A)
148
115
return TracedRNumber {T} ((), A. mlir_data)
@@ -157,14 +124,14 @@ for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!)
157
124
internal_overload_randfun = Symbol (:internal_overload_ , randfun)
158
125
@eval begin
159
126
@noinline function $ (overload_randfun)(rng:: AbstractRNG , args... )
160
- rng = TracedRNG (
127
+ rng = ReactantRNG (
161
128
TracedUtils. promote_to (TracedRArray{UInt64,1 }, make_seed (rng)),
162
129
rng_algorithm (rng),
163
130
)
164
131
return $ (internal_overload_randfun)(rng, args... )
165
132
end
166
133
167
- @noinline function $ (overload_randfun)(rng:: TracedRNG , args... )
134
+ @noinline function $ (overload_randfun)(rng:: ReactantRNG , args... )
168
135
return $ (internal_overload_randfun)(rng, args... )
169
136
end
170
137
end
0 commit comments