Skip to content

Commit b1f8d8a

Browse files
committed
Use @eval macro, add support for tracing arbitrary Functions.
1 parent 19ea4c1 commit b1f8d8a

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

README.md

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# TracedRandom.jl
22

33
Allows for the optional specification of traced addresses (i.e. variable names)
4-
in calls to `rand` and other primitive functions in `Random`:
4+
in calls to `rand` and other primitive functions in `Random`. Providing this
5+
information allows ordinary Julia code to be "probabilistic-programming-ready".
56

67
```julia
78
julia> rand(:u, Float64, 10)
@@ -21,9 +22,17 @@ julia> randperm(:perm)
2122
4
2223
```
2324

24-
By default, the addresses (`:x`, `:z` and `:perm` in the examples above)
25-
are ignored, but they can be intercepted via meta-programming by
26-
probabilistic programming systems such as [`Gen`](https://www.gen.dev/) and
27-
[`Jaynes`](https://femtomc.github.io/Jaynes.jl/) in order to perform inference
28-
over the addressed random variables. Addresses can be specified as `Symbol`s,
29-
or as pairs from symbols to other types (`Pair{Symbol,<:Any}`).
25+
In addition, a call to some `fn::Function` can be annotated with an address
26+
by wrapping it in `rand` call:
27+
```julia
28+
julia> gaussian_mixture(μs) = randn(:z) + μs[rand(:k, 1:length(μs))]
29+
julia> rand(:x, gaussian_mixture, [1, 10])
30+
9.594800995267331
31+
```
32+
33+
By default, the addresses (`:x`, `:z`, `:k` and `:perm` in the examples above)
34+
are ignored, but they can be intercepted via meta-programming
35+
(see [`Genify.jl`](https://github.com/probcomp/Genify.jl])) to support inference
36+
in probabilistic programming systems such as [`Gen`](https://www.gen.dev/).
37+
Addresses can be specified as `Symbol`s, or as pairs from symbols
38+
to other types (`Pair{Symbol,<:Any}`).

src/TracedRandom.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ export rand!, randn!, randexp, randexp!,
88

99
const Address = Union{Symbol,Pair{Symbol}}
1010

11+
Base.rand(::Address, fn::Function, args...) = fn(args...)
12+
Base.rand(::AbstractRNG, ::Address, fn::Function, args...) = fn(args...)
13+
1114
Base.rand(::Address, args...) = rand(args...)
1215
Base.rand(rng::AbstractRNG, ::Address, args...) = rand(rng, args...)
1316

@@ -18,8 +21,8 @@ for fn in (:rand!, :randn!, :randexp, :randexp!, :bitrand, :randstring,
1821
:randsubseq, :randsubseq!, :shuffle, :shuffle!,
1922
:randperm, :randperm!, :randcycle, :randcycle!)
2023
fn = GlobalRef(Random, fn)
21-
eval(:($fn(::Address, args...) = $fn(args...)))
22-
eval(:($fn(rng::AbstractRNG, ::Address, args...) = $fn(rng, args...)))
24+
@eval $fn(::Address, args...) = $fn(args...)
25+
@eval $fn(rng::AbstractRNG, ::Address, args...) = $fn(rng, args...)
2326
end
2427

2528
end

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
using TracedRandom, Random, Test
22

3+
@testset "Addressed function calls" begin
4+
5+
gaussian_mixture(μs) = randn(:z) + μs[rand(:k, 1:length(μs))]
6+
Random.seed!(0)
7+
traced = rand(:y, gaussian_mixture, [1, 10, 100])
8+
Random.seed!(0)
9+
untraced = gaussian_mixture([1, 10, 100])
10+
@test traced == untraced
11+
12+
end
13+
314
@testset "Addressed calls with global RNG" begin
415

516
for fn in (rand, randn, randexp, bitrand, randstring)

0 commit comments

Comments
 (0)