Skip to content

Commit 1c59cd7

Browse files
wsmosesavik-pal
andauthored
Random: update seed array in place (#1292)
* Random: update seed array in place * fix * Update src/Types.jl --------- Co-authored-by: Avik Pal <[email protected]>
1 parent 8a5fa74 commit 1c59cd7

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/Types.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
8080
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
8181

8282
## TracedRNG
83-
mutable struct TracedRNG <: Random.AbstractRNG
83+
struct TracedRNG <: Random.AbstractRNG
8484
seed::TracedRArray{UInt64,1}
85-
const algorithm::String
85+
algorithm::String
8686
end
8787

8888
# Concrete Types

src/stdlibs/Random.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
end
4444

4545
@noinline function Random.seed!(rng::TracedRNG, seed::TracedRArray{UInt64,1})
46-
rng.seed = seed
46+
rng.seed.mlir_data = seed.mlir_data
4747
return rng
4848
end
4949

@@ -62,7 +62,7 @@ end
6262
end
6363

6464
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractConcreteArray{UInt64,1})
65-
rng.seed = seed
65+
Base.copyto!(rng.seed, seed)
6666
return rng
6767
end
6868

@@ -82,7 +82,7 @@ Base.copy(rng::TracedRNG) = TracedRNG(copy(rng.seed), rng.algorithm)
8282
) where {T,N}
8383
length(A) == 0 && return A
8484
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
85-
rng.seed = res.output_state
85+
rng.seed.mlir_data = res.output_state.mlir_data
8686
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
8787
return A
8888
end
@@ -92,7 +92,7 @@ end
9292
) where {T,N}
9393
length(A) == 0 && return A
9494
res = Ops.randn(T, rng.seed, [size(A)...]; rng.algorithm)
95-
rng.seed = res.output_state
95+
rng.seed.mlir_data = res.output_state.mlir_data
9696
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
9797
return A
9898
end
@@ -102,7 +102,7 @@ end
102102
) where {T,N}
103103
length(A) == 0 && return A
104104
res = Ops.randexp(T, rng.seed, [size(A)...]; rng.algorithm)
105-
rng.seed = res.output_state
105+
rng.seed.mlir_data = res.output_state.mlir_data
106106
TracedUtils.set_mlir_data!(A, res.output.mlir_data)
107107
return A
108108
end

0 commit comments

Comments
 (0)