-
Notifications
You must be signed in to change notification settings - Fork 36
Accumulators, stage 1 #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
061acbe
1496868
324e623
bb59885
fc32398
cc5e581
b9c368b
1b8f555
ae9b1cd
4fb0bf4
e410f47
97788bd
7fe03ec
5ba3530
d49f7be
28bbf1c
a0ed665
be27636
e6453fe
c59400d
47033ce
8b841c9
3ee3989
13163f2
37dd6dd
d7013b6
40d4caa
ff5f2cb
c68f1bb
13da08a
d52feec
221e797
1dbcb2c
e1b70e0
3f195e5
68b974a
2b405d9
00cd304
6d1048d
4fef20f
905b874
557954a
6f702c9
f748775
5f4a532
31967fd
ad2f564
10b4f2f
d2b670d
8241d12
7b7a3e2
0b08237
2a4b874
c1e90f7
cb1c6c6
00ef0cf
14f4788
c4ee4ec
7ad9450
c5e2a6b
fb09acc
e324c9b
6b7b9f8
048178b
6437801
bf95169
efc7c53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -91,36 +91,70 @@ function transformation end | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Accumulation of log-probabilities. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
getlogp(vi::AbstractVarInfo) | ||||||||||||||||||||||||||||||||||
getlogjoint(vi::AbstractVarInfo) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Return the log of the joint probability of the observed data and parameters sampled in | ||||||||||||||||||||||||||||||||||
`vi`. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
function getlogp end | ||||||||||||||||||||||||||||||||||
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) | ||||||||||||||||||||||||||||||||||
getlogp(vi::AbstractVarInfo) = getlogjoint(vi) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function setaccs!! end | ||||||||||||||||||||||||||||||||||
function getaccs end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
getlogprior(vi::AbstractVarInfo) = getacc(vi, LogPrior).logp | ||||||||||||||||||||||||||||||||||
getloglikelihood(vi::AbstractVarInfo) = getacc(vi, LogLikelihood).logp | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) | ||||||||||||||||||||||||||||||||||
return setaccs!!(vi, setacc!!(getaccs(vi), acc)) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPrior(logp)) | ||||||||||||||||||||||||||||||||||
setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp)) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
setlogp!!(vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Set the log of the joint probability of the observed data and parameters sampled in | ||||||||||||||||||||||||||||||||||
`vi` to `logp`, mutating if it makes sense. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
function setlogp!! end | ||||||||||||||||||||||||||||||||||
function setlogp!!(vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
vi = setlogprior!!(vi, zero(logp)) | ||||||||||||||||||||||||||||||||||
vi = setloglikelihood!!(vi, logp) | ||||||||||||||||||||||||||||||||||
return vi | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
lp = getlogp(vi_typed_metadata) | |
varinfos = map(( | |
vi_untyped_metadata, | |
vi_untyped_vnv, | |
vi_typed_metadata, | |
vi_typed_vnv, | |
svi_typed, | |
svi_untyped, | |
svi_vnv, | |
svi_typed_ref, | |
svi_untyped_ref, | |
svi_vnv_ref, | |
)) do vi | |
# Set them all to the same values. | |
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) | |
end |
The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.
Of course, we can fix this on our end - we would get and set logprior and loglikelihood manually, and we can grep the codebase to make sure that there are no other ill-defined calls to setlogp. We can't guarantee that other people will be similarly careful, though (and us or anyone being careful also doesn't guarantee that everything will be fixed correctly).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While looking for other uses of setlogp, I encountered this:
AdvancedHMC.Transition
only contains a single notion of log density, so it's not obvious to me how we're going to extract the prior and likelihood components from it 😓 This might require upstream changes to AdvancedHMC. Since the contexts will be removed, I suspect LogDensityFunction
also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).
(For the record, I'd be quite happy with making all of these changes!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo.
It is inconsistent, but as long as the user only uses getlogp
, they would never see the difference, right? If some of logprior is accidentally stored in loglikelihood or vice versa, as long as one is using getlogp
and DefaultContext
that should be undetectable. What would be trouble is if someone mixes using e.g. setlogp!!
and getlogprior
, which would require adding calls to getlogprior
after upgrading to a version that has deprecated setlogp!!
, but probably people would end up doing that. Maybe the deprecation warning could say something about this?
Since the contexts will be removed, I suspect LogDensityFunction also needs to be changed so that there's a way for it to return only the prior or the likelihood (or maybe it should return both).
Yeah, this sort of stuff will come up (and is coming up) in multiple places. Anything that explicitly uses PriorContext or LikelihoodContext would need to be changed to use LogPrior and LogLikelihood accumulators instead. I'm currently doing this for pointwiselogdensities
.
Uh oh!
There was an error while loading. Please reload this page.