Skip to content

Run JuliaFormatter on more files, remove trailing whitespace #2374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,4 @@ ignore = [
# https://github.com/TuringLang/Turing.jl/pull/2328/files
"src/experimental/gibbs.jl",
"test/experimental/gibbs.jl",
# https://github.com/TuringLang/Turing.jl/pull/1887 # Enzyme PR
"test/mcmc/hmc.jl",
"test/mcmc/sghmc.jl",
]
6 changes: 3 additions & 3 deletions .github/workflows/DocsNav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ jobs:

# Define the URL of the navbar to be used
NAVBAR_URL="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/main/assets/scripts/TuringNavbar.html"

# Update all HTML files in the current directory (gh-pages root)
./insert_navbar.sh . $NAVBAR_URL

# Remove the insert_navbar.sh file
rm insert_navbar.sh

# Check if there are any changes
if [[ -n $(git status -s) ]]; then
git add .
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Specifying a single distribution implies the use of static MH:

```julia
# Use a static proposal for s² (which happens to be the same
# as the prior) and a static proposal for m (note that this
# as the prior) and a static proposal for m (note that this
# isn't a random walk proposal).
chain = sample(
gdemo(1.5, 2.0),
Expand Down
98 changes: 49 additions & 49 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,60 +22,59 @@ using Turing
# Set a seed
rng = StableRNG(123)
@testset "constrained bounded" begin
obs = [0,1,0,1,1,1,1,1,1,1]
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]

@model function constrained_test(obs)
p ~ Beta(2,2)
for i = 1:length(obs)
p ~ Beta(2, 2)
for i in 1:length(obs)
obs[i] ~ Bernoulli(p)
end
p
return p
end

chain = sample(
rng,
constrained_test(obs),
HMC(1.5, 3; adtype=adbackend),# using a large step size (1.5)
1000)
1000,
)

check_numerical(chain, [:p], [10/14], atol=0.1)
check_numerical(chain, [:p], [10 / 14]; atol=0.1)
end
@testset "constrained simplex" begin
obs12 = [1,2,1,2,2,2,2,2,2,2]
obs12 = [1, 2, 1, 2, 2, 2, 2, 2, 2, 2]

@model function constrained_simplex_test(obs12)
ps ~ Dirichlet(2, 3)
pd ~ Dirichlet(4, 1)
for i = 1:length(obs12)
for i in 1:length(obs12)
obs12[i] ~ Categorical(ps)
end
return ps
end

chain = sample(
rng,
constrained_simplex_test(obs12),
HMC(0.75, 2; adtype=adbackend),
1000)
rng, constrained_simplex_test(obs12), HMC(0.75, 2; adtype=adbackend), 1000
)

check_numerical(chain, ["ps[1]", "ps[2]"], [5/16, 11/16], atol=0.015)
check_numerical(chain, ["ps[1]", "ps[2]"], [5 / 16, 11 / 16]; atol=0.015)
end
@testset "hmc reverse diff" begin
alg = HMC(0.1, 10; adtype=adbackend)
res = sample(rng, gdemo_default, alg, 4000)
check_gdemo(res, rtol=0.1)
check_gdemo(res; rtol=0.1)
end
@testset "matrix support" begin
@model function hmcmatrixsup()
v ~ Wishart(7, [1 0.5; 0.5 1])
return v ~ Wishart(7, [1 0.5; 0.5 1])
end

model_f = hmcmatrixsup()
n_samples = 1_000
vs = map(1:3) do _
chain = sample(rng, model_f, HMC(0.15, 7; adtype=adbackend), n_samples)
r = reshape(Array(group(chain, :v)), n_samples, 2, 2)
reshape(mean(r; dims = 1), 2, 2)
reshape(mean(r; dims=1), 2, 2)
end

@test maximum(abs, mean(vs) - (7 * [1 0.5; 0.5 1])) <= 0.5
Expand All @@ -92,10 +91,10 @@ using Turing
M = N ÷ 4
x1s = rand(M) * 5
x2s = rand(M) * 5
xt1s = Array([[x1s[i]; x2s[i]] for i = 1:M])
append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i = 1:M]))
xt0s = Array([[x1s[i]; x2s[i] - 6] for i = 1:M])
append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i = 1:M]))
xt1s = Array([[x1s[i]; x2s[i]] for i in 1:M])
append!(xt1s, Array([[x1s[i] - 6; x2s[i] - 6] for i in 1:M]))
xt0s = Array([[x1s[i]; x2s[i] - 6] for i in 1:M])
append!(xt0s, Array([[x1s[i] - 6; x2s[i]] for i in 1:M]))

xs = [xt1s; xt0s]
ts = [ones(M); ones(M); zeros(M); zeros(M)]
Expand All @@ -106,20 +105,22 @@ using Turing
var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior

@model function bnn(ts)
b1 ~ MvNormal([0. ;0.; 0.],
[var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior])
w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior])
w12 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior])
w13 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior])
b1 ~ MvNormal(
[0.0; 0.0; 0.0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
)
w11 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
w12 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
w13 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior])
bo ~ Normal(0, var_prior)

wo ~ MvNormal([0.; 0; 0],
[var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior])
for i = rand(1:N, 10)
wo ~ MvNormal(
[0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior]
)
for i in rand(1:N, 10)
y = nn(xs[i], b1, w11, w12, w13, bo, wo)
ts[i] ~ Bernoulli(y)
end
b1, w11, w12, w13, bo, wo
return b1, w11, w12, w13, bo, wo
end

# Sampling
Expand Down Expand Up @@ -147,7 +148,7 @@ using Turing
Random.seed!(12345) # particle samplers do not support user-provided `rng` yet
alg3 = Gibbs(PG(20, :s), HMCDA(500, 0.8, 0.25, :m; init_ϵ=0.05, adtype=adbackend))

res3 = sample(rng, gdemo_default, alg3, 3000, discard_initial=1000)
res3 = sample(rng, gdemo_default, alg3, 3000; discard_initial=1000)
check_gdemo(res3)
end

Expand Down Expand Up @@ -191,8 +192,8 @@ using Turing
@testset "check discard" begin
alg = NUTS(100, 0.8; adtype=adbackend)

c1 = sample(rng, gdemo_default, alg, 500, discard_adapt=true)
c2 = sample(rng, gdemo_default, alg, 500, discard_adapt=false)
c1 = sample(rng, gdemo_default, alg, 500; discard_adapt=true)
c2 = sample(rng, gdemo_default, alg, 500; discard_adapt=false)

@test size(c1, 1) == 500
@test size(c2, 1) == 500
Expand All @@ -210,20 +211,20 @@ using Turing
# https://github.com/TuringLang/DynamicPPL.jl/issues/27
@model function mwe1(::Type{T}=Float64) where {T<:Real}
m = Matrix{T}(undef, 2, 3)
m .~ MvNormal(zeros(2), I)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe1(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains

@model function mwe2(::Type{T}=Matrix{Float64}) where {T}
m = T(undef, 2, 3)
m .~ MvNormal(zeros(2), I)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe2(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains

# https://github.com/TuringLang/Turing.jl/issues/1308
@model function mwe3(::Type{T}=Array{Float64}) where {T}
m = T(undef, 2, 3)
m .~ MvNormal(zeros(2), I)
return m .~ MvNormal(zeros(2), I)
end
@test sample(rng, mwe3(), HMC(0.2, 4; adtype=adbackend), 1_000) isa Chains
end
Expand All @@ -241,13 +242,17 @@ using Turing
@model function demo_hmc_prior()
# NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
# which means that it's _very_ difficult to find a good tolerance in the test below:)
s ~ truncated(Normal(3, 1), lower=0)
m ~ Normal(0, sqrt(s))
s ~ truncated(Normal(3, 1); lower=0)
return m ~ Normal(0, sqrt(s))
end
alg = NUTS(1000, 0.8; adtype=adbackend)
gdemo_default_prior = DynamicPPL.contextualize(demo_hmc_prior(), DynamicPPL.PriorContext())
gdemo_default_prior = DynamicPPL.contextualize(
demo_hmc_prior(), DynamicPPL.PriorContext()
)
chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0])
check_numerical(chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0], atol=0.2)
check_numerical(
chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0]; atol=0.2
)
end

@testset "warning for difficult init params" begin
Expand All @@ -262,7 +267,7 @@ using Turing
@test_logs (
:warn,
"failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword",
) (:info,) match_mode=:any begin
) (:info,) match_mode = :any begin
sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5)
end
end
Expand All @@ -271,7 +276,7 @@ using Turing
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
return xs[2] ~ Dirichlet(ones(5))
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
Expand All @@ -296,15 +301,10 @@ using Turing
end
end

model = buggy_model();
num_samples = 1_000;
model = buggy_model()
num_samples = 1_000

chain = sample(
model,
NUTS(),
num_samples;
initial_params=[0.5, 1.75, 1.0]
)
chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0])
chain_prior = sample(model, Prior(), num_samples)

# Extract the `x` like this because running `generated_quantities` was how
Expand Down
10 changes: 5 additions & 5 deletions test/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using Turing

alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adbackend)
chain = sample(rng, gdemo_default, alg, 10_000)
check_gdemo(chain, atol=0.1)
check_gdemo(chain; atol=0.1)
end
end

Expand All @@ -58,15 +58,15 @@ end
@testset "sgld inference" begin
rng = StableRNG(1)

chain = sample(rng, gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 20_000)
check_gdemo(chain, atol = 0.2)
chain = sample(rng, gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000)
check_gdemo(chain; atol=0.2)

# Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh)
v = get(chain, [:SGLD_stepsize, :s, :m])
s_weighted = dot(v.SGLD_stepsize, v.s) / sum(v.SGLD_stepsize)
m_weighted = dot(v.SGLD_stepsize, v.m) / sum(v.SGLD_stepsize)
@test s_weighted ≈ 49/24 atol=0.2
@test m_weighted ≈ 7/6 atol=0.2
@test s_weighted ≈ 49 / 24 atol = 0.2
@test m_weighted ≈ 7 / 6 atol = 0.2
end
end

Expand Down
Loading