Skip to content

Commit b666813

Browse files
committed
enforcing calling convention (rng being 0th operand) for simulate/generate ops
1 parent d1be27c commit b666813

File tree

3 files changed

+124
-161
lines changed

3 files changed

+124
-161
lines changed

src/ProbProg.jl

Lines changed: 64 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function addSampleToTrace(
7171
shape_ptr_array = unsafe_wrap(Array, shape_ptr_array, num_outputs)
7272
sample_ptr_array = unsafe_wrap(Array, sample_ptr_array, num_outputs)
7373

74-
tostore = Any[]
74+
vals = Any[]
7575
for i in 1:num_outputs
7676
ndims = ndims_array[i]
7777
width = width_array[i]
@@ -96,17 +96,14 @@ function addSampleToTrace(
9696
end
9797

9898
if ndims == 0
99-
val = unsafe_load(Ptr{julia_type}(sample_ptr))
100-
push!(tostore, val)
99+
push!(vals, unsafe_load(Ptr{julia_type}(sample_ptr)))
101100
else
102101
shape = unsafe_wrap(Array, shape_ptr, ndims)
103-
push!(
104-
tostore, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape)))
105-
)
102+
push!(vals, copy(unsafe_wrap(Array, Ptr{julia_type}(sample_ptr), Tuple(shape))))
106103
end
107104
end
108105

109-
trace.choices[symbol] = tuple(tostore...)
106+
trace.choices[symbol] = tuple(vals...)
110107

111108
return nothing
112109
end
@@ -184,7 +181,8 @@ function addRetvalToTrace(
184181
end
185182
end
186183

187-
trace.retval = length(vals) == 1 ? vals[1] : vals
184+
trace.retval = tuple(vals...)
185+
188186
return nothing
189187
end
190188

@@ -418,56 +416,11 @@ function sample_internal(
418416
sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
419417
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))
420418

421-
# Specify which outputs to add to the trace.
422-
traced_output_indices = Int[]
423-
for (i, res) in enumerate(linear_results)
424-
if TracedUtils.has_idx(res, resprefix)
425-
push!(traced_output_indices, i - 1)
426-
end
427-
end
428-
429-
# Specify which inputs to pass to logpdf.
430-
traced_input_indices = Int[]
431-
for (i, a) in enumerate(linear_args)
432-
idx, _ = TracedUtils.get_argidx(a, argprefix)
433-
if fnwrap && idx == 1 # TODO: add test for fnwrap
434-
continue
435-
end
436-
437-
if fnwrap
438-
idx -= 1
439-
end
440-
441-
if !(args[idx] isa AbstractRNG)
442-
push!(traced_input_indices, i - 1)
443-
end
444-
end
445-
446419
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
447420
symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet(
448421
MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64
449422
)::MLIR.IR.Attribute
450423

451-
# (out_idx1, in_idx1, out_idx2, in_idx2, ...)
452-
alias_pairs = Int64[]
453-
for (out_idx, res) in enumerate(linear_results)
454-
if TracedUtils.has_idx(res, argprefix)
455-
in_idx = nothing
456-
for (i, arg) in enumerate(linear_args)
457-
if TracedUtils.has_idx(arg, argprefix) &&
458-
TracedUtils.get_idx(arg, argprefix) ==
459-
TracedUtils.get_idx(res, argprefix)
460-
in_idx = i - 1
461-
break
462-
end
463-
end
464-
@assert in_idx !== nothing "Unable to find operand for aliased result"
465-
push!(alias_pairs, out_idx - 1)
466-
push!(alias_pairs, in_idx)
467-
end
468-
end
469-
alias_attr = MLIR.IR.DenseArrayAttribute(alias_pairs)
470-
471424
# Construct MLIR attribute if Julia logpdf function is provided.
472425
logpdf_attr = nothing
473426
if logpdf !== nothing
@@ -504,9 +457,6 @@ function sample_internal(
504457
fn=fn_attr,
505458
logpdf=logpdf_attr,
506459
symbol=symbol_attr,
507-
traced_input_indices=traced_input_indices,
508-
traced_output_indices=traced_output_indices,
509-
alias_map=alias_attr,
510460
name=Base.String(symbol),
511461
)
512462

@@ -547,8 +497,6 @@ function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nar
547497
r isa AbstractConcreteArray ? Array(r) : r
548498
end
549499

550-
@show res
551-
552500
return length(res) == 1 ? res[1] : res
553501
end
554502

@@ -581,19 +529,17 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w
581529
fnwrap = mlir_fn_res.fnwrapped
582530
func2 = mlir_fn_res.f
583531

