Skip to content

Commit 892b971

Browse files
committed
Fixed fallback for observe (#265)
In the recent breaking release we mistakenly removed the call to `observe(::AbstractSampler, right left, vi)` as a fallback for `DefaultContext`, leading to certain sampler breaking in Turing (TuringLang/Turing.jl#1636). This PR adds back a proper fallback, making overloads such as https://github.com/TuringLang/Turing.jl/blob/tor%2Fdppl-update/src/inference/AdvancedSMC.jl#L353-L356 work as before 0.11.
1 parent 4f31771 commit 892b971

File tree

2 files changed

+10
-32
lines changed

2 files changed

+10
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.12.0"
3+
version = "0.12.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/context_implementations.jl

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,15 @@ end
163163

164164
# Leaf contexts
165165
tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi)
166-
tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi)
166+
function tilde_observe(::DefaultContext, sampler, right, left, vi)
167+
return observe(sampler, right, left, vi)
168+
end
167169
tilde_observe(::PriorContext, right, left, vi) = 0
168170
tilde_observe(::PriorContext, sampler, right, left, vi) = 0
169171
tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi)
170-
tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi)
172+
function tilde_observe(::LikelihoodContext, sampler, right, left, vi)
173+
return observe(sampler, right, left, vi)
174+
end
171175

172176
# `MiniBatchContext`
173177
function tilde_observe(context::MiniBatchContext, right, left, vi)
@@ -259,6 +263,7 @@ function assume(
259263
end
260264

261265
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
266+
observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi)
262267
function observe(right::Distribution, left, vi)
263268
increment_num_produce!(vi)
264269
return Distributions.loglikelihood(right, left)
@@ -633,46 +638,19 @@ function dot_tilde_observe!(context, right, left, vi)
633638
return left
634639
end
635640

636-
# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics
637-
function dot_observe(
638-
::Union{SampleFromPrior,SampleFromUniform},
639-
dist::MultivariateDistribution,
640-
value::AbstractMatrix,
641-
vi,
642-
)
641+
# Falls back to non-sampler definition.
642+
function dot_observe(::AbstractSampler, dist, value, vi)
643643
return dot_observe(dist, value, vi)
644644
end
645645
function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi)
646646
increment_num_produce!(vi)
647-
@debug "dist = $dist"
648-
@debug "value = $value"
649647
return Distributions.loglikelihood(dist, value)
650648
end
651-
function dot_observe(
652-
::Union{SampleFromPrior,SampleFromUniform},
653-
dists::Distribution,
654-
value::AbstractArray,
655-
vi,
656-
)
657-
return dot_observe(dists, value, vi)
658-
end
659649
function dot_observe(dists::Distribution, value::AbstractArray, vi)
660650
increment_num_produce!(vi)
661-
@debug "dists = $dists"
662-
@debug "value = $value"
663651
return Distributions.loglikelihood(dists, value)
664652
end
665-
function dot_observe(
666-
::Union{SampleFromPrior,SampleFromUniform},
667-
dists::AbstractArray{<:Distribution},
668-
value::AbstractArray,
669-
vi,
670-
)
671-
return dot_observe(dists, value, vi)
672-
end
673653
function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi)
674654
increment_num_produce!(vi)
675-
@debug "dists = $dists"
676-
@debug "value = $value"
677655
return sum(Distributions.loglikelihood.(dists, value))
678656
end

0 commit comments

Comments
 (0)