Skip to content

Commit d79acb2

Browse files
authored
feat: rework special concrete/traced rng handling (#1347)
* feat: rework special concrete/traced rng handling * chore: bump reactant version
1 parent a6b32db commit d79acb2

File tree

7 files changed

+38
-114
lines changed

7 files changed

+38
-114
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.120"
4+
version = "0.2.121"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Overlay.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727
end
2828

2929
@reactant_overlay @noinline function TracedRandom.default_rng()
30-
return TracedRNG(
30+
return ReactantRNG(
3131
TracedUtils.promote_to(TracedRArray{UInt64,1}, TracedRandom.make_seed()), "DEFAULT"
3232
)
3333
end

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ use_overlayed_version(::TracedRArray) = true
160160
use_overlayed_version(::TracedRNumber) = true
161161
use_overlayed_version(::Number) = false
162162
use_overlayed_version(::MissingTracedValue) = true
163-
use_overlayed_version(::TracedRNG) = true
164163
use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true
164+
use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed)
165165

166166
function use_overlayed_version(x::AbstractArray)
167167
a = ancestor(x)

src/Tracing.jl

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -348,23 +348,6 @@ Base.@nospecializeinfer function traced_type_inner(
348348
end
349349
end
350350

351-
Base.@nospecializeinfer function traced_type_inner(
352-
@nospecialize(T::Type{<:ConcreteRNG}),
353-
seen,
354-
mode::TraceMode,
355-
@nospecialize(track_numbers::Type),
356-
@nospecialize(sharding),
357-
@nospecialize(runtime)
358-
)
359-
if mode == ConcreteToTraced
360-
return TracedRNG
361-
elseif mode == TracedToConcrete
362-
return T
363-
else
364-
throw("Unsupported mode: $mode")
365-
end
366-
end
367-
368351
Base.@nospecializeinfer function traced_type_inner(
369352
@nospecialize(T::Type{<:MissingTracedValue}), @nospecialize(args...)
370353
)
@@ -451,29 +434,6 @@ Base.@nospecializeinfer function traced_type_inner(
451434
end
452435
end
453436

454-
Base.@nospecializeinfer function traced_type_inner(
455-
@nospecialize(T::Type{<:TracedRNG}),
456-
seen,
457-
mode::TraceMode,
458-
@nospecialize(track_numbers::Type),
459-
@nospecialize(sharding),
460-
@nospecialize(runtime)
461-
)
462-
if mode == ConcreteToTraced
463-
throw("TracedRNG cannot be traced")
464-
elseif mode == TracedToConcrete
465-
return ConcreteRNG{
466-
traced_type_inner(
467-
TracedRArray{UInt64,1}, seen, mode, track_numbers, sharding, runtime
468-
),
469-
}
470-
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
471-
return T
472-
else
473-
throw("Unsupported mode: $mode")
474-
end
475-
end
476-
477437
Base.@nospecializeinfer function traced_type_inner(
478438
@nospecialize(A::Type{AbstractArray}),
479439
seen,

src/Types.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,6 @@ const AnyTracedRVector{T} = AnyTracedRArray{T,1}
8787
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
8888
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
8989

90-
## TracedRNG
91-
struct TracedRNG <: Random.AbstractRNG
92-
seed::TracedRArray{UInt64,1}
93-
algorithm::String
94-
end
95-
9690
# Concrete Types
9791
## ConcretePJRTNumber
9892
mutable struct ConcretePJRTNumber{T,D,S<:Sharding.ShardInfo} <: AbstractConcreteNumber{T}
@@ -442,12 +436,16 @@ function ConcreteIFRTArray{T,N}(x::AnyConcreteIFRTArray; kwargs...) where {T,N}
442436
)
443437
end
444438

445-
## ConcreteRNG
446-
mutable struct ConcreteRNG{S<:AbstractConcreteArray} <: Random.AbstractRNG
439+
# RNGs
440+
struct ReactantRNG{S<:Union{<:AbstractConcreteArray{UInt64,1},TracedRArray{UInt64,1}}} <:
441+
Random.AbstractRNG
447442
seed::S
448-
const algorithm::String
443+
algorithm::String
449444
end
450445

446+
Base.@deprecate_binding ConcreteRNG ReactantRNG
447+
Base.@deprecate_binding TracedRNG ReactantRNG
448+
451449
## Aliases based on the set preferences
452450
if XLA.REACTANT_XLA_RUNTIME == "PJRT"
453451
const ConcreteRArray = ConcretePJRTArray

src/stdlibs/Random.jl

Lines changed: 27 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,80 +5,45 @@ module TracedRandom
55
# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl
66

77
using ..Reactant:
8-
Reactant,
9-
TracedRArray,
10-
TracedRNumber,
11-
ConcreteRNG,
12-
TracedRNG,
13-
AnyTracedRArray,
14-
Reactant,
15-
TracedUtils,
16-
Ops,
17-
AbstractConcreteArray,
18-
AbstractConcreteNumber,
19-
unwrapped_eltype
8+
Reactant, TracedRArray, TracedRNumber, ReactantRNG, AnyTracedRArray, TracedUtils, Ops
209
using Random: Random, AbstractRNG
2110

2211
@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) =
2312
Random.rand!(rng, Vector{UInt64}(undef, 2))
2413

25-
@noinline function Random.seed!(rng::TracedRNG, seed::Number)
14+
@noinline function Random.seed!(rng::ReactantRNG, seed::Number)
2615
if seed isa TracedRNumber
2716
error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \
2817
`TracedRArray` of the appropriate size instead.")
2918
end
3019

3120
seed = reinterpret(UInt64, Random.hash_seed(seed))
32-
return Random.seed!(
33-
rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)])
34-
)
21+
return Random.seed!(rng, seed[1:length(rng.seed)])
3522
end
3623

