@@ -349,6 +349,23 @@ function __init__()
349
349
end
350
350
351
351
function sample (
352
+ rng:: AbstractRNG ,
353
+ f:: Function ,
354
+ args:: Vararg{Any,Nargs} ;
355
+ symbol:: Symbol = gensym (" sample" ),
356
+ logpdf:: Union{Nothing,Function} = nothing ,
357
+ ) where {Nargs}
358
+ res = sample_internal (rng, f, args... ; symbol, logpdf)
359
+
360
+ @assert res isa Tuple && length (res) >= 1 && res[1 ] isa AbstractRNG " Expected first result to be RNG"
361
+
362
+ res = res[2 : end ]
363
+
364
+ return length (res) == 1 ? res[1 ] : res
365
+ end
366
+
367
+ function sample_internal (
368
+ rng:: AbstractRNG ,
352
369
f:: Function ,
353
370
args:: Vararg{Any,Nargs} ;
354
371
symbol:: Symbol = gensym (" sample" ),
@@ -358,15 +375,22 @@ function sample(
358
375
resprefix:: Symbol = gensym (" sampleresult" )
359
376
resargprefix:: Symbol = gensym (" sampleresarg" )
360
377
378
+ wrapper_fn = (all_args... ) -> begin
379
+ res = f (all_args... )
380
+ (all_args[1 ], (res isa Tuple ? res : (res,)). .. )
381
+ end
382
+
383
+ args = (rng, args... )
384
+
361
385
mlir_fn_res = invokelatest (
362
386
TracedUtils. make_mlir_fn,
363
- f ,
387
+ wrapper_fn ,
364
388
args,
365
389
(),
366
390
string (f),
367
391
false ;
368
392
do_transpose= false ,
369
- args_in_result= :all ,
393
+ args_in_result= :result ,
370
394
argprefix,
371
395
resprefix,
372
396
resargprefix,
@@ -378,10 +402,13 @@ function sample(
378
402
inputs = MLIR. IR. Value[]
379
403
for a in linear_args
380
404
idx, path = TracedUtils. get_argidx (a, argprefix)
381
- if idx == 1 && fnwrap
405
+ if idx == 2 && fnwrap
382
406
TracedUtils. push_val! (inputs, f, path[3 : end ])
383
407
else
384
- idx -= fnwrap ? 1 : 0
408
+ if fnwrap && idx > 1
409
+ idx -= 1
410
+ end
411
+
385
412
TracedUtils. push_val! (inputs, args[idx], path[3 : end ])
386
413
end
387
414
end
@@ -464,7 +491,7 @@ function sample(
464
491
string (logpdf),
465
492
false ;
466
493
do_transpose= false ,
467
- args_in_result= :all ,
494
+ args_in_result= :result ,
468
495
)
469
496
470
497
logpdf_sym = TracedUtils. get_attribute_by_name (logpdf_mlir. f, " sym_name" )
@@ -485,46 +512,67 @@ function sample(
485
512
486
513
for (i, res) in enumerate (linear_results)
487
514
resv = MLIR. IR. result (sample_op, i)
515
+
488
516
if TracedUtils. has_idx (res, resprefix)
489
517
path = TracedUtils. get_idx (res, resprefix)
490
518
TracedUtils. set! (result, path[2 : end ], resv)
491
- elseif TracedUtils. has_idx (res, argprefix)
519
+ end
520
+
521
+ if TracedUtils. has_idx (res, argprefix)
492
522
idx, path = TracedUtils. get_argidx (res, argprefix)
493
- if idx == 1 && fnwrap
523
+ if fnwrap && idx == 2
494
524
TracedUtils. set! (f, path[3 : end ], resv)
495
525
else
496
- if fnwrap
526
+ if fnwrap && idx > 2
497
527
idx -= 1
498
528
end
499
529
TracedUtils. set! (args[idx], path[3 : end ], resv)
500
530
end
501
- else
531
+ end
532
+
533
+ if ! TracedUtils. has_idx (res, resprefix) && ! TracedUtils. has_idx (res, argprefix)
502
534
TracedUtils. set! (res, (), resv)
503
535
end
504
536
end
505
537
506
538
return result
507
539
end
508
540
509
- function call (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
510
- res = @jit optimize = :probprog call_internal (f, args... )
511
- return res isa AbstractConcreteArray ? Array (res) : res
541
+ function call (rng:: AbstractRNG , f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
542
+ res = @jit optimize = :probprog call_internal (rng, f, args... )
543
+
544
+ @assert res isa Tuple && length (res) >= 1 && res[1 ] isa AbstractRNG " Expected first result to be RNG"
545
+
546
+ res = map (res[2 : end ]) do r
547
+ r isa AbstractConcreteArray ? Array (r) : r
548
+ end
549
+
550
+ @show res
551
+
552
+ return length (res) == 1 ? res[1 ] : res
512
553
end
513
554
514
- function call_internal (f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
555
+ function call_internal (rng :: AbstractRNG , f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
515
556
argprefix:: Symbol = gensym (" callarg" )
516
557
resprefix:: Symbol = gensym (" callresult" )
517
558
resargprefix:: Symbol = gensym (" callresarg" )
518
559
560
+ wrapper_fn = (all_args... ) -> begin
561
+ res = f (all_args... )
562
+ (all_args[1 ], (res isa Tuple ? res : (res,)). .. )
563
+ end
564
+
565
+ args = (rng, args... )
566
+
519
567
mlir_fn_res = invokelatest (
520
568
TracedUtils. make_mlir_fn,
521
- f ,
569
+ wrapper_fn ,
522
570
args,
523
571
(),
524
572
string (f),
525
573
false ;
526
574
do_transpose= false ,
527
- args_in_result= :all ,
575
+ args_in_result= :result ,
528
576
argprefix,
529
577
resprefix,
530
578
resargprefix,
@@ -533,6 +581,8 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
533
581
fnwrap = mlir_fn_res. fnwrapped
534
582
func2 = mlir_fn_res. f
535
583
584
+ @show length (linear_results), linear_results
585
+
536
586
out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
537
587
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
538
588
fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
@@ -557,17 +607,21 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
557
607
if TracedUtils. has_idx (res, resprefix)
558
608
path = TracedUtils. get_idx (res, resprefix)
559
609
TracedUtils. set! (result, path[2 : end ], resv)
560
- elseif TracedUtils. has_idx (res, argprefix)
610
+ end
611
+
612
+ if TracedUtils. has_idx (res, argprefix)
561
613
idx, path = TracedUtils. get_argidx (res, argprefix)
562
- if idx == 1 && fnwrap
614
+ if fnwrap && idx == 2
563
615
TracedUtils. set! (f, path[3 : end ], resv)
564
616
else
565
- if fnwrap
617
+ if fnwrap && idx > 2
566
618
idx -= 1
567
619
end
568
620
TracedUtils. set! (args[idx], path[3 : end ], resv)
569
621
end
570
- else
622
+ end
623
+
624
+ if ! TracedUtils. has_idx (res, resprefix) && ! TracedUtils. has_idx (res, argprefix)
571
625
TracedUtils. set! (res, (), resv)
572
626
end
573
627
end
0 commit comments