diff --git a/HISTORY.md b/HISTORY.md index ff800b0f4d..33b0823d9b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# Release 0.38.3 + +`getparams(::Model, ::AbstractVarInfo)` now returns an empty `Float64` if the VarInfo contains no parameters. + # Release 0.38.2 Bump compat for `MCMCChains` to `7`. diff --git a/Project.toml b/Project.toml index 8311e5e0fc..6dbd03e405 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.38.2" +version = "0.38.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0cbb45b48f..cdc9e2570c 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -358,6 +358,9 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) # Materialize the iterators and concatenate. return mapreduce(collect, vcat, iters) end +function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}}) + return float(Real)[] +end function _params_to_array(model::DynamicPPL.Model, ts::Vector) names_set = OrderedSet{VarName}() diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 6a03f5359f..3450cfdc74 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -246,6 +246,7 @@ end ] return expr end +_val_tuple(::VarInfo, ::Tuple{}) = () @generated function _dist_tuple( props::NamedTuple{propnames}, vi::VarInfo, vns::NamedTuple{names} @@ -267,6 +268,7 @@ end ] return expr end +_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = () # Utility functions to link should_link(varinfo, sampler, proposal) = false diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 5966589090..55d989b154 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -631,6 +631,39 @@ using Turing StableRNG(seed), demo_incorrect_missing([missing]), NUTS(), 10; check_model=true ) end + + @testset "getparams" begin + @model function e(x=1.0) + return x ~ Normal() + end + evi = Turing.VarInfo(e()) + @test isempty(Turing.Inference.getparams(e(), evi)) + + @model function f() + return x ~ Normal() + end + fvi = Turing.VarInfo(f()) + @test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)]) + + @model function g() + x ~ Normal() + return y ~ Poisson() + end + gvi = Turing.VarInfo(g()) + gparams = Turing.Inference.getparams(g(), gvi) + @test gparams[1] == (@varname(x), gvi[@varname(x)]) + @test gparams[2] == (@varname(y), gvi[@varname(y)]) + @test length(gparams) == 2 + end + + @testset "empty model" begin + @model function e(x=1.0) + return x ~ Normal() + end + # Can't test with HMC/NUTS because some AD backends error; see + # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802 + @test sample(e(), IS(), 100) isa MCMCChains.Chains + end end end