37-
@noinline function Random.seed!(rng::TracedRNG, seed::AbstractVector{<:Integer})
38-
return Random.seed!(rng, UInt64.(seed))
39-
end
40-
41-
@noinline function Random.seed!(rng::TracedRNG, seed::AbstractVector{UInt64})
42-
return Random.seed!(rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed))
43-
end
44-
45-
@noinline function Random.seed!(rng::TracedRNG, seed::TracedRArray{UInt64,1})
46-
copyto!(rng.seed, seed)
24+
@noinline function Random.seed!(rng::ReactantRNG, seed::AbstractVector)
25+
rng.seed .= seed
4726
return rng
4827
end
4928

50-
@noinline function Random.seed!(rng::ConcreteRNG, seed::Number)
51-
seed isa AbstractConcreteNumber && (seed = unwrapped_eltype(seed)(seed))
52-
seed = reinterpret(UInt64, Random.hash_seed(seed))
53-
return Random.seed!(rng, Reactant.to_rarray(seed))
54-
end
55-
56-
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{<:Integer})
57-
return Random.seed!(rng, seed)
58-
end
59-
60-
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{UInt64})
61-
return Random.seed!(rng, Reactant.to_rarray(seed))
62-
end
29+
Base.copy(rng::ReactantRNG) = ReactantRNG(copy(rng.seed), rng.algorithm)
6330

64-
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractConcreteArray{UInt64,1})
65-
Base.copyto!(rng.seed, seed)
66-
return rng
31+
@noinline function ReactantRNG()
32+
if Reactant.within_compile()
33+
return ReactantRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()))
34+
else
35+
return ReactantRNG(Reactant.to_rarray(make_seed()))
36+
end
6737
end
38+
@noinline ReactantRNG(seed::AbstractVector) = ReactantRNG(seed, "DEFAULT")
6839

69-
Base.copy(rng::ConcreteRNG) = ConcreteRNG(copy(rng.seed), rng.algorithm)
70-
Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm)
40+
@noinline default_rng() = ReactantRNG()
7141

72-
@noinline ConcreteRNG() = ConcreteRNG(Reactant.to_rarray(make_seed()))
73-
@noinline ConcreteRNG(seed::AbstractConcreteArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT")
74-
75-
@noinline default_rng() = ConcreteRNG()
76-
77-
@noinline rng_algorithm(rng::TracedRNG) = rng.algorithm
42+
@noinline rng_algorithm(rng::ReactantRNG) = rng.algorithm
7843
@noinline rng_algorithm(::AbstractRNG) = "DEFAULT"
7944

8045
@noinline function internal_overload_rand!(
81-
rng::TracedRNG, A::AnyTracedRArray{T,N}
46+
rng::ReactantRNG{<:TracedRArray}, A::AnyTracedRArray{T,N}
8247
) where {T,N}
8348
length(A) == 0 && return A
8449
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
@@ -88,7 +53,7 @@ Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm)
8853
end
8954

