Skip to content

Commit 136f9ef

Browse files
authored
fix: incorrect scope for seed (#1343)
* fix: incorrect scope * fix: use copyto!
1 parent d749657 commit 136f9ef

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

src/stdlibs/Random.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
end
4444

4545
@noinline function Random.seed!(rng::TracedRNG, seed::TracedRArray{UInt64,1})
46-
rng.seed.mlir_data = seed.mlir_data
46+
copyto!(rng.seed, seed)
4747
return rng
4848
end
4949

@@ -82,7 +82,7 @@ Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm)
8282
) where {T,N}
8383
length(A) == 0 && return A
8484
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
85-
rng.seed.mlir_data = res.output_state.mlir_data
85+
copyto!(rng.seed, res.output_state)
8686
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
8787
return A
8888
end
@@ -92,7 +92,7 @@ end
9292
) where {T,N}
9393
length(A) == 0 && return A
9494
res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm)
95-
rng.seed.mlir_data = res.output_state.mlir_data
95+
copyto!(rng.seed, res.output_state)
9696
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
9797
return A
9898
end
@@ -102,7 +102,7 @@ end
102102
) where {T,N}
103103
length(A) == 0 && return A
104104
res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm)
105-
rng.seed.mlir_data = res.output_state.mlir_data
105+
copyto!(rng.seed, res.output_state)
106106
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
107107
return A
108108
end

test/autodiff.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Enzyme, Reactant, Test
1+
using Enzyme, Reactant, Test, Random
22

33
square(x) = x * 2
44

@@ -228,3 +228,30 @@ vector_forward_ad(x) = Enzyme.autodiff(Forward, fn, BatchDuplicated(x, Enzyme.on
228228
@test res[1][3] res_enz[1][3]
229229
@test res[1][4] res_enz[1][4]
230230
end
231+
232+
function simple_forward(x, st)
233+
rng = copy(st.rng)
234+
y = similar(x)
235+
rand!(rng, y)
236+
return x .+ y, (; rng)
237+
end
238+
239+
function gradient_fn(x, st)
240+
stₙ = Ref{Any}(nothing)
241+
function lfn(x, st_old)
242+
y, st_new = simple_forward(x, st_old)
243+
stₙ[] = st_new
244+
return sum(abs2, y)
245+
end
246+
return Enzyme.gradient(Reverse, lfn, x, Const(st)), stₙ[]
247+
end
248+
249+
@testset "seed" begin
250+
x = Reactant.to_rarray(rand(2, 2))
251+
st = (; rng=Reactant.ConcreteRNG())
252+
253+
@test begin
254+
hlo = @code_hlo gradient_fn(x, st)
255+
contains(repr(hlo), "stablehlo.rng_bit_generator")
256+
end
257+
end

0 commit comments

Comments
 (0)