Skip to content

Commit 92689ca

Browse files
devmotionyebai
andcommitted
Use Distributions.loglikelihood instead of Distributions.logpdf (#153)
This PR replaces some occurrences of `logpdf` with `loglikelihood` since the former is intended to be used for a single sample only but we are interested in the log probability of individual and multiple samples. JuliaStats/Distributions.jl#1144 allows us to use `loglikelihood` even for individual samples and arrays of samples (everything that can be sampled by `rand` should support the computation of `loglikelihood`). Currently, `logpdf` is misused (also in DistributionsAD) to compute arrays of log densities for multiple samples which are summed afterwards. Usually, this intermittent step can be avoided by summing the log densities directly (which is the default implementation in Distributions). Similar issues and possible optimizations exist for `Bijectors.logpdf_with_trans` (see TuringLang/Bijectors.jl#120). Co-authored-by: Hong Ge <[email protected]>
1 parent ddbb14c commit 92689ca

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1515
[compat]
1616
AbstractMCMC = "1"
1717
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
18-
Distributions = "0.22, 0.23"
18+
Distributions = "0.23.8"
1919
MacroTools = "0.5.1"
2020
NaturalSort = "1"
2121
ZygoteRules = "0.2"

src/context_implementations.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ function observe(
149149
vi,
150150
)
151151
increment_num_produce!(vi)
152-
return Distributions.logpdf(dist, value)
152+
return Distributions.loglikelihood(dist, value)
153153
end
154154

155155
# .~ functions
@@ -444,18 +444,33 @@ function dot_observe(
444444
vi,
445445
)
446446
increment_num_produce!(vi)
447-
@debug "dot_observe" dist value
448-
return sum(Distributions.logpdf(dist, value))
447+
@debug "dist = $dist"
448+
@debug "value = $value"
449+
return Distributions.loglikelihood(dist, value)
449450
end
450451
function dot_observe(
451452
spl::Union{SampleFromPrior, SampleFromUniform},
452-
dists::Union{Distribution, AbstractArray{<:Distribution}},
453+
dists::Distribution,
453454
value::AbstractArray,
454455
vi,
455456
)
456457
increment_num_produce!(vi)
457-
@debug "dot_observe" dists value
458-
return sum(Distributions.logpdf.(dists, value))
458+
@debug "dists = $dists"
459+
@debug "value = $value"
460+
return Distributions.loglikelihood(dists, value)
461+
end
462+
function dot_observe(
463+
spl::Union{SampleFromPrior, SampleFromUniform},
464+
dists::AbstractArray{<:Distribution},
465+
value::AbstractArray,
466+
vi,
467+
)
468+
increment_num_produce!(vi)
469+
@debug "dists = $dists"
470+
@debug "value = $value"
471+
return sum(zip(dists, value)) do (d, v)
472+
Distributions.loglikelihood(d, v)
473+
end
459474
end
460475
function dot_observe(
461476
spl::Sampler,

0 commit comments

Comments
 (0)