Skip to content

Commit e318654

Browse files
authored
implement AbstractMCMC.getstats (#119)
* Implement AbstractMCMC.getstats * Add tests * Bump compat etc * Disable JuliaPre workflow * Add testing on 1.10 (lts)
1 parent 6ac07d5 commit e318654

File tree

7 files changed

+36
-12
lines changed

7 files changed

+36
-12
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,25 @@ jobs:
1717
# Current stable version
1818
- version: '1'
1919
os: ubuntu-latest
20-
arch: x64
20+
# LTS
21+
- version: 'lts'
22+
os: ubuntu-latest
2123
# Minimum supported version
2224
- version: 'min'
2325
os: ubuntu-latest
24-
arch: x64
2526
# Windows
2627
- version: '1'
2728
os: windows-latest
28-
arch: x64
2929
# macOS
3030
- version: '1'
3131
os: macos-latest
32-
arch: aarch64
3332

3433
steps:
3534
- uses: actions/checkout@v4
3635

3736
- uses: julia-actions/setup-julia@v2
3837
with:
3938
version: ${{ matrix.runner.version }}
40-
arch: ${{ matrix.runner.arch }}
4139

4240
- uses: julia-actions/cache@v2
4341

.github/workflows/JuliaPre.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
name: JuliaPre
22

33
on:
4-
push:
5-
branches:
6-
- main
7-
pull_request:
4+
workflow_dispatch:
5+
# Disabled for now because there is no available prerelease of Julia.
6+
# push:
7+
# branches:
8+
# - main
9+
# pull_request:
810

911
# needed to allow julia-actions/cache to delete old caches that it has created
1012
permissions:

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.8.8"
3+
version = "0.8.9"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -24,7 +24,7 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
2424
AdvancedMHStructArraysExt = "StructArrays"
2525

2626
[compat]
27-
AbstractMCMC = "5.6"
27+
AbstractMCMC = "5.9"
2828
DiffResults = "1"
2929
Distributions = "0.25"
3030
DocStringExtensions = "0.9"

src/AdvancedMH.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ end
146146
function AbstractMCMC.getparams(t::Transition)
147147
return t.params
148148
end
149+
function AbstractMCMC.getstats(t::Transition)
150+
return (accepted=t.accepted,)
151+
end
149152

150153
# TODO (sunxd): remove `DensityModel` in favor of `AbstractMCMC.LogDensityModel`
151154
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::Transition, params)

src/MALA.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp
2323
function AbstractMCMC.getparams(t::GradientTransition)
2424
return t.params
2525
end
26+
function AbstractMCMC.getstats(t::GradientTransition)
27+
return (accepted=t.accepted,)
28+
end
2629

2730
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::GradientTransition, params)
2831
lp, gradient = logdensity_and_gradient(model, params)

src/RobustAdaptiveMetropolis.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,18 @@ end
116116
AbstractMCMC.getparams(state::RobustAdaptiveMetropolisState) = state.x
117117
function AbstractMCMC.setparams!!(state::RobustAdaptiveMetropolisState, x)
118118
return RobustAdaptiveMetropolisState(
119-
x, state.logprob, state.S, state.logα, state.η, state.iteration, state.isaccept
119+
x,
120+
state.logprob,
121+
state.S,
122+
state.logα,
123+
state.η,
124+
state.iteration,
125+
state.isaccept,
120126
)
121127
end
128+
function AbstractMCMC.getstats(state::RobustAdaptiveMetropolisState)
129+
return (logα = state.logα, η = state.η, accepted = state.isaccept)
130+
end
122131

123132
function ram_step_inner(
124133
rng::Random.AbstractRNG,

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ include("util.jl")
5353
end
5454
end
5555

56+
@testset "AbstractMCMC.getstats" begin
57+
t1, _ = AbstractMCMC.step(Random.default_rng(), model, StaticMH([Normal(0, 1), Normal(0, 1)]))
58+
t2, _ = AbstractMCMC.step(Random.default_rng(), model, MALA(x -> MvNormal(x, I)); initial_params=ones(2))
59+
for t in [t1, t2]
60+
stats = AbstractMCMC.getstats(t)
61+
@test stats == (accepted = t.accepted,)
62+
end
63+
end
64+
5665
@testset "StaticMH" begin
5766
# Set up our sampler with initial parameters.
5867
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])

0 commit comments

Comments
 (0)