584-
@show length(linear_results), linear_results
585-
586532
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
587533
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
588534
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
589535

590536
inputs = MLIR.IR.Value[]
591537
for a in linear_args
592538
idx, path = TracedUtils.get_argidx(a, argprefix)
593-
if idx == 1 && fnwrap
539+
if idx == 2 && fnwrap
594540
TracedUtils.push_val!(inputs, f, path[3:end])
595541
else
596-
if fnwrap
542+
if fnwrap && idx > 2
597543
idx -= 1
598544
end
599545
TracedUtils.push_val!(inputs, args[idx], path[3:end])
@@ -629,15 +575,14 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w
629575
return result
630576
end
631577

632-
function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
633-
old_gc_state = GC.enable(false)
634-
578+
function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs}
635579
trace = nothing
636-
weight = nothing
637-
res = nothing
638580

581+
compiled_fn = @compile optimize = :probprog simulate_internal(rng, f, args...)
582+
583+
old_gc_state = GC.enable(false)
639584
try
640-
trace, weight, res = @jit optimize = :probprog simulate_internal(f, args...)
585+
trace, _, _ = compiled_fn(rng, f, args...)
641586
finally
642587
GC.enable(old_gc_state)
643588
end
@@ -647,20 +592,29 @@ function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
647592
return trace, trace.weight
648593
end
649594

650-
function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
595+
function simulate_internal(
596+
rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}
597+
) where {Nargs}
651598
argprefix::Symbol = gensym("simulatearg")
652599
resprefix::Symbol = gensym("simulateresult")
653600
resargprefix::Symbol = gensym("simulateresarg")
654601

602+
wrapper_fn = (all_args...) -> begin
603+
res = f(all_args...)
604+
(all_args[1], (res isa Tuple ? res : (res,))...)
605+
end
606+
607+
args = (rng, args...)
608+
655609
mlir_fn_res = invokelatest(
656610
TracedUtils.make_mlir_fn,
657-
f,
611+
wrapper_fn,
658612
args,
659613
(),
660614
string(f),
661615
false;
662616
do_transpose=false,
663-
args_in_result=:all,
617+
args_in_result=:result,
664618
argprefix,
665619
resprefix,
666620
resargprefix,
@@ -673,21 +627,13 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
673627
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
674628
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
675629

676-
# Specify which outputs to add to the trace.
677-
traced_output_indices = Int[]
678-
for (i, res) in enumerate(linear_results)
679-
if TracedUtils.has_idx(res, resprefix)
680-
push!(traced_output_indices, i - 1)
681-
end
682-
end
683-
684630
inputs = MLIR.IR.Value[]
685631
for a in linear_args
686632
idx, path = TracedUtils.get_argidx(a, argprefix)
687-
if idx == 1 && fnwrap
633+
if idx == 2 && fnwrap
688634
TracedUtils.push_val!(inputs, f, path[3:end])
689635
else
690-
if fnwrap
636+
if fnwrap && idx > 2
691637
idx -= 1
692638
end
693639
TracedUtils.push_val!(inputs, args[idx], path[3:end])
@@ -700,30 +646,29 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
700646
weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))
701647

702648
simulate_op = MLIR.Dialects.enzyme.simulate(
703-
inputs;
704-
trace=trace_ty,
705-
weight=weight_ty,
706-
outputs=out_tys,
707-
fn=fn_attr,
708-
traced_output_indices=traced_output_indices,
649+
inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr
709650
)
710651

711652
for (i, res) in enumerate(linear_results)
712653
resv = MLIR.IR.result(simulate_op, i + 2)
713654
if TracedUtils.has_idx(res, resprefix)
714655
path = TracedUtils.get_idx(res, resprefix)
715656
TracedUtils.set!(result, path[2:end], resv)
716-
elseif TracedUtils.has_idx(res, argprefix)
657+
end
658+
659+
if TracedUtils.has_idx(res, argprefix)
717660
idx, path = TracedUtils.get_argidx(res, argprefix)
718-
if idx == 1 && fnwrap
661+
if idx == 2 && fnwrap
719662
TracedUtils.set!(f, path[3:end], resv)
720663
else
721-
if fnwrap
664+
if fnwrap && idx > 2
722665
idx -= 1
723666
end
724667
TracedUtils.set!(args[idx], path[3:end], resv)
725668
end
726-
else
669+
end
670+
671+
if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix)
727672
TracedUtils.set!(res, (), resv)
728673
end
729674
end
@@ -751,24 +696,25 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
751696
end
752697

753698
function generate(
754-
f::Function, args::Vararg{Any,Nargs}; constraint::Constraint=Dict{Symbol,Any}()
699+
rng::AbstractRNG,
700+
f::Function,
701+
args::Vararg{Any,Nargs};
702+
constraint::Constraint=Dict{Symbol,Any}(),
755703
) where {Nargs}
756704
trace = nothing
757-
weight = nothing
758-
res = nothing
759705

760706
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
761707
constrained_symbols = collect(keys(constraint))
762708

763-
function wrapper_fn(constraint_ptr, args...)
764-
return generate_internal(f, args...; constraint_ptr, constrained_symbols)
709+
function wrapper_fn(rng, constraint_ptr, args...)
710+
return generate_internal(rng, f, args...; constraint_ptr, constrained_symbols)
765711
end
766712

767-
compiled_fn = @compile optimize = :probprog wrapper_fn(constraint_ptr, args...)
713+
compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr, args...)
768714

769715
old_gc_state = GC.enable(false)
770716
try
771-
trace, weight, res = compiled_fn(constraint_ptr, args...)
717+
trace, _, _ = compiled_fn(rng, constraint_ptr, args...)
772718
finally
773719
GC.enable(old_gc_state)
774720
end
@@ -779,6 +725,7 @@ function generate(
779725
end
780726

781727
function generate_internal(
728+
rng::AbstractRNG,
782729
f::Function,
783730
args::Vararg{Any,Nargs};
784731
constraint_ptr::TracedRNumber,
@@ -788,15 +735,22 @@ function generate_internal(
788735
resprefix::Symbol = gensym("generateresult")
789736
resargprefix::Symbol = gensym("generateresarg")
790737

738+
wrapper_fn = (all_args...) -> begin
739+
res = f(all_args...)
740+
(all_args[1], (res isa Tuple ? res : (res,))...)
741+
end
742+
743+
args = (rng, args...)
744+
791745
mlir_fn_res = invokelatest(
792746
TracedUtils.make_mlir_fn,
793-
f,
747+
wrapper_fn,
794748
args,
795749
(),
796750
string(f),
797751
false;
798752
do_transpose=false,
799-
args_in_result=:all,
753+
args_in_result=:result,
800754
argprefix,
801755
resprefix,
802756
resargprefix,
@@ -809,21 +763,13 @@ function generate_internal(
809763
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
810764
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
811765

812-
# Specify which outputs to add to the trace.
813-
traced_output_indices = Int[]
814-
for (i, res) in enumerate(linear_results)
815-
if TracedUtils.has_idx(res, resprefix)
816-
push!(traced_output_indices, i - 1)
817-
end
818-
end
819-
820766
inputs = MLIR.IR.Value[]
821767
for a in linear_args
822768
idx, path = TracedUtils.get_argidx(a, argprefix)
823-
if idx == 1 && fnwrap
769+
if idx == 2 && fnwrap
824770
TracedUtils.push_val!(inputs, f, path[3:end])
825771
else
826-
if fnwrap
772+
if fnwrap && idx > 2
827773
idx -= 1
828774
end
829775
TracedUtils.push_val!(inputs, args[idx], path[3:end])
@@ -865,25 +811,28 @@ function generate_internal(
865811
outputs=out_tys,
866812
fn=fn_attr,
867813
constrained_symbols=MLIR.IR.Attribute(constrained_symbols_attr),
868-
traced_output_indices,
869814
)
870815

871816
for (i, res) in enumerate(linear_results)
872817
resv = MLIR.IR.result(generate_op, i + 2)
873818
if TracedUtils.has_idx(res, resprefix)
874819
path = TracedUtils.get_idx(res, resprefix)
875820
TracedUtils.set!(result, path[2:end], resv)
876-
elseif TracedUtils.has_idx(res, argprefix)
821+
end
822+
823+
if TracedUtils.has_idx(res, argprefix)
877824
idx, path = TracedUtils.get_argidx(res, argprefix)
878-
if idx == 1 && fnwrap
825+
if idx == 2 && fnwrap
879826
TracedUtils.set!(f, path[3:end], resv)
880827
else
881-
if fnwrap
828+
if fnwrap && idx > 2
882829
idx -= 1
883830
end
884831
TracedUtils.set!(args[idx], path[3:end], resv)
885832
end
886-
else
833+
end
834+
835+
if !TracedUtils.has_idx(res, resprefix) && !TracedUtils.has_idx(res, argprefix)
887836
TracedUtils.set!(res, (), resv)
888837
end
889838
end

0 commit comments

Comments
 (0)