Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ jobs:
test:
# Run some of the slower test files individually. The last one catches everything
# not included in the others.
- name: "essential/ad"
args: "essential/ad.jl"
- name: "mcmc/gibbs"
args: "mcmc/gibbs.jl"
- name: "mcmc/hmc"
Expand All @@ -37,7 +35,7 @@ jobs:
- name: "mcmc/ess"
args: "mcmc/ess.jl"
- name: "everything else"
args: "--skip essential/ad.jl mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl"
args: "--skip mcmc/gibbs.jl mcmc/hmc.jl mcmc/abstractmcmc.jl mcmc/Inference.jl mcmc/ess.jl"
runner:
# Default
- version: '1'
Expand Down
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.11.3"
manifest_format = "2.0"
project_hash = "83ec9face19bc568fc30cc287161517dc49f6c5c"
project_hash = "afdf28a30966aaa4af542a30879dd92074661565"

[[deps.ADTypes]]
git-tree-sha1 = "fb97701c117c8162e84dfcf80215caa904aef44f"
Expand Down
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand Down Expand Up @@ -69,7 +68,6 @@ ForwardDiff = "0.10.3"
Libtask = "0.8.8"
LinearAlgebra = "1"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "5, 6"
NamedArrays = "0.9, 0.10"
Optim = "1"
Expand Down
24 changes: 9 additions & 15 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,10 @@
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
###

if isdefined(Base, :get_extension)
using DynamicHMC: DynamicHMC
using Turing
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
else
import ..DynamicHMC
using ..Turing
using ..Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using ..Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
end
using DynamicHMC: DynamicHMC
using Turing
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using Turing.Inference: ADTypes, TYPEDFIELDS

"""
DynamicNUTS
Expand Down Expand Up @@ -69,10 +62,11 @@
end

# Define log-density function.
ℓ = LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
ℓ = DynamicPPL.LogDensityFunction(

Check warning on line 65 in ext/TuringDynamicHMCExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringDynamicHMCExt.jl#L65

Added line #L65 was not covered by tests
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
)

# Perform initial step.
Expand Down
49 changes: 24 additions & 25 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
module TuringOptimExt

if isdefined(Base, :get_extension)
using Turing: Turing
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
using Optim: Optim
else
import ..Turing
import ..Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
import ..Optim
end
using Turing: Turing
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
using Optim: Optim

####################
# Optim.jl methods #
Expand Down Expand Up @@ -42,7 +36,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 39 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L39

Added line #L39 was not covered by tests
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -65,7 +59,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 62 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L62

Added line #L62 was not covered by tests
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(
Expand All @@ -81,7 +75,7 @@

function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)

Check warning on line 78 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L78

Added line #L78 was not covered by tests
end

"""
Expand Down Expand Up @@ -112,7 +106,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 109 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L109

Added line #L109 was not covered by tests
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -135,7 +129,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 132 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L132

Added line #L132 was not covered by tests
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(
Expand All @@ -151,28 +145,29 @@

function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
return _optimize(model, Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)

Check warning on line 148 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L148

Added line #L148 was not covered by tests
end

"""
_optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)

Estimate a mode, i.e., compute a MLE or MAP estimate.
"""
function _optimize(
model::DynamicPPL.Model,
f::Optimisation.OptimLogDensity,
init_vals::AbstractArray=DynamicPPL.getparams(f),
init_vals::AbstractArray=DynamicPPL.getparams(f.ldf),
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
options::Optim.Options=Optim.Options(),
args...;
kwargs...,
)
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)
# TODO(penelopeysm): As with in src/optimisation/Optimisation.jl, unclear
# whether initialisation is really necessary at all
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
vi = DynamicPPL.link(vi, f.ldf.model)
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 170 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L167-L170

Added lines #L167 - L170 were not covered by tests

# Optimize!
M = Optim.optimize(Optim.only_fg!(f), init_vals, optimizer, options, args...; kwargs...)
Expand All @@ -186,12 +181,16 @@
end

# Get the optimum in unconstrained space. `getparams` does the invlinking.
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
vns_vals_iter = Turing.Inference.getparams(model, f.varinfo)
vi = f.ldf.varinfo
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
logdensity_optimum = Optimisation.OptimLogDensity(

Check warning on line 186 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L184-L186

Added lines #L184 - L186 were not covered by tests
f.ldf.model, vi_optimum, f.ldf.context
)
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)

Check warning on line 189 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L189

Added line #L189 was not covered by tests
varnames = map(Symbol ∘ first, vns_vals_iter)
vals = map(last, vns_vals_iter)
vmat = NamedArrays.NamedArray(vals, varnames)
return Optimisation.ModeResult(vmat, M, -M.minimum, f)
return Optimisation.ModeResult(vmat, M, -M.minimum, logdensity_optimum)

Check warning on line 193 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L193

Added line #L193 was not covered by tests
end

end # module
24 changes: 0 additions & 24 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import AdvancedPS
import Accessors
import EllipticalSliceSampling
import LogDensityProblems
import LogDensityProblemsAD
import Random
import MCMCChains
import StatsBase: predict
Expand Down Expand Up @@ -160,29 +159,6 @@ function externalsampler(
return ExternalSampler(sampler, adtype, Val(unconstrained))
end

getADType(spl::Sampler) = getADType(spl.alg)
getADType(::SampleFromPrior) = Turing.DEFAULT_ADTYPE

getADType(ctx::DynamicPPL.SamplingContext) = getADType(ctx.sampler)
getADType(ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.NodeTrait(ctx), ctx)
getADType(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = Turing.DEFAULT_ADTYPE
function getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext)
return getADType(DynamicPPL.childcontext(ctx))
end

getADType(alg::Hamiltonian) = alg.adtype

function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(getADType(ℓ.context), ℓ)
end

function LogDensityProblems.logdensity(
f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext},
x::NamedTuple,
)
return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x))
end

# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
set_namedtuple!(deepcopy(vi), θ)
Expand Down
51 changes: 11 additions & 40 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct TuringState{S,F}
struct TuringState{S,M,V,C}
state::S
logdensity::F
ldf::DynamicPPL.LogDensityFunction{M,V,C}
end

state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
Expand All @@ -12,20 +12,10 @@
return Transition(f.model, varinfo, transition)
end

state_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, state) = TuringState(state, f)
function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transition)
return transition_to_turing(parent(f), transition)
end

function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))
end
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo

function varinfo(state::TuringState)
θ = getparams(DynamicPPL.getmodel(state.logdensity), state.state)
θ = getparams(state.ldf.model, state.state)

Check warning on line 16 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L16

Added line #L16 was not covered by tests
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)
return DynamicPPL.unflatten(state.ldf.varinfo, θ)

Check warning on line 18 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L18

Added line #L18 was not covered by tests
end
varinfo(state::AbstractVarInfo) = state

Expand All @@ -40,23 +30,6 @@

getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params

getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper)
return getvarinfo(LogDensityProblemsAD.parent(f))
end

function setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo)
return DynamicPPL.LogDensityFunction(f.model, varinfo, f.context; adtype=f.adtype)
end

function setvarinfo(
f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType
)
return LogDensityProblemsAD.ADgradient(
adtype, setvarinfo(LogDensityProblemsAD.parent(f), varinfo)
)
end

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand All @@ -69,12 +42,8 @@
alg = sampler_wrapper.alg
sampler = alg.sampler

# Create a log-density function with an implementation of the
# gradient so we ensure that we're using the same AD backend as in Turing.
f = LogDensityProblemsAD.ADgradient(alg.adtype, DynamicPPL.LogDensityFunction(model))

# Link the varinfo if needed.
varinfo = getvarinfo(f)
# Initialise varinfo with initial params and link the varinfo if needed.
varinfo = DynamicPPL.VarInfo(model)
if requires_unconstrained_space(alg)
if initial_params !== nothing
# If we have initial parameters, we need to set the varinfo before linking.
Expand All @@ -85,9 +54,11 @@
varinfo = DynamicPPL.link(varinfo, model)
end
end
f = setvarinfo(f, varinfo, alg.adtype)

# Then just call `AdvancedHMC.step` with the right arguments.
# Construct LogDensityFunction
f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype)

# Then just call `AbstractMCMC.step` with the right arguments.
if initial_state === nothing
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs...
Expand All @@ -114,7 +85,7 @@
kwargs...,
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
f = state.ldf

# Then just call `AdvancedHMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@
state::TuringState,
params::AbstractVarInfo,
)
logdensity = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype)
logdensity = DynamicPPL.setmodel(state.ldf, model, sampler.alg.adtype)

Check warning on line 441 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L441

Added line #L441 was not covered by tests
new_inner_state = setparams_varinfo!!(
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
)
Expand Down
Loading
Loading