@@ -163,11 +163,15 @@ end
163
163
164
164
# Leaf contexts
165
165
tilde_observe (:: DefaultContext , right, left, vi) = observe (right, left, vi)
166
- tilde_observe (:: DefaultContext , sampler, right, left, vi) = observe (right, left, vi)
166
+ function tilde_observe (:: DefaultContext , sampler, right, left, vi)
167
+ return observe (sampler, right, left, vi)
168
+ end
167
169
tilde_observe (:: PriorContext , right, left, vi) = 0
168
170
tilde_observe (:: PriorContext , sampler, right, left, vi) = 0
169
171
tilde_observe (:: LikelihoodContext , right, left, vi) = observe (right, left, vi)
170
- tilde_observe (:: LikelihoodContext , sampler, right, left, vi) = observe (right, left, vi)
172
+ function tilde_observe (:: LikelihoodContext , sampler, right, left, vi)
173
+ return observe (sampler, right, left, vi)
174
+ end
171
175
172
176
# `MiniBatchContext`
173
177
function tilde_observe (context:: MiniBatchContext , right, left, vi)
@@ -259,6 +263,7 @@ function assume(
259
263
end
260
264
261
265
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
266
+ observe (sampler:: AbstractSampler , right, left, vi) = observe (right, left, vi)
262
267
function observe (right:: Distribution , left, vi)
263
268
increment_num_produce! (vi)
264
269
return Distributions. loglikelihood (right, left)
@@ -633,46 +638,19 @@ function dot_tilde_observe!(context, right, left, vi)
633
638
return left
634
639
end
635
640
636
- # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics
637
- function dot_observe (
638
- :: Union{SampleFromPrior,SampleFromUniform} ,
639
- dist:: MultivariateDistribution ,
640
- value:: AbstractMatrix ,
641
- vi,
642
- )
641
+ # Falls back to non-sampler definition.
642
+ function dot_observe (:: AbstractSampler , dist, value, vi)
643
643
return dot_observe (dist, value, vi)
644
644
end
645
645
function dot_observe (dist:: MultivariateDistribution , value:: AbstractMatrix , vi)
646
646
increment_num_produce! (vi)
647
- @debug " dist = $dist "
648
- @debug " value = $value "
649
647
return Distributions. loglikelihood (dist, value)
650
648
end
651
- function dot_observe (
652
- :: Union{SampleFromPrior,SampleFromUniform} ,
653
- dists:: Distribution ,
654
- value:: AbstractArray ,
655
- vi,
656
- )
657
- return dot_observe (dists, value, vi)
658
- end
659
649
function dot_observe (dists:: Distribution , value:: AbstractArray , vi)
660
650
increment_num_produce! (vi)
661
- @debug " dists = $dists "
662
- @debug " value = $value "
663
651
return Distributions. loglikelihood (dists, value)
664
652
end
665
- function dot_observe (
666
- :: Union{SampleFromPrior,SampleFromUniform} ,
667
- dists:: AbstractArray{<:Distribution} ,
668
- value:: AbstractArray ,
669
- vi,
670
- )
671
- return dot_observe (dists, value, vi)
672
- end
673
653
function dot_observe (dists:: AbstractArray{<:Distribution} , value:: AbstractArray , vi)
674
654
increment_num_produce! (vi)
675
- @debug " dists = $dists "
676
- @debug " value = $value "
677
655
return sum (Distributions. loglikelihood .(dists, value))
678
656
end
0 commit comments