@@ -355,10 +355,12 @@ end
355
355
356
356
function generate_tilde_literal (left, right)
357
357
# If the LHS is a literal, it is always an observation
358
+ @gensym value
358
359
return quote
359
- $ (DynamicPPL. tilde_observe!)(
360
+ $ value, __varinfo__ = $ (DynamicPPL. tilde_observe! !)(
360
361
__context__, $ (DynamicPPL. check_tilde_rhs)($ right), $ left, __varinfo__
361
362
)
363
+ $ value
362
364
end
363
365
end
364
366
@@ -373,7 +375,7 @@ function generate_tilde(left, right)
373
375
374
376
# Otherwise it is determined by the model or its value,
375
377
# if the LHS represents an observation
376
- @gensym vn isassumption
378
+ @gensym vn isassumption value
377
379
378
380
# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
379
381
# that in DynamicPPL we the entire function body. Instead we should be
@@ -389,32 +391,38 @@ function generate_tilde(left, right)
389
391
$ left = $ (DynamicPPL. getvalue_nested)(__context__, $ vn)
390
392
end
391
393
392
- $ (DynamicPPL. tilde_observe!)(
394
+ $ value, __varinfo__ = $ (DynamicPPL. tilde_observe! !)(
393
395
__context__,
394
396
$ (DynamicPPL. check_tilde_rhs)($ right),
395
397
$ (maybe_view (left)),
396
398
$ vn,
397
399
__varinfo__,
398
400
)
401
+ $ value
399
402
end
400
403
end
401
404
end
402
405
403
406
function generate_tilde_assume (left, right, vn)
404
- expr = :(
405
- $ left = $ (DynamicPPL. tilde_assume!)(
407
+ # HACK: Because the Setfield.jl macro does not support assignment
408
+ # with multiple arguments on the LHS, we need to capture the return-values
409
+ # and then update the LHS variables one by one.
410
+ @gensym value
411
+ expr = :($ left = $ value)
412
+ if left isa Expr
413
+ expr = AbstractPPL. drop_escape (
414
+ Setfield. setmacro (BangBang. prefermutation, expr; overwrite= true )
415
+ )
416
+ end
417
+
418
+ return quote
419
+ $ value, __varinfo__ = $ (DynamicPPL. tilde_assume!!)(
406
420
__context__,
407
421
$ (DynamicPPL. unwrap_right_vn)($ (DynamicPPL. check_tilde_rhs)($ right), $ vn). .. ,
408
422
__varinfo__,
409
423
)
410
- )
411
-
412
- return if left isa Expr
413
- AbstractPPL. drop_escape (
414
- Setfield. setmacro (BangBang. prefermutation, expr; overwrite= true )
415
- )
416
- else
417
- return expr
424
+ $ expr
425
+ $ value
418
426
end
419
427
end
420
428
@@ -428,7 +436,7 @@ function generate_dot_tilde(left, right)
428
436
429
437
# Otherwise it is determined by the model or its value,
430
438
# if the LHS represents an observation
431
- @gensym vn isassumption
439
+ @gensym vn isassumption value
432
440
return quote
433
441
$ vn = $ (AbstractPPL. drop_escape (varname (left)))
434
442
$ isassumption = $ (DynamicPPL. isassumption (left))
@@ -440,13 +448,14 @@ function generate_dot_tilde(left, right)
440
448
$ left .= $ (DynamicPPL. getvalue_nested)(__context__, $ vn)
441
449
end
442
450
443
- $ (DynamicPPL. dot_tilde_observe!)(
451
+ $ value, __varinfo__ = $ (DynamicPPL. dot_tilde_observe! !)(
444
452
__context__,
445
453
$ (DynamicPPL. check_tilde_rhs)($ right),
446
454
$ (maybe_view (left)),
447
455
$ vn,
448
456
__varinfo__,
449
457
)
458
+ $ value
450
459
end
451
460
end
452
461
end
@@ -455,15 +464,82 @@ function generate_dot_tilde_assume(left, right, vn)
455
464
# We don't need to use `Setfield.@set` here since
456
465
# `.=` is always going to be inplace + needs `left` to
457
466
# be something that supports `.=`.
458
- return :(
459
- $ left .= $ (DynamicPPL. dot_tilde_assume!)(
467
+ @gensym value
468
+ return quote
469
+ $ value, __varinfo__ = $ (DynamicPPL. dot_tilde_assume!!)(
460
470
__context__,
461
471
$ (DynamicPPL. unwrap_right_left_vns)(
462
472
$ (DynamicPPL. check_tilde_rhs)($ right), $ (maybe_view (left)), $ vn
463
473
). .. ,
464
474
__varinfo__,
465
475
)
466
- )
476
+ $ left .= $ value
477
+ $ value
478
+ end
479
+ end
480
+
481
+ # Note that we cannot use `MacroTools.isdef` because
482
+ # of https://github.com/FluxML/MacroTools.jl/issues/154.
483
+ """
484
+ isfuncdef(expr)
485
+
486
+ Return `true` if `expr` is any form of function definition, and `false` otherwise.
487
+ """
488
+ function isfuncdef (e:: Expr )
489
+ return if Meta. isexpr (e, :function )
490
+ # Classic `function f(...)`
491
+ true
492
+ elseif Meta. isexpr (e, :-> )
493
+ # Anonymous functions/lambdas, e.g. `do` blocks or `->` defs.
494
+ true
495
+ elseif Meta. isexpr (e, :(= )) && Meta. isexpr (e. args[1 ], :call )
496
+ # Short function defs, e.g. `f(args...) = ...`.
497
+ true
498
+ else
499
+ false
500
+ end
501
+ end
502
+
503
+ """
504
+ replace_returns(expr)
505
+
506
+ Return `Expr` with all `return ...` statements replaced with
507
+ `return ..., DynamicPPL.return_values(__varinfo__)`.
508
+
509
+ Note that this method will _not_ replace `return` statements within function
510
+ definitions. This is checked using [`isfuncdef`](@ref).
511
+ """
512
+ replace_returns (e) = e
513
+ function replace_returns (e:: Expr )
514
+ if isfuncdef (e)
515
+ return e
516
+ end
517
+
518
+ if Meta. isexpr (e, :return )
519
+ # NOTE: `return` always has an argument. In the case of
520
+ # an empty `return`, the lowered expression will be `return nothing`.
521
+ # Hence we don't need any special handling for empty returns.
522
+ retval_expr = if length (e. args) > 1
523
+ Expr (:tuple , e. args... )
524
+ else
525
+ e. args[1 ]
526
+ end
527
+
528
+ return :(return ($ retval_expr, __varinfo__))
529
+ end
530
+
531
+ return Expr (e. head, map (replace_returns, e. args)... )
532
+ end
533
+
534
+ # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`.
535
+ make_returns_explicit! (body) = Expr (:return , body)
536
+ function make_returns_explicit! (body:: Expr )
537
+ # If the last statement is a return-statement, we don't do anything.
538
+ # Otherwise we replace the last statement with a `return` statement.
539
+ if ! Meta. isexpr (body. args[end ], :return )
540
+ body. args[end ] = Expr (:return , body. args[end ])
541
+ end
542
+ return body
467
543
end
468
544
469
545
const FloatOrArrayType = Type{<: Union{AbstractFloat,AbstractArray} }
@@ -496,10 +572,14 @@ function build_output(modelinfo, linenumbernode)
496
572
# Replace the user-provided function body with the version created by DynamicPPL.
497
573
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
498
574
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
499
- # to the call site
575
+ # to the call site.
576
+ # NOTE: We need to replace statements of the form `return ...` with
577
+ # `return (..., __varinfo__)` to ensure that the second
578
+ # element in the returned value is always the most up-to-date `__varinfo__`.
579
+ # See the docstrings of `replace_returns` for more info.
500
580
evaluatordef[:body ] = MacroTools. @q begin
501
581
$ (linenumbernode)
502
- $ (modelinfo[:body ])
582
+ $ (replace_returns ( make_returns_explicit! ( modelinfo[:body ])) )
503
583
end
504
584
505
585
# # Build the model function.
0 commit comments