9055
@noinline function internal_overload_randn!(
91-
rng::TracedRNG, A::AnyTracedRArray{T,N}
56+
rng::ReactantRNG{<:TracedRArray}, A::AnyTracedRArray{T,N}
9257
) where {T,N}
9358
length(A) == 0 && return A
9459
res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm)
@@ -98,7 +63,7 @@ end
9863
end
9964

10065
@noinline function internal_overload_randexp!(
101-
rng::TracedRNG, A::AnyTracedRArray{T,N}
66+
rng::ReactantRNG{<:TracedRArray}, A::AnyTracedRArray{T,N}
10267
) where {T,N}
10368
length(A) == 0 && return A
10469
res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm)
@@ -114,25 +79,25 @@ for randfun in (:rand, :randn, :randexp)
11479

11580
@eval begin
11681
@noinline function $(overload_randfun)(
117-
rng::TracedRNG, ::Type{T}, dims::Dims
82+
rng::ReactantRNG{<:TracedRArray}, ::Type{T}, dims::Dims
11883
) where {T}
11984
return $(overload_randfun!)(
12085
rng, TracedRArray{T,length(dims)}((), nothing, dims)
12186
)
12287
end
12388

124-
@noinline function $(overload_randfun)(rng::TracedRNG, dims::Dims)
89+
@noinline function $(overload_randfun)(rng::ReactantRNG{<:TracedRArray}, dims::Dims)
12590
return $(overload_randfun)(rng, Float64, dims)
12691
end
12792

12893
@noinline function $(overload_randfun)(
129-
rng::TracedRNG, dim1::Integer, dims::Integer...
94+
rng::ReactantRNG{<:TracedRArray}, dim1::Integer, dims::Integer...
13095
)
13196
return $(overload_randfun)(rng, Dims((dim1, dims...)))
13297
end
13398

13499
@noinline function $(overload_randfun)(
135-
rng::TracedRNG, ::Type{T}, dim1::Integer, dims::Integer...
100+
rng::ReactantRNG{<:TracedRArray}, ::Type{T}, dim1::Integer, dims::Integer...
136101
) where {T}
137102
return $(overload_randfun)(rng, T, Dims((dim1, dims...)))
138103
end
@@ -142,7 +107,9 @@ for randfun in (:rand, :randn, :randexp)
142107
end
143108

144109
# scalars
145-
@noinline function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T}
110+
@noinline function $(overload_randfun)(
111+
rng::ReactantRNG{<:TracedRArray}, ::Type{T}=Float64
112+
) where {T}
146113
A = TracedUtils.promote_to(TracedRArray{T,0}, fill(T(0)))
147114
$(overload_randfun!)(rng, A)
148115
return TracedRNumber{T}((), A.mlir_data)
@@ -157,14 +124,14 @@ for randfun in (:rand, :randn, :randexp, :rand!, :randn!, :randexp!)
157124
internal_overload_randfun = Symbol(:internal_overload_, randfun)
158125
@eval begin
159126
@noinline function $(overload_randfun)(rng::AbstractRNG, args...)
160-
rng = TracedRNG(
127+
rng = ReactantRNG(
161128
TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed(rng)),
162129
rng_algorithm(rng),
163130
)
164131
return $(internal_overload_randfun)(rng, args...)
165132
end
166133

167-
@noinline function $(overload_randfun)(rng::TracedRNG, args...)
134+
@noinline function $(overload_randfun)(rng::ReactantRNG, args...)
168135
return $(internal_overload_randfun)(rng, args...)
169136
end
170137
end

src/utils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ function should_rewrite_call(@nospecialize(ft))
110110
# Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions
111111
if has_ancestor(mod, Reactant.Ops) ||
112112
has_ancestor(mod, Reactant.TracedUtils) ||
113-
has_ancestor(mod, Reactant.MLIR) ||
114-
has_ancestor(mod, Reactant.TracedRandom)
113+
has_ancestor(mod, Reactant.MLIR)
115114
return false
116115
end
117116
if string(mod) == "CUDA"

0 commit comments

Comments
 (0)