Skip to content

Commit b511385

Browse files
authored
Fix various methods for on empty varinfo (#2561)
* Fix getparams on empty varinfo * Fix two methods in MH for empty tuples * Add a test to sample an empty model
1 parent c38abda commit b511385

File tree

5 files changed

+43
-1
lines changed

5 files changed

+43
-1
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Release 0.38.3
2+
3+
`getparams(::Model, ::AbstractVarInfo)` now returns an empty `Float64` if the VarInfo contains no parameters.
4+
15
# Release 0.38.2
26

37
Bump compat for `MCMCChains` to `7`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.38.2"
3+
version = "0.38.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/Inference.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
358358
# Materialize the iterators and concatenate.
359359
return mapreduce(collect, vcat, iters)
360360
end
361+
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
362+
return float(Real)[]
363+
end
361364

362365
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
363366
names_set = OrderedSet{VarName}()

src/mcmc/mh.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ end
246246
]
247247
return expr
248248
end
249+
_val_tuple(::VarInfo, ::Tuple{}) = ()
249250

250251
@generated function _dist_tuple(
251252
props::NamedTuple{propnames}, vi::VarInfo, vns::NamedTuple{names}
@@ -267,6 +268,7 @@ end
267268
]
268269
return expr
269270
end
271+
_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = ()
270272

271273
# Utility functions to link
272274
should_link(varinfo, sampler, proposal) = false

test/mcmc/Inference.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,39 @@ using Turing
631631
StableRNG(seed), demo_incorrect_missing([missing]), NUTS(), 10; check_model=true
632632
)
633633
end
634+
635+
@testset "getparams" begin
636+
@model function e(x=1.0)
637+
return x ~ Normal()
638+
end
639+
evi = Turing.VarInfo(e())
640+
@test isempty(Turing.Inference.getparams(e(), evi))
641+
642+
@model function f()
643+
return x ~ Normal()
644+
end
645+
fvi = Turing.VarInfo(f())
646+
@test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)])
647+
648+
@model function g()
649+
x ~ Normal()
650+
return y ~ Poisson()
651+
end
652+
gvi = Turing.VarInfo(g())
653+
gparams = Turing.Inference.getparams(g(), gvi)
654+
@test gparams[1] == (@varname(x), gvi[@varname(x)])
655+
@test gparams[2] == (@varname(y), gvi[@varname(y)])
656+
@test length(gparams) == 2
657+
end
658+
659+
@testset "empty model" begin
660+
@model function e(x=1.0)
661+
return x ~ Normal()
662+
end
663+
# Can't test with HMC/NUTS because some AD backends error; see
664+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802
665+
@test sample(e(), IS(), 100) isa MCMCChains.Chains
666+
end
634667
end
635668

636669
end

0 commit comments

Comments
 (0)