Skip to content

Commit 7b0d5a9

Browse files
committed
Split up tilde_observe!! for Distribution / Submodel
1 parent bf53f34 commit 7b0d5a9

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/context_implementations.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ accumulate the log probability, and return the observed value and updated `vi`.
6060
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
6161
and indices; if needed, these can be accessed through this function, though.
6262
"""
63-
function tilde_observe!!(::DefaultContext, right, left, vn, vi)
64-
right isa DynamicPPL.Submodel &&
65-
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
63+
function tilde_observe!!(::DefaultContext, right::Distribution, left, vn, vi)
6664
vi = accumulate_observe!!(vi, right, left, vn)
6765
return left, vi
6866
end
67+
68+
function tilde_observe!!(::DefaultContext, ::DynamicPPL.Submodel, left, vn, vi)
69+
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
70+
end

src/transforming.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct DynamicTransformationContext{isinverse} <: AbstractContext end
1313
NodeTrait(::DynamicTransformationContext) = IsLeaf()
1414

1515
function tilde_assume!!(
16-
::DynamicTransformationContext{isinverse}, right, vn, vi
16+
::DynamicTransformationContext{isinverse}, right::Distribution, vn, vi
1717
) where {isinverse}
1818
# vi[vn, right] always provides the value in unlinked space.
1919
x = vi[vn, right]
@@ -31,7 +31,7 @@ function tilde_assume!!(
3131
return x, vi
3232
end
3333

34-
function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi)
34+
function tilde_observe!!(::DynamicTransformationContext, right::Distribution, left, vn, vi)
3535
return tilde_observe!!(DefaultContext(), right, left, vn, vi)
3636
end
3737

0 commit comments

Comments
 (0)