Skip to content

Commit f6306ce

Browse files
committed
Use new getlogjoint for optimisation
1 parent 6aaad1f commit f6306ce

File tree

2 files changed

+7
-112
lines changed

2 files changed

+7
-112
lines changed

ext/TuringOptimExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function Optim.optimize(
102102
options::Optim.Options=Optim.Options();
103103
kwargs...,
104104
)
105-
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
105+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
106106
init_vals = DynamicPPL.getparams(f.ldf)
107107
optimizer = Optim.LBFGS()
108108
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -124,7 +124,7 @@ function Optim.optimize(
124124
options::Optim.Options=Optim.Options();
125125
kwargs...,
126126
)
127-
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
127+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
128128
init_vals = DynamicPPL.getparams(f.ldf)
129129
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
130130
end
@@ -140,7 +140,7 @@ function Optim.optimize(
140140
end
141141

142142
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
143-
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
143+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getlogjoint)
144144
return _optimize(f, args...; kwargs...)
145145
end
146146

src/optimisation/Optimisation.jl

Lines changed: 4 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -43,113 +43,6 @@ Concrete type for maximum a posteriori estimation. Only used for the Optim.jl in
4343
"""
4444
struct MAP <: ModeEstimator end
4545

46-
# Most of these functions for LogPriorWithoutJacobianAccumulator are copied from
47-
# LogPriorAccumulator. The only one that is different is the accumulate_assume!! one.
48-
"""
49-
LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator
50-
51-
Exactly like DynamicPPL.LogPriorAccumulator, but does not include the log determinant of the
52-
Jacobian of any variable transformations.
53-
54-
Used for MAP optimisation.
55-
"""
56-
struct LogPriorWithoutJacobianAccumulator{T} <: DynamicPPL.AbstractAccumulator
57-
logp::T
58-
end
59-
60-
"""
61-
LogPriorWithoutJacobianAccumulator{T}()
62-
63-
Create a new `LogPriorWithoutJacobianAccumulator` accumulator with the log prior initialized to zero.
64-
"""
65-
LogPriorWithoutJacobianAccumulator{T}() where {T<:Real} =
66-
LogPriorWithoutJacobianAccumulator(zero(T))
67-
function LogPriorWithoutJacobianAccumulator()
68-
return LogPriorWithoutJacobianAccumulator{DynamicPPL.LogProbType}()
69-
end
70-
71-
function Base.show(io::IO, acc::LogPriorWithoutJacobianAccumulator)
72-
return print(io, "LogPriorWithoutJacobianAccumulator($(repr(acc.logp)))")
73-
end
74-
75-
function DynamicPPL.accumulator_name(::Type{<:LogPriorWithoutJacobianAccumulator})
76-
return :LogPriorWithoutJacobian
77-
end
78-
79-
Base.copy(acc::LogPriorWithoutJacobianAccumulator) = acc
80-
81-
function DynamicPPL.split(::LogPriorWithoutJacobianAccumulator{T}) where {T}
82-
return LogPriorWithoutJacobianAccumulator(zero(T))
83-
end
84-
85-
function DynamicPPL.combine(
86-
acc::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator
87-
)
88-
return LogPriorWithoutJacobianAccumulator(acc.logp + acc2.logp)
89-
end
90-
91-
function Base.:+(
92-
acc1::LogPriorWithoutJacobianAccumulator, acc2::LogPriorWithoutJacobianAccumulator
93-
)
94-
return LogPriorWithoutJacobianAccumulator(acc1.logp + acc2.logp)
95-
end
96-
97-
function Base.zero(acc::LogPriorWithoutJacobianAccumulator)
98-
return LogPriorWithoutJacobianAccumulator(zero(acc.logp))
99-
end
100-
101-
function DynamicPPL.accumulate_assume!!(
102-
acc::LogPriorWithoutJacobianAccumulator, val, logjac, vn, right
103-
)
104-
return acc + LogPriorWithoutJacobianAccumulator(Distributions.logpdf(right, val))
105-
end
106-
function DynamicPPL.accumulate_observe!!(
107-
acc::LogPriorWithoutJacobianAccumulator, right, left, vn
108-
)
109-
return acc
110-
end
111-
112-
function Base.convert(
113-
::Type{LogPriorWithoutJacobianAccumulator{T}}, acc::LogPriorWithoutJacobianAccumulator
114-
) where {T}
115-
return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp))
116-
end
117-
118-
function DynamicPPL.convert_eltype(
119-
::Type{T}, acc::LogPriorWithoutJacobianAccumulator
120-
) where {T}
121-
return LogPriorWithoutJacobianAccumulator(convert(T, acc.logp))
122-
end
123-
124-
function getlogprior_without_jacobian(vi::DynamicPPL.AbstractVarInfo)
125-
acc = DynamicPPL.getacc(vi, Val(:LogPriorWithoutJacobian))
126-
return acc.logp
127-
end
128-
129-
function getlogjoint_without_jacobian(vi::DynamicPPL.AbstractVarInfo)
130-
return getlogprior_without_jacobian(vi) + DynamicPPL.getloglikelihood(vi)
131-
end
132-
133-
# This is called when constructing a LogDensityFunction, and ensures the VarInfo has the
134-
# right accumulators.
135-
function DynamicPPL.ldf_default_varinfo(
136-
model::DynamicPPL.Model, ::typeof(getlogprior_without_jacobian)
137-
)
138-
vi = DynamicPPL.VarInfo(model)
139-
vi = DynamicPPL.setaccs!!(vi, (LogPriorWithoutJacobianAccumulator(),))
140-
return vi
141-
end
142-
143-
function DynamicPPL.ldf_default_varinfo(
144-
model::DynamicPPL.Model, ::typeof(getlogjoint_without_jacobian)
145-
)
146-
vi = DynamicPPL.VarInfo(model)
147-
vi = DynamicPPL.setaccs!!(
148-
vi, (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator())
149-
)
150-
return vi
151-
end
152-
15346
"""
15447
OptimLogDensity{
15548
M<:DynamicPPL.Model,
@@ -625,8 +518,10 @@ function estimate_mode(
625518

626519
# Create an OptimLogDensity object that can be used to evaluate the objective function,
627520
# i.e. the negative log density.
628-
getlogdensity =
629-
estimator isa MAP ? getlogjoint_without_jacobian : DynamicPPL.getloglikelihood
521+
# Note that we use `getlogjoint` rather than `getlogjoint_internal`: this
522+
# is intentional, because even though the VarInfo may be linked, the
523+
# optimisation target should not take the Jacobian term into account.
524+
getlogdensity = estimator isa MAP ? DynamicPPL.getlogjoint : DynamicPPL.getloglikelihood
630525

631526
# Set its VarInfo to the initial parameters.
632527
# TODO(penelopeysm): Unclear if this is really needed? Any time that logp is calculated

0 commit comments

Comments
 (0)