Skip to content

Commit ed17b7e

Browse files
committed
Simplify tests as we no longer support Julia <1.8
1 parent 410d98e commit ed17b7e

File tree

3 files changed

+24
-32
lines changed

3 files changed

+24
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
5656
AdvancedMH = "0.8"
5757
AdvancedPS = "0.6.0"
5858
AdvancedVI = "0.2"
59-
BangBang = "0.4"
59+
BangBang = "0.4.2"
6060
Bijectors = "0.13.6"
6161
Compat = "4.15.0"
6262
DataStructures = "0.18"

test/mcmc/hmc.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,15 @@ ADUtils.install_tapir && import Tapir
269269
end
270270
end
271271

272-
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
273-
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
274-
if VERSION v"1.8"
275-
@testset "(partially) issue: #2095" begin
276-
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
277-
xs = Vector{TV}(undef, 2)
278-
xs[1] ~ Dirichlet(ones(5))
279-
xs[2] ~ Dirichlet(ones(5))
280-
end
281-
model = vector_of_dirichlet()
282-
chain = sample(model, NUTS(), 1000)
283-
@test mean(Array(chain)) 0.2
272+
@testset "(partially) issue: #2095" begin
273+
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
274+
xs = Vector{TV}(undef, 2)
275+
xs[1] ~ Dirichlet(ones(5))
276+
xs[2] ~ Dirichlet(ones(5))
284277
end
278+
model = vector_of_dirichlet()
279+
chain = sample(model, NUTS(), 1000)
280+
@test mean(Array(chain)) 0.2
285281
end
286282

287283
@testset "issue: #2195" begin

test/mcmc/mh.jl

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,28 +185,24 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
185185
# @test v1 < v2
186186
end
187187

188-
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
189-
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
190-
if VERSION v"1.8"
191-
@testset "vector of multivariate distributions" begin
192-
@model function test(k)
193-
T = Vector{Vector{Float64}}(undef, k)
194-
for i in 1:k
195-
T[i] ~ Dirichlet(5, 1.0)
196-
end
188+
@testset "vector of multivariate distributions" begin
189+
@model function test(k)
190+
T = Vector{Vector{Float64}}(undef, k)
191+
for i in 1:k
192+
T[i] ~ Dirichlet(5, 1.0)
197193
end
194+
end
198195

199-
Random.seed!(100)
200-
chain = sample(test(1), MH(), 5_000)
201-
for i in 1:5
202-
@test mean(chain, "T[1][$i]") 0.2 atol = 0.01
203-
end
196+
Random.seed!(100)
197+
chain = sample(test(1), MH(), 5_000)
198+
for i in 1:5
199+
@test mean(chain, "T[1][$i]") 0.2 atol = 0.01
200+
end
204201

205-
Random.seed!(100)
206-
chain = sample(test(10), MH(), 5_000)
207-
for j in 1:10, i in 1:5
208-
@test mean(chain, "T[$j][$i]") 0.2 atol = 0.01
209-
end
202+
Random.seed!(100)
203+
chain = sample(test(10), MH(), 5_000)
204+
for j in 1:10, i in 1:5
205+
@test mean(chain, "T[$j][$i]") 0.2 atol = 0.01
210206
end
211207
end
212208

0 commit comments

Comments
 (0)