Skip to content

Commit d473012

Browse files
authored
more test fixes (#2491)
* Remove LogDensityProblemsAD, part 1 * update Optimisation code to not use LogDensityProblemsAD * Fix field name change * Don't put chunksize=0 * Remove LogDensityProblemsAD dep * Improve OptimLogDensity docstring * Remove unneeded model argument to _optimize * Fix more tests * Remove essential/ad from the list of CI groups * Fix HMC function * More test fixes
1 parent ee2b148 commit d473012

File tree

5 files changed

+10
-8
lines changed

5 files changed

+10
-8
lines changed

src/mcmc/gibbs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,9 @@ function setparams_varinfo!!(
438438
state::TuringState,
439439
params::AbstractVarInfo,
440440
)
441-
logdensity = DynamicPPL.setmodel(state.ldf, model, sampler.alg.adtype)
441+
logdensity = DynamicPPL.LogDensityFunction(
442+
model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype
443+
)
442444
new_inner_state = setparams_varinfo!!(
443445
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
444446
)

src/mcmc/sghmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function DynamicPPL.initialstep(
7272
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
7373
adtype=spl.alg.adtype,
7474
)
75-
state = SGHMCState(ℓ, vi, zero(vi[spl]))
75+
state = SGHMCState(ℓ, vi, zero(vi[:]))
7676

7777
return sample, state
7878
end
@@ -87,7 +87,7 @@ function AbstractMCMC.step(
8787
# Compute gradient of log density.
8888
= state.logdensity
8989
vi = state.vi
90-
θ = vi[spl]
90+
θ = vi[:]
9191
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))
9292

9393
# Update latent variables and velocity according to
@@ -246,7 +246,7 @@ function AbstractMCMC.step(
246246
# Perform gradient step.
247247
= state.logdensity
248248
vi = state.vi
249-
θ = vi[spl]
249+
θ = vi[:]
250250
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))
251251
step = state.step
252252
stepsize = spl.alg.stepsize(step)

test/mcmc/Inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ using Turing
512512

513513
@model function vdemo2(x)
514514
μ ~ MvNormal(zeros(size(x, 1)), I)
515-
return x .~ MvNormal(μ, I)
515+
return x ~ filldist(MvNormal(μ, I), size(x, 2))
516516
end
517517

518518
D = 2
@@ -560,7 +560,7 @@ using Turing
560560

561561
@model function vdemo7()
562562
x = Array{Real}(undef, N, N)
563-
return x .~ [InverseGamma(2, 3) for i in 1:N]
563+
return x ~ filldist(InverseGamma(2, 3), N, N)
564564
end
565565

566566
sample(StableRNG(seed), vdemo7(), alg, 10)

test/mcmc/hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ using Turing
218218
# https://github.com/TuringLang/Turing.jl/issues/1308
219219
@model function mwe3(::Type{T}=Array{Float64}) where {T}
220220
m = T(undef, 2, 3)
221-
return m .~ MvNormal(zeros(2), I)
221+
return m ~ filldist(MvNormal(zeros(2), I), 3)
222222
end
223223
@test sample(StableRNG(seed), mwe3(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains
224224
end

test/mcmc/mh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
238238

239239
# Link if proposal is `AdvancedHM.RandomWalkProposal`
240240
vi = deepcopy(vi_base)
241-
d = length(vi_base[DynamicPPL.SampleFromPrior()])
241+
d = length(vi_base[:])
242242
alg = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I)))
243243
spl = DynamicPPL.Sampler(alg)
244244
vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default)

0 commit comments

Comments
 (0)