Skip to content

Commit d1be27c

Browse files
committed
enforcing calling convention (rng being the 0th operand) for sample & untraced call ops
1 parent f4c4a88 commit d1be27c

File tree

2 files changed

+126
-38
lines changed

2 files changed

+126
-38
lines changed

src/ProbProg.jl

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,23 @@ function __init__()
349349
end
350350

351351
function sample(
352+
rng::AbstractRNG,
353+
f::Function,
354+
args::Vararg{Any,Nargs};
355+
symbol::Symbol=gensym("sample"),
356+
logpdf::Union{Nothing,Function}=nothing,
357+
) where {Nargs}
358+
res = sample_internal(rng, f, args...; symbol, logpdf)
359+
360+
@assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG"
361+
362+
res = res[2:end]
363+
364+
return length(res) == 1 ? res[1] : res
365+
end
366+
367+
function sample_internal(
368+
rng::AbstractRNG,
352369
f::Function,
353370
args::Vararg{Any,Nargs};
354371
symbol::Symbol=gensym("sample"),
@@ -358,15 +375,22 @@ function sample(
358375
resprefix::Symbol = gensym("sampleresult")
359376
resargprefix::Symbol = gensym("sampleresarg")
360377

378+
wrapper_fn = (all_args...) -> begin
379+
res = f(all_args...)
380+
(all_args[1], (res isa Tuple ? res : (res,))...)
381+
end
382+
383+
args = (rng, args...)
384+
361385
mlir_fn_res = invokelatest(
362386
TracedUtils.make_mlir_fn,
363-
f,
387+
wrapper_fn,
364388
args,
365389
(),
366390
string(f),
367391
false;
368392
do_transpose=false,
369-
args_in_result=:all,
393+
args_in_result=:result,
370394
argprefix,
371395
resprefix,
372396
resargprefix,
@@ -378,10 +402,13 @@ function sample(
378402
inputs = MLIR.IR.Value[]
379403
for a in linear_args
380404
idx, path = TracedUtils.get_argidx(a, argprefix)
381-
if idx == 1 && fnwrap
405+
if idx == 2 && fnwrap
382406
TracedUtils.push_val!(inputs, f, path[3:end])
383407
else
384-
idx -= fnwrap ? 1 : 0
408+
if fnwrap && idx > 1
409+
idx -= 1
410+
end
411+
385412
TracedUtils.push_val!(inputs, args[idx], path[3:end])
386413
end
387414
end
@@ -464,7 +491,7 @@ function sample(
464491
string(logpdf),
465492
false;
466493
do_transpose=false,
467-
args_in_result=:all,
494+
args_in_result=:result,
468495
)
469496

470497
logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name")
@@ -485,46 +512,67 @@ function sample(
485512

486513
for (i, res) in enumerate(linear_results)
487514
resv = MLIR.IR.result(sample_op, i)
515+
488516
if TracedUtils.has_idx(res, resprefix)
489517
path = TracedUtils.get_idx(res, resprefix)
490518
TracedUtils.set!(result, path[2:end], resv)
491-
elseif TracedUtils.has_idx(res, argprefix)
519+
end
520+
521+
if TracedUtils.has_idx(res, argprefix)
492522
idx, path = TracedUtils.get_argidx(res, argprefix)
493-
if idx == 1 && fnwrap
523+
if fnwrap && idx == 2
494524
TracedUtils.set!(f, path[3:end], resv)
495525
else
496-
if fnwrap
526+
if fnwrap && idx > 2
497527
idx -= 1
498528
end
499529
TracedUtils.set!(args[idx], path[3:end], resv)
500530
end
501-
else
531+
end
532+
533+
if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix)
502534
TracedUtils.set!(res, (), resv)
503535
end
504536
end
505537

506538
return result
507539
end
508540

509-
function call(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
510-
res = @jit optimize = :probprog call_internal(f, args...)
511-
return res isa AbstractConcreteArray ? Array(res) : res
541+
function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs}
542+
res = @jit optimize = :probprog call_internal(rng, f, args...)
543+
544+
@assert res isa Tuple && length(res) >= 1 && res[1] isa AbstractRNG "Expected first result to be RNG"
545+
546+
res = map(res[2:end]) do r
547+
r isa AbstractConcreteArray ? Array(r) : r
548+
end
549+
550+
@show res
551+
552+
return length(res) == 1 ? res[1] : res
512553
end
513554

514-
function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
555+
function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs}
515556
argprefix::Symbol = gensym("callarg")
516557
resprefix::Symbol = gensym("callresult")
517558
resargprefix::Symbol = gensym("callresarg")
518559

560+
wrapper_fn = (all_args...) -> begin
561+
res = f(all_args...)
562+
(all_args[1], (res isa Tuple ? res : (res,))...)
563+
end
564+
565+
args = (rng, args...)
566+
519567
mlir_fn_res = invokelatest(
520568
TracedUtils.make_mlir_fn,
521-
f,
569+
wrapper_fn,
522570
args,
523571
(),
524572
string(f),
525573
false;
526574
do_transpose=false,
527-
args_in_result=:all,
575+
args_in_result=:result,
528576
argprefix,
529577
resprefix,
530578
resargprefix,
@@ -533,6 +581,8 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
533581
fnwrap = mlir_fn_res.fnwrapped
534582
func2 = mlir_fn_res.f
535583

584+
@show length(linear_results), linear_results
585+
536586
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
537587
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
538588
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
@@ -557,17 +607,21 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
557607
if TracedUtils.has_idx(res, resprefix)
558608
path = TracedUtils.get_idx(res, resprefix)
559609
TracedUtils.set!(result, path[2:end], resv)
560-
elseif TracedUtils.has_idx(res, argprefix)
610+
end
611+
612+
if TracedUtils.has_idx(res, argprefix)
561613
idx, path = TracedUtils.get_argidx(res, argprefix)
562-
if idx == 1 && fnwrap
614+
if fnwrap && idx == 2
563615
TracedUtils.set!(f, path[3:end], resv)
564616
else
565-
if fnwrap
617+
if fnwrap && idx > 2
566618
idx -= 1
567619
end
568620
TracedUtils.set!(args[idx], path[3:end], resv)
569621
end
570-
else
622+
end
623+
624+
if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix)
571625
TracedUtils.set!(res, (), resv)
572626
end
573627
end

test/probprog/sample.jl

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,80 @@
11
using Reactant, Test, Random
2-
using Reactant: ProbProg
2+
using Reactant: ProbProg, ReactantRNG
33

44
normal(rng, μ, σ, shape) = μ .+ σ .* randn(rng, shape)
55

6-
function one_sample(seed, μ, σ, shape)
7-
rng = Random.default_rng()
8-
Random.seed!(rng, seed)
9-
s = ProbProg.sample(normal, rng, μ, σ, shape)
6+
function one_sample(rng, μ, σ, shape)
7+
s = ProbProg.sample(rng, normal, μ, σ, shape)
108
return s
119
end
1210

13-
function two_samples(seed, μ, σ, shape)
14-
rng = Random.default_rng()
15-
Random.seed!(rng, seed)
16-
_ = ProbProg.sample(normal, rng, μ, σ, shape)
17-
t = ProbProg.sample(normal, rng, μ, σ, shape)
11+
function two_samples(rng, μ, σ, shape)
12+
_ = ProbProg.sample(rng, normal, μ, σ, shape)
13+
t = ProbProg.sample(rng, normal, μ, σ, shape)
14+
return t
15+
end
16+
17+
function compose(rng, μ, σ, shape)
18+
s = ProbProg.sample(rng, normal, μ, σ, shape)
19+
t = ProbProg.sample(rng, normal, s, σ, shape)
1820
return t
1921
end
2022

2123
@testset "test" begin
22-
@testset "sample_hlo" begin
24+
@testset "normal_hlo" begin
2325
shape = (10,)
2426
seed = Reactant.to_rarray(UInt64[1, 4])
27+
rng = ReactantRNG(seed)
2528
μ = Reactant.ConcreteRNumber(0.0)
2629
σ = Reactant.ConcreteRNumber(1.0)
27-
before = @code_hlo optimize = false ProbProg.call_internal(
28-
one_sample, seed, μ, σ, shape
29-
)
30+
31+
code = @code_hlo optimize = false ProbProg.sample(rng, normal, μ, σ, shape)
32+
@test contains(repr(code), "enzyme.sample")
33+
end
34+
35+
@testset "two_samples_hlo" begin
36+
shape = (10,)
37+
seed = Reactant.to_rarray(UInt64[1, 4])
38+
rng = ReactantRNG(seed)
39+
μ = Reactant.ConcreteRNumber(0.0)
40+
σ = Reactant.ConcreteRNumber(1.0)
41+
42+
code = @code_hlo optimize = false ProbProg.sample(rng, two_samples, μ, σ, shape)
43+
@test contains(repr(code), "enzyme.sample")
44+
end
45+
46+
@testset "compose" begin
47+
shape = (10,)
48+
seed = Reactant.to_rarray(UInt64[1, 4])
49+
rng = ReactantRNG(seed)
50+
μ = Reactant.ConcreteRNumber(0.0)
51+
σ = Reactant.ConcreteRNumber(1.0)
52+
53+
before = @code_hlo optimize = false ProbProg.call(rng, compose, μ, σ, shape)
3054
@test contains(repr(before), "enzyme.sample")
31-
after = @code_hlo optimize = :probprog ProbProg.call_internal(
32-
two_samples, seed, μ, σ, shape
33-
)
55+
56+
after = @code_hlo optimize = :probprog ProbProg.call(rng, compose, μ, σ, shape)
3457
@test !contains(repr(after), "enzyme.sample")
3558
end
3659

3760
@testset "rng_state" begin
3861
shape = (10,)
62+
3963
seed = Reactant.to_rarray(UInt64[1, 4])
4064
μ = Reactant.ConcreteRNumber(0.0)
4165
σ = Reactant.ConcreteRNumber(1.0)
42-
X = ProbProg.call(one_sample, seed, μ, σ, shape)
43-
Y = ProbProg.call(two_samples, seed, μ, σ, shape)
66+
67+
rng1 = ReactantRNG(copy(seed))
68+
69+
X = ProbProg.call(rng1, one_sample, μ, σ, shape)
70+
@test !all(rng1.seed .== seed)
71+
72+
rng2 = ReactantRNG(copy(seed))
73+
Y = ProbProg.call(rng2, two_samples, μ, σ, shape)
74+
75+
@test !all(rng2.seed .== seed)
76+
@test !all(rng2.seed .== rng1.seed)
77+
4478
@test !all(X .≈ Y)
4579
end
4680
end

0 commit comments

Comments
 (0)