From 0a4e29830bad9b762e71057031467894e8581091 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 25 Sep 2024 15:52:37 +0200 Subject: [PATCH 1/3] theta -> flattened_param_vals --- src/sampler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index cfc58942e..833aaf7e2 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -150,7 +150,7 @@ function set_values!!( flattened_param_vals = varinfo[spl] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))", + "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))", ), ) From 048d4819eca8994e4e0f11053cff437d5d229004 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Wed, 25 Sep 2024 21:56:51 +0200 Subject: [PATCH 2/3] add unittest --- test/sampler.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/sampler.jl b/test/sampler.jl index b29d3caf1..4c7dd9a0c 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -162,6 +162,7 @@ chain1 = sample(model, sampler, 1; progress=false) Random.seed!(1234) chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) + @test_throws DimensionMismatch sample(model, sampler, 1; progress=false, initial_params=zeros(10)) @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals From fe0e5edf432e3d26a5ce8607863206b22e5248c3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 26 Sep 2024 10:51:29 +0100 Subject: [PATCH 3/3] Update test/sampler.jl --- test/sampler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/sampler.jl b/test/sampler.jl index 4c7dd9a0c..95e838167 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -162,7 +162,9 @@ chain1 = sample(model, sampler, 1; progress=false) Random.seed!(1234) chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample(model, sampler, 1; progress=false, initial_params=zeros(10)) + @test_throws DimensionMismatch sample( + model, sampler, 1; progress=false, initial_params=zeros(10) + ) @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals