diff --git a/HISTORY.md b/HISTORY.md index 42613d08f..d1f8c2ba5 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,11 @@ # DynamicPPL Changelog +## 0.36.5 + +`varinfo[:]` now returns an empty vector if `varinfo::DynamicPPL.NTVarInfo` is empty, rather than erroring. + +In its place, `check_model` now issues a warning if the model is empty. + ## 0.36.4 Added compatibility with DifferentiationInterface.jl 0.7, and also with JET.jl 0.10. diff --git a/Project.toml b/Project.toml index 6b3f445e3..b9253d2a5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.4" +version = "0.36.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 15ef8fb01..754b344ee 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -338,6 +338,11 @@ function conditioned_varnames(context) end function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int}) + if isempty(varnames_seen) + @warn "The model does not contain any parameters." + return true + end + issuccess = true for (varname, count) in varnames_seen if count == 0 @@ -416,6 +421,8 @@ julia> print(trace) assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356) julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); +┌ Warning: The model does not contain any parameters. +└ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342 julia> issuccess true diff --git a/src/varinfo.jl b/src/varinfo.jl index 360857ef7..bc59c67a6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -854,6 +854,9 @@ getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon() function getindex_internal(vi::NTVarInfo, ::Colon) return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) end +function getindex_internal(vi::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) + return float(Real)[] +end function getindex_internal(md::Metadata, ::Colon) return mapreduce( Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) diff --git a/test/varinfo.jl b/test/varinfo.jl index 777917aa6..7439869cf 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -78,8 +78,8 @@ end @test vn2 == vn1 @test hash(vn2) == hash(vn1) - function test_base!!(vi_original) - vi = empty!!(vi_original) + function test_base(vi_original) + vi = deepcopy(vi_original) @test getlogp(vi) == 0 @test isempty(vi[:]) @@ -97,8 +97,10 @@ end @test length(vi[vn]) == 1 @test vi[vn] == r + @test vi[:] == [r] vi = DynamicPPL.setindex!!(vi, 2 * r, vn) @test vi[vn] == 2 * r + @test vi[:] == [2 * r] # TODO(mhauru) Implement these functions for other VarInfo types too. if vi isa DynamicPPL.UntypedVectorVarInfo @@ -113,12 +115,11 @@ end @test ~isempty(vi) end - vi = VarInfo() - test_base!!(vi) - test_base!!(DynamicPPL.typed_varinfo(vi)) - test_base!!(SimpleVarInfo()) - test_base!!(SimpleVarInfo(Dict())) - test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + test_base(VarInfo()) + test_base(DynamicPPL.typed_varinfo(VarInfo())) + test_base(SimpleVarInfo()) + test_base(SimpleVarInfo(Dict())) + test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @testset "get/set/acc/resetlogp" begin