Skip to content

Commit eedb02f

Browse files
committed
add getparams and setparams!!
1 parent 47b212a commit eedb02f

File tree

3 files changed

+52
-31
lines changed

3 files changed

+52
-31
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.6.2"
3+
version = "0.6.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains"
3030
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3131

3232
[compat]
33-
AbstractMCMC = "5"
33+
AbstractMCMC = "5.5"
3434
ArgCheck = "1, 2"
3535
CUDA = "3, 4, 5"
3636
DocStringExtensions = "0.8, 0.9"

src/abstractmcmc.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ getadaptor(state::HMCState) = state.adaptor
3030
getmetric(state::HMCState) = state.metric
3131
getintegrator(state::HMCState) = state.κ.τ.integrator
3232

33+
function AbstractMCMC.getparams(state::HMCState)
34+
# TODO(sunxd): should we return a copy?
35+
return state.transition.z.θ
36+
end
37+
38+
function AbstractMCMC.setparams!!(state::HMCState, θ)
39+
return @set state.transition.z.θ = θ
40+
end
41+
3342
"""
3443
$(TYPEDSIGNATURES)
3544

test/abstractmcmc.jl

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Statistics: mean
88
θ_init = randn(rng, 2)
99

1010
nuts = NUTS(0.8)
11-
hmc = HMC(100; integrator = Leapfrog(0.05))
11+
hmc = HMC(100; integrator=Leapfrog(0.05))
1212
hmcda = HMCDA(0.8, 0.1)
1313

1414
integrator = Leapfrog(1e-3)
@@ -21,15 +21,27 @@ using Statistics: mean
2121
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo),
2222
)
2323

24+
@testset "getparams and setparams!!" begin
25+
t, s = AbstractMCMC.step(
26+
rng,
27+
model,
28+
nuts;
29+
)
30+
31+
θ = AbstractMCMC.getparams(s)
32+
@test θ == t.z.θ
33+
@test AbstractMCMC.setparams!!(s, θ) == s
34+
end
35+
2436
samples_nuts = AbstractMCMC.sample(
2537
rng,
2638
model,
2739
nuts,
2840
n_adapts + n_samples;
29-
n_adapts = n_adapts,
30-
initial_params = θ_init,
31-
progress = false,
32-
verbose = false,
41+
n_adapts=n_adapts,
42+
initial_params=θ_init,
43+
progress=false,
44+
verbose=false,
3345
)
3446

3547
# Error if keyword argument `nadapts` is used
@@ -38,10 +50,10 @@ using Statistics: mean
3850
model,
3951
nuts,
4052
n_adapts + n_samples;
41-
nadapts = n_adapts,
42-
initial_params = θ_init,
43-
progress = false,
44-
verbose = false,
53+
nadapts=n_adapts,
54+
initial_params=θ_init,
55+
progress=false,
56+
verbose=false,
4557
)
4658
@test_throws ArgumentError AbstractMCMC.sample(
4759
rng,
@@ -50,10 +62,10 @@ using Statistics: mean
5062
MCMCThreads(),
5163
n_adapts + n_samples,
5264
2;
53-
nadapts = n_adapts,
54-
initial_params = θ_init,
55-
progress = false,
56-
verbose = false,
65+
nadapts=n_adapts,
66+
initial_params=θ_init,
67+
progress=false,
68+
verbose=false,
5769
)
5870

5971
# Transform back to original space.
@@ -73,10 +85,10 @@ using Statistics: mean
7385
model,
7486
hmc,
7587
n_adapts + n_samples;
76-
n_adapts = n_adapts,
77-
initial_params = θ_init,
78-
progress = false,
79-
verbose = false,
88+
n_adapts=n_adapts,
89+
initial_params=θ_init,
90+
progress=false,
91+
verbose=false,
8092
)
8193

8294
# Transform back to original space.
@@ -96,10 +108,10 @@ using Statistics: mean
96108
model,
97109
custom,
98110
n_adapts + n_samples;
99-
n_adapts = 0,
100-
initial_params = θ_init,
101-
progress = false,
102-
verbose = false,
111+
n_adapts=0,
112+
initial_params=θ_init,
113+
progress=false,
114+
verbose=false,
103115
)
104116

105117
# Transform back to original space.
@@ -122,20 +134,20 @@ using Statistics: mean
122134
model,
123135
custom,
124136
10;
125-
n_adapts = 0,
126-
initial_params = θ_init,
127-
progress = false,
128-
verbose = false,
137+
n_adapts=0,
138+
initial_params=θ_init,
139+
progress=false,
140+
verbose=false,
129141
)
130142
samples2 = AbstractMCMC.sample(
131143
rng2,
132144
model,
133145
custom,
134146
10;
135-
n_adapts = 0,
136-
initial_params = θ_init,
137-
progress = false,
138-
verbose = false,
147+
n_adapts=0,
148+
initial_params=θ_init,
149+
progress=false,
150+
verbose=false,
139151
)
140152
@test mapreduce(*, samples1, samples2) do s1, s2
141153
s1.z.θ == s2.z.θ

0 commit comments

Comments
 (0)