@@ -71,7 +71,7 @@ function addSampleToTrace(
71
71
shape_ptr_array = unsafe_wrap (Array, shape_ptr_array, num_outputs)
72
72
sample_ptr_array = unsafe_wrap (Array, sample_ptr_array, num_outputs)
73
73
74
- tostore = Any[]
74
+ vals = Any[]
75
75
for i in 1 : num_outputs
76
76
ndims = ndims_array[i]
77
77
width = width_array[i]
@@ -96,17 +96,14 @@ function addSampleToTrace(
96
96
end
97
97
98
98
if ndims == 0
99
- val = unsafe_load (Ptr {julia_type} (sample_ptr))
100
- push! (tostore, val)
99
+ push! (vals, unsafe_load (Ptr {julia_type} (sample_ptr)))
101
100
else
102
101
shape = unsafe_wrap (Array, shape_ptr, ndims)
103
- push! (
104
- tostore, copy (unsafe_wrap (Array, Ptr {julia_type} (sample_ptr), Tuple (shape)))
105
- )
102
+ push! (vals, copy (unsafe_wrap (Array, Ptr {julia_type} (sample_ptr), Tuple (shape))))
106
103
end
107
104
end
108
105
109
- trace. choices[symbol] = tuple (tostore ... )
106
+ trace. choices[symbol] = tuple (vals ... )
110
107
111
108
return nothing
112
109
end
@@ -184,7 +181,8 @@ function addRetvalToTrace(
184
181
end
185
182
end
186
183
187
- trace. retval = length (vals) == 1 ? vals[1 ] : vals
184
+ trace. retval = tuple (vals... )
185
+
188
186
return nothing
189
187
end
190
188
@@ -418,56 +416,11 @@ function sample_internal(
418
416
sym = TracedUtils. get_attribute_by_name (func2, " sym_name" )
419
417
fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (sym))
420
418
421
- # Specify which outputs to add to the trace.
422
- traced_output_indices = Int[]
423
- for (i, res) in enumerate (linear_results)
424
- if TracedUtils. has_idx (res, resprefix)
425
- push! (traced_output_indices, i - 1 )
426
- end
427
- end
428
-
429
- # Specify which inputs to pass to logpdf.
430
- traced_input_indices = Int[]
431
- for (i, a) in enumerate (linear_args)
432
- idx, _ = TracedUtils. get_argidx (a, argprefix)
433
- if fnwrap && idx == 1 # TODO : add test for fnwrap
434
- continue
435
- end
436
-
437
- if fnwrap
438
- idx -= 1
439
- end
440
-
441
- if ! (args[idx] isa AbstractRNG)
442
- push! (traced_input_indices, i - 1 )
443
- end
444
- end
445
-
446
419
symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
447
420
symbol_attr = @ccall MLIR. API. mlir_c. enzymeSymbolAttrGet (
448
421
MLIR. IR. context ():: MLIR.API.MlirContext , symbol_addr:: UInt64
449
422
):: MLIR.IR.Attribute
450
423
451
- # (out_idx1, in_idx1, out_idx2, in_idx2, ...)
452
- alias_pairs = Int64[]
453
- for (out_idx, res) in enumerate (linear_results)
454
- if TracedUtils. has_idx (res, argprefix)
455
- in_idx = nothing
456
- for (i, arg) in enumerate (linear_args)
457
- if TracedUtils. has_idx (arg, argprefix) &&
458
- TracedUtils. get_idx (arg, argprefix) ==
459
- TracedUtils. get_idx (res, argprefix)
460
- in_idx = i - 1
461
- break
462
- end
463
- end
464
- @assert in_idx != = nothing " Unable to find operand for aliased result"
465
- push! (alias_pairs, out_idx - 1 )
466
- push! (alias_pairs, in_idx)
467
- end
468
- end
469
- alias_attr = MLIR. IR. DenseArrayAttribute (alias_pairs)
470
-
471
424
# Construct MLIR attribute if Julia logpdf function is provided.
472
425
logpdf_attr = nothing
473
426
if logpdf != = nothing
@@ -504,9 +457,6 @@ function sample_internal(
504
457
fn= fn_attr,
505
458
logpdf= logpdf_attr,
506
459
symbol= symbol_attr,
507
- traced_input_indices= traced_input_indices,
508
- traced_output_indices= traced_output_indices,
509
- alias_map= alias_attr,
510
460
name= Base. String (symbol),
511
461
)
512
462
@@ -547,8 +497,6 @@ function call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nar
547
497
r isa AbstractConcreteArray ? Array (r) : r
548
498
end
549
499
550
- @show res
551
-
552
500
return length (res) == 1 ? res[1 ] : res
553
501
end
554
502
@@ -581,19 +529,17 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w
581
529
fnwrap = mlir_fn_res. fnwrapped
582
530
func2 = mlir_fn_res. f
583
531
584
- @show length (linear_results), linear_results
585
-
586
532
out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
587
533
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
588
534
fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
589
535
590
536
inputs = MLIR. IR. Value[]
591
537
for a in linear_args
592
538
idx, path = TracedUtils. get_argidx (a, argprefix)
593
- if idx == 1 && fnwrap
539
+ if idx == 2 && fnwrap
594
540
TracedUtils. push_val! (inputs, f, path[3 : end ])
595
541
else
596
- if fnwrap
542
+ if fnwrap && idx > 2
597
543
idx -= 1
598
544
end
599
545
TracedUtils. push_val! (inputs, args[idx], path[3 : end ])
@@ -629,15 +575,14 @@ function call_internal(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) w
629
575
return result
630
576
end
631
577
632
- function simulate (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
633
- old_gc_state = GC. enable (false )
634
-
578
+ function simulate (rng:: AbstractRNG , f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
635
579
trace = nothing
636
- weight = nothing
637
- res = nothing
638
580
581
+ compiled_fn = @compile optimize = :probprog simulate_internal (rng, f, args... )
582
+
583
+ old_gc_state = GC. enable (false )
639
584
try
640
- trace, weight, res = @jit optimize = :probprog simulate_internal ( f, args... )
585
+ trace, _, _ = compiled_fn (rng, f, args... )
641
586
finally
642
587
GC. enable (old_gc_state)
643
588
end
@@ -647,20 +592,29 @@ function simulate(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
647
592
return trace, trace. weight
648
593
end
649
594
650
- function simulate_internal (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
595
+ function simulate_internal (
596
+ rng:: AbstractRNG , f:: Function , args:: Vararg{Any,Nargs}
597
+ ) where {Nargs}
651
598
argprefix:: Symbol = gensym (" simulatearg" )
652
599
resprefix:: Symbol = gensym (" simulateresult" )
653
600
resargprefix:: Symbol = gensym (" simulateresarg" )
654
601
602
+ wrapper_fn = (all_args... ) -> begin
603
+ res = f (all_args... )
604
+ (all_args[1 ], (res isa Tuple ? res : (res,)). .. )
605
+ end
606
+
607
+ args = (rng, args... )
608
+
655
609
mlir_fn_res = invokelatest (
656
610
TracedUtils. make_mlir_fn,
657
- f ,
611
+ wrapper_fn ,
658
612
args,
659
613
(),
660
614
string (f),
661
615
false ;
662
616
do_transpose= false ,
663
- args_in_result= :all ,
617
+ args_in_result= :result ,
664
618
argprefix,
665
619
resprefix,
666
620
resargprefix,
@@ -673,21 +627,13 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
673
627
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
674
628
fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
675
629
676
- # Specify which outputs to add to the trace.
677
- traced_output_indices = Int[]
678
- for (i, res) in enumerate (linear_results)
679
- if TracedUtils. has_idx (res, resprefix)
680
- push! (traced_output_indices, i - 1 )
681
- end
682
- end
683
-
684
630
inputs = MLIR. IR. Value[]
685
631
for a in linear_args
686
632
idx, path = TracedUtils. get_argidx (a, argprefix)
687
- if idx == 1 && fnwrap
633
+ if idx == 2 && fnwrap
688
634
TracedUtils. push_val! (inputs, f, path[3 : end ])
689
635
else
690
- if fnwrap
636
+ if fnwrap && idx > 2
691
637
idx -= 1
692
638
end
693
639
TracedUtils. push_val! (inputs, args[idx], path[3 : end ])
@@ -700,30 +646,29 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
700
646
weight_ty = MLIR. IR. TensorType (Int64[], MLIR. IR. Type (Float64))
701
647
702
648
simulate_op = MLIR. Dialects. enzyme. simulate (
703
- inputs;
704
- trace= trace_ty,
705
- weight= weight_ty,
706
- outputs= out_tys,
707
- fn= fn_attr,
708
- traced_output_indices= traced_output_indices,
649
+ inputs; trace= trace_ty, weight= weight_ty, outputs= out_tys, fn= fn_attr
709
650
)
710
651
711
652
for (i, res) in enumerate (linear_results)
712
653
resv = MLIR. IR. result (simulate_op, i + 2 )
713
654
if TracedUtils. has_idx (res, resprefix)
714
655
path = TracedUtils. get_idx (res, resprefix)
715
656
TracedUtils. set! (result, path[2 : end ], resv)
716
- elseif TracedUtils. has_idx (res, argprefix)
657
+ end
658
+
659
+ if TracedUtils. has_idx (res, argprefix)
717
660
idx, path = TracedUtils. get_argidx (res, argprefix)
718
- if idx == 1 && fnwrap
661
+ if idx == 2 && fnwrap
719
662
TracedUtils. set! (f, path[3 : end ], resv)
720
663
else
721
- if fnwrap
664
+ if fnwrap && idx > 2
722
665
idx -= 1
723
666
end
724
667
TracedUtils. set! (args[idx], path[3 : end ], resv)
725
668
end
726
- else
669
+ end
670
+
671
+ if ! TracedUtils. has_idx (res, resprefix) && ! TracedUtils. has_idx (res, argprefix)
727
672
TracedUtils. set! (res, (), resv)
728
673
end
729
674
end
@@ -751,24 +696,25 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
751
696
end
752
697
753
698
function generate (
754
- f:: Function , args:: Vararg{Any,Nargs} ; constraint:: Constraint = Dict {Symbol,Any} ()
699
+ rng:: AbstractRNG ,
700
+ f:: Function ,
701
+ args:: Vararg{Any,Nargs} ;
702
+ constraint:: Constraint = Dict {Symbol,Any} (),
755
703
) where {Nargs}
756
704
trace = nothing
757
- weight = nothing
758
- res = nothing
759
705
760
706
constraint_ptr = ConcreteRNumber (reinterpret (UInt64, pointer_from_objref (constraint)))
761
707
constrained_symbols = collect (keys (constraint))
762
708
763
- function wrapper_fn (constraint_ptr, args... )
764
- return generate_internal (f, args... ; constraint_ptr, constrained_symbols)
709
+ function wrapper_fn (rng, constraint_ptr, args... )
710
+ return generate_internal (rng, f, args... ; constraint_ptr, constrained_symbols)
765
711
end
766
712
767
- compiled_fn = @compile optimize = :probprog wrapper_fn (constraint_ptr, args... )
713
+ compiled_fn = @compile optimize = :probprog wrapper_fn (rng, constraint_ptr, args... )
768
714
769
715
old_gc_state = GC. enable (false )
770
716
try
771
- trace, weight, res = compiled_fn (constraint_ptr, args... )
717
+ trace, _, _ = compiled_fn (rng, constraint_ptr, args... )
772
718
finally
773
719
GC. enable (old_gc_state)
774
720
end
@@ -779,6 +725,7 @@ function generate(
779
725
end
780
726
781
727
function generate_internal (
728
+ rng:: AbstractRNG ,
782
729
f:: Function ,
783
730
args:: Vararg{Any,Nargs} ;
784
731
constraint_ptr:: TracedRNumber ,
@@ -788,15 +735,22 @@ function generate_internal(
788
735
resprefix:: Symbol = gensym (" generateresult" )
789
736
resargprefix:: Symbol = gensym (" generateresarg" )
790
737
738
+ wrapper_fn = (all_args... ) -> begin
739
+ res = f (all_args... )
740
+ (all_args[1 ], (res isa Tuple ? res : (res,)). .. )
741
+ end
742
+
743
+ args = (rng, args... )
744
+
791
745
mlir_fn_res = invokelatest (
792
746
TracedUtils. make_mlir_fn,
793
- f ,
747
+ wrapper_fn ,
794
748
args,
795
749
(),
796
750
string (f),
797
751
false ;
798
752
do_transpose= false ,
799
- args_in_result= :all ,
753
+ args_in_result= :result ,
800
754
argprefix,
801
755
resprefix,
802
756
resargprefix,
@@ -809,21 +763,13 @@ function generate_internal(
809
763
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
810
764
fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
811
765
812
- # Specify which outputs to add to the trace.
813
- traced_output_indices = Int[]
814
- for (i, res) in enumerate (linear_results)
815
- if TracedUtils. has_idx (res, resprefix)
816
- push! (traced_output_indices, i - 1 )
817
- end
818
- end
819
-
820
766
inputs = MLIR. IR. Value[]
821
767
for a in linear_args
822
768
idx, path = TracedUtils. get_argidx (a, argprefix)
823
- if idx == 1 && fnwrap
769
+ if idx == 2 && fnwrap
824
770
TracedUtils. push_val! (inputs, f, path[3 : end ])
825
771
else
826
- if fnwrap
772
+ if fnwrap && idx > 2
827
773
idx -= 1
828
774
end
829
775
TracedUtils. push_val! (inputs, args[idx], path[3 : end ])
@@ -865,25 +811,28 @@ function generate_internal(
865
811
outputs= out_tys,
866
812
fn= fn_attr,
867
813
constrained_symbols= MLIR. IR. Attribute (constrained_symbols_attr),
868
- traced_output_indices,
869
814
)
870
815
871
816
for (i, res) in enumerate (linear_results)
872
817
resv = MLIR. IR. result (generate_op, i + 2 )
873
818
if TracedUtils. has_idx (res, resprefix)
874
819
path = TracedUtils. get_idx (res, resprefix)
875
820
TracedUtils. set! (result, path[2 : end ], resv)
876
- elseif TracedUtils. has_idx (res, argprefix)
821
+ end
822
+
823
+ if TracedUtils. has_idx (res, argprefix)
877
824
idx, path = TracedUtils. get_argidx (res, argprefix)
878
- if idx == 1 && fnwrap
825
+ if idx == 2 && fnwrap
879
826
TracedUtils. set! (f, path[3 : end ], resv)
880
827
else
881
- if fnwrap
828
+ if fnwrap && idx > 2
882
829
idx -= 1
883
830
end
884
831
TracedUtils. set! (args[idx], path[3 : end ], resv)
885
832
end
886
- else
833
+ end
834
+
835
+ if ! TracedUtils. has_idx (res, resprefix) && ! TracedUtils. has_idx (res, argprefix)
887
836
TracedUtils. set! (res, (), resv)
888
837
end
889
838
end
0 commit comments