Skip to content

Commit 1dc299c

Browse files
committed
Support address splicing into the current namespace via the here constant.
1 parent 75bba13 commit 1dc299c

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TracedRandom"
22
uuid = "03576162-b2ff-4022-a931-e67b9900fd45"
33
authors = ["Xuan <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/TracedRandom.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,27 @@ module TracedRandom
22

33
using Random
44

5-
export rand!, randn!, randexp, randexp!,
5+
export here, rand!, randn!, randexp, randexp!,
66
bitrand, randstring, randsubseq, randsubseq!,
77
shuffle, shuffle!, randperm, randperm!, randcycle, randcycle!
88

99
const Address = Union{Symbol,Pair{Symbol}}
1010

11+
struct Here end
12+
13+
"""
14+
here
15+
16+
Special address that refers to the address namespace of the calling context.
17+
Supported when [`rand`](@ref) is called on a non-primitive stochastic function,
18+
causing all named random variables sampled by that function to be spliced into
19+
the address namespace of the calling context.
20+
"""
21+
const here = Here()
22+
23+
Base.rand(::Here, fn::Function, args...) = fn(args...)
24+
Base.rand(::AbstractRNG, ::Here, fn::Function, args...) = fn(args...)
25+
1126
Base.rand(::Address, fn::Function, args...) = fn(args...)
1227
Base.rand(::AbstractRNG, ::Address, fn::Function, args...) = fn(args...)
1328

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ Random.seed!(0)
99
untraced = gaussian_mixture([1, 10, 100])
1010
@test traced == untraced
1111

12+
Random.seed!(0)
13+
traced = rand(here, gaussian_mixture, [1, 10, 100])
14+
Random.seed!(0)
15+
untraced = gaussian_mixture([1, 10, 100])
16+
@test traced == untraced
17+
1218
end
1319

1420
@testset "Addressed calls with global RNG" begin

0 commit comments

Comments
 (0)