Skip to content

Commit 94e9576

Browse files
authored
feat: tracing Random.jl functionality correctly (EnzymeAD#363)
* refactor: move stdlib overloads to a different directory * fix: Ops.rng_bit_generator * feat: initial prototype for random number generation * feat: add support for scalar sampling * feat: efficient sampling for non-native RNGs * fix: handling floating point sampling * feat: use the override macro * fix: use `@noinline` * feat: support randexp * feat: override seeding inside interpreter * refactor: move things into a module * refactor: rework how the overlays are implemented * docs: add internal api to the docs * test: include floating point tests * test: setup testing * feat: overlay all generators * test: ensure distributions are correct * test: overlay generation * fix: test whether we can call into the non-overlayed version * fix: try marking TracedRandom in whitelist * fix: workaround the AbsInt issues for now * fix: throw errors for now instead of crashing
1 parent 0713d99 commit 94e9576

File tree

17 files changed

+690
-16
lines changed

17 files changed

+690
-16
lines changed

.github/workflows/CI.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ jobs:
5050
version: '1.10'
5151
assertions: true
5252
test_group: neural_networks
53+
- os: ubuntu-20.04
54+
arch: x64
55+
libReactant: packaged
56+
version: '1.10'
57+
assertions: true
58+
test_group: integration
5359
- os: ubuntu-20.04
5460
arch: x86
5561
libReactant: packaged

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1616
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
17+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
1819
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
1920
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
@@ -23,17 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
2324
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
2425
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2526
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
27+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
2628
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2729
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
2830

29-
[sources.ReactantCore]
30-
path = "lib/ReactantCore"
31+
[sources]
32+
ReactantCore = {path = "lib/ReactantCore"}
3133

3234
[extensions]
3335
ReactantAbstractFFTsExt = "AbstractFFTs"
3436
ReactantArrayInterfaceExt = "ArrayInterface"
3537
ReactantCUDAExt = "CUDA"
3638
ReactantNNlibExt = "NNlib"
39+
ReactantRandom123Ext = "Random123"
3740
ReactantStatisticsExt = "Statistics"
3841
ReactantYaoBlocksExt = "YaoBlocks"
3942

@@ -51,6 +54,8 @@ LinearAlgebra = "1.10"
5154
NNlib = "0.9.26"
5255
OrderedCollections = "1"
5356
Preferences = "1.4"
57+
Random = "1.10"
58+
Random123 = "1.7"
5459
ReactantCore = "0.1.3"
5560
Reactant_jll = "0.0.26"
5661
Scratch = "1.2"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pages = [
4343
],
4444
"MLIR API" => "api/mlirc.md",
4545
"XLA" => "api/xla.md",
46+
"Internal API" => "api/internal.md",
4647
],
4748
]
4849

docs/src/.vitepress/config.mts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ export default defineConfig({
7878
{ text: "MLIR API", link: "/api/mlirc" },
7979
{ text: "XLA", link: "/api/xla" },
8080
],
81-
}
81+
},
82+
{ text: "Internal API", link: "/api/internal" },
8283
],
8384
},
8485
{
@@ -132,6 +133,7 @@ export default defineConfig({
132133
{ text: "XLA", link: "/api/xla" },
133134
],
134135
},
136+
{ text: "Internal API", link: "/api/internal" },
135137
],
136138
},
137139
},

docs/src/api/internal.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# Internal API
6+
7+
These functions are not part of the public API and are subject to change at any time.
8+
9+
```@docs
10+
Reactant.REDUB_ARGUMENTS_NAME
11+
Reactant.within_reactant_interpreter
12+
```

ext/ReactantRandom123Ext.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module ReactantRandom123Ext
2+
3+
using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x
4+
using Reactant: TracedRandom
5+
6+
TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY"
7+
TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
8+
TracedRandom.rng_algorithm(::Philox4x) = "PHILOX"
9+
TracedRandom.rng_algorithm(::Philox2x) = "PHILOX"
10+
11+
end

src/Ops.jl

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,19 +1016,150 @@ end
10161016
end
10171017

