Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.8.3"
version = "0.8.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -23,7 +23,7 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
AdvancedMHStructArraysExt = "StructArrays"

[compat]
AbstractMCMC = "5"
AbstractMCMC = "5.5"
DiffResults = "1"
Distributions = "0.25"
FillArrays = "1"
Expand Down
10 changes: 10 additions & 0 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module AdvancedMH

# Import the relevant libraries.
using AbstractMCMC
using AbstractMCMC: BangBang
using Distributions
using LinearAlgebra: I
using FillArrays: Zeros
Expand Down Expand Up @@ -140,6 +141,15 @@ function __init__()
end
end

# AbstractMCMC.jl interface
function AbstractMCMC.getparams(t::Transition)
return t.params
end

function AbstractMCMC.setparams!!(t::Transition, params)
return BangBang.setproperty!!(t, :params, params)
end

# Include inference methods.
include("proposal.jl")
include("mh-core.jl")
Expand Down
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using AdvancedMH
using AbstractMCMC
using DiffResults
using Distributions
using ForwardDiff
Expand Down Expand Up @@ -33,6 +34,15 @@ include("util.jl")
LogDensityProblems.logdensity(::typeof(density), θ) = density(θ)
LogDensityProblems.dimension(::typeof(density)) = 2

@testset "getparams/setparams!! (AbstractMCMC interface)" begin
test_spl = StaticMH([Normal(0, 1), Normal(0, 1)])
t, _ = AbstractMCMC.step(Random.default_rng(), model, test_spl)
@test AbstractMCMC.getparams(t) == t.params
@test AbstractMCMC.setparams!!(t, AbstractMCMC.getparams(t)) == t
t_replaced = AbstractMCMC.setparams!!(t, (μ=1.0, σ=2.0))
@test t_replaced.params == (μ=1.0, σ=2.0)
end

@testset "StaticMH" begin
# Set up our sampler with initial parameters.
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
Expand Down
Loading