10181018
# random ops
1019+
"""
1020+
rng_bit_generator(
1021+
::Type{T},
1022+
seed::TracedRArray{UInt64,1},
1023+
shape;
1024+
algorithm::String="DEFAULT",
1025+
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1026+
)
1027+
1028+
Generate a random array of type `T` with the given shape and seed from a uniform random
1029+
distribution between 0 and 1. Returns a NamedTuple with the following fields:
1030+
1031+
- `output_state`: The state of the random number generator after the operation.
1032+
- `output`: The generated array.
1033+
1034+
# Arguments
1035+
1036+
- `T`: The type of the generated array.
1037+
- `seed`: The seed for the random number generator.
1038+
- `shape`: The shape of the generated array.
1039+
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
1040+
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
1041+
"""
10191042
@noinline function rng_bit_generator(
1043+
::Type{T},
10201044
seed::TracedRArray{UInt64,1},
10211045
shape;
10221046
algorithm::String="DEFAULT",
10231047
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
1024-
)
1025-
output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape)
1048+
) where {T<:Integer}
1049+
@assert algorithm in ("DEFAULT", "PHILOX", "THREE_FRY")
1050+
if algorithm == "PHILOX"
1051+
@assert length(seed) (2, 3)
1052+
elseif algorithm == "THREE_FRY"
1053+
@assert length(seed) == 2
1054+
end
1055+
1056+
output = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
1057+
output_state = MLIR.IR.TensorType(size(seed), MLIR.IR.Type(UInt64))
10261058
rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm)
1027-
op = stablehlo.rng_bit_generator(seed.mlir_data; output, rng_algorithm, location)
1059+
op = stablehlo.rng_bit_generator(
1060+
seed.mlir_data; output, output_state, rng_algorithm, location
1061+
)
10281062
return (;
1029-
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), MLIR.IR.size(seed)),
1030-
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), shape),
1063+
output_state=TracedRArray{UInt64,1}((), MLIR.IR.result(op, 1), size(seed)),
1064+
output=TracedRArray{T,length(shape)}((), MLIR.IR.result(op, 2), Tuple(shape)),
1065+
)
1066+
end
1067+
1068+
@noinline function rng_bit_generator(
1069+
::Type{T},
1070+
seed::TracedRArray{UInt64,1},
1071+
shape;
1072+
algorithm::String="DEFAULT",
1073+
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
1074+
) where {T<:AbstractFloat}
1075+
nbits = sizeof(T) * 8
1076+
uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
1077+
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
1078+
output = divide(
1079+
convert(TracedRArray{T,ndims(output)}, output),
1080+
constant(fill(T(typemax(uT)), Tuple(shape)); location),
1081+
)
1082+
return (; output_state, output)
1083+
end
1084+
1085+
"""
1086+
randn(
1087+
::Type{T},
1088+
seed::TracedRArray{UInt64,1},
1089+
shape;
1090+
algorithm::String="DEFAULT",
1091+
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
10311092
)
1093+
1094+
Generate a random array of type `T` with the given shape and seed from a standard normal
1095+
distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following
1096+
fields:
1097+
1098+
- `output_state`: The state of the random number generator after the operation.
1099+
- `output`: The generated array.
1100+
1101+
# Arguments
1102+
1103+
- `T`: The type of the generated array.
1104+
- `seed`: The seed for the random number generator.
1105+
- `shape`: The shape of the generated array.
1106+
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
1107+
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
1108+
"""
1109+
@noinline function randn(
1110+
::Type{T},
1111+
seed::TracedRArray{UInt64,1},
1112+
shape;
1113+
algorithm::String="DEFAULT",
1114+
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1115+
) where {T}
1116+
res = rng_bit_generator(T, seed, shape; algorithm, location)
1117+
rand_uniform = res.output
1118+
seed = res.output_state
1119+
scaled_uniform = subtract(
1120+
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
1121+
constant(fill(T(1), size(rand_uniform))),
1122+
)
1123+
probit = erf_inv(scaled_uniform)
1124+
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
1125+
return (; output_state=seed, output=rand_normal)
1126+
end
1127+
1128+
"""
1129+
randexp(
1130+
::Type{T},
1131+
seed::TracedRArray{UInt64,1},
1132+
shape;
1133+
algorithm::String="DEFAULT",
1134+
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1135+
)
1136+
1137+
Generate a random array of type `T` with the given shape and seed from an exponential
1138+
distribution with rate 1. Returns a NamedTuple with the following fields:
1139+
1140+
- `output_state`: The state of the random number generator after the operation.
1141+
- `output`: The generated array.
1142+
1143+
# Arguments
1144+
1145+
- `T`: The type of the generated array.
1146+
- `seed`: The seed for the random number generator.
1147+
- `shape`: The shape of the generated array.
1148+
- `algorithm`: The algorithm to use for generating the random numbers. Defaults to
1149+
"DEFAULT". Other options include "PHILOX" and "THREE_FRY".
1150+
"""
1151+
@noinline function randexp(
1152+
::Type{T},
1153+
seed::TracedRArray{UInt64,1},
1154+
shape;
1155+
algorithm::String="DEFAULT",
1156+
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1157+
) where {T}
1158+
res = rng_bit_generator(T, seed, shape; algorithm, location)
1159+
rand_uniform = res.output
1160+
seed = res.output_state
1161+
rand_exp = negate(log_plus_one(negate(rand_uniform)))
1162+
return (; output_state=seed, output=rand_exp)
10321163
end
10331164

10341165
# functional ops

src/Overlay.jl

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,23 @@
33
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
44
# we should move all the reactant_overrides to relevant files.
55

6+
# Helper Function to determine if we are inside the ReactantInterpreter
7+
"""
8+
within_reactant_interpreter()
9+
10+
Returns `true` if we are currently inside the ReactantInterpreter.
11+
"""
12+
@noinline within_reactant_interpreter() = false
13+
@reactant_overlay @noinline within_reactant_interpreter() = true
14+
615
# Compiling within a compile should return simply the original function
716
@reactant_overlay function Compiler.compile(
817
f, args; client=nothing, optimize=true, sync=false
918
)
1019
return f
1120
end
1221

13-
# Enzyme overrides
22+
# Enzyme.jl overlays
1423
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
1524
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
1625
) where {FA<:Annotation,A<:Annotation,Nargs}
@@ -22,3 +31,87 @@ end
2231
) where {FA<:Annotation,A<:Annotation,Nargs}
2332
return overload_autodiff(rmode, f, rt, args...)
2433
end
34+
35+
# Random.jl overlays
36+
@reactant_overlay @noinline function Random.default_rng()
37+
return call_with_reactant(TracedRandom.default_rng)
38+
end
39+
40+
## Only problematic edge case here is the direct `<randfun!>(rng, A::AbstractArray)` call
41+
## We can't directly overlay that call without breaking the semantics of inplace update
42+
for randfun in (:rand, :randn, :randexp)
43+
randfun! = Symbol(randfun, :!)
44+
overload_randfun = Symbol(:overload_, randfun)
45+
overload_randfun! = Symbol(:overload_, randfun!)
46+
47+
@eval begin
48+
@reactant_overlay @noinline function Random.$(randfun)(
49+
rng::AbstractRNG, ::Type{T}, dims::Dims
50+
) where {T}
51+
if T <: ReactantPrimitive
52+
return TracedRandom.$(overload_randfun)(rng, T, dims)
53+
end
54+
return error(
55+
"Reactant doesn't support sampling of $(T) with the current interpreter."
56+
)
57+
# XXX: The following will lead to illegal instruction
58+
# @warn "Reactant doesn't support sampling of $(T) with the current \
59+
# interpreter. Falling back to native interpreter." maxlog = 1
60+
# return Random.$(randfun)(rng, T, dims)
61+
end
62+
63+
@reactant_overlay @noinline function Random.$(randfun)(
64+
rng::AbstractRNG, dim1::Integer, dims::Integer...
65+
)
66+
return TracedRandom.$(overload_randfun)(rng, dim1, dims...)
67+
end
68+
69+
@reactant_overlay @noinline function Random.$(randfun)(
70+
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
71+
) where {T}
72+
if T <: ReactantPrimitive
73+
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
74+
end
75+
return error(
76+
"Reactant doesn't support sampling of $(T) with the current interpreter."
77+
)
78+
# XXX: The following will lead to illegal instruction
79+
# @warn "Reactant doesn't support sampling of $(T) with the current \
80+
# interpreter. Falling back to native interpreter." maxlog = 1
81+
# return Random.$(randfun)(rng, T, dim1, dims...)
82+
end
83+
84+
# scalars
85+
@reactant_overlay @noinline function Random.$(randfun)(
86+
rng::AbstractRNG, ::Type{T}=Float64
87+
) where {T}
88+
if T <: ReactantPrimitive
89+
return TracedRandom.$(overload_randfun)(rng, T)
90+
end
91+
return error(
92+
"Reactant doesn't support sampling of $(T) with the current interpreter."
93+
)
94+
# XXX: The following will lead to illegal instruction
95+
# @warn "Reactant doesn't support sampling of $(T) with the current \
96+
# interpreter. Falling back to native interpreter." maxlog = 1
97+
# return Random.$(randfun)(rng, T)
98+
end
99+
100+
# inplace
101+
@reactant_overlay @noinline function Random.$(randfun!)(
102+
rng::AbstractRNG, A::AnyTracedRArray
103+
)
104+
return TracedRandom.$(overload_randfun!)(rng, A)
105+
end
106+
107+
# XXX: Uncomment once AbsInt issues with recursive calls are resolved
108+
# @reactant_overlay @noinline function Random.$(randfun!)(
109+
# rng::AbstractRNG, A::AbstractArray
110+
# )
111+
# @warn "Directly writing to an array using Random.jl functions inside \
112+
# ReactantInterpreter will generate a constant array in the IR. Use with \
113+
# caution." maxlog = 1
114+
# return Random.$(randfun!)(rng, A)
115+
# end
116+
end
117+
end

src/Reactant.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module Reactant
33
using ReactantCore: ReactantCore, @trace, MissingTracedValue
44

55
using LinearAlgebra: LinearAlgebra
6+
using Random: Random, AbstractRNG
7+
68
using Adapt: Adapt, WrappedArray
79
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`
810

@@ -122,7 +124,14 @@ include("TracedRArray.jl")
122124

123125
include("ConcreteRArray.jl")
124126

125-
include("linear_algebra.jl")
127+
mutable struct TracedRNG <: Random.AbstractRNG
128+
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
129+
const algorithm::String
130+
end
131+
132+
# StdLib Overloads
133+
include("stdlibs/LinearAlgebra.jl")
134+
include("stdlibs/Random.jl")
126135

127136
const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
128137

File renamed without changes.

0 commit comments

Comments
 (0)