Skip to content

Commit 986d016

Browse files
committed
Merge remote-tracking branch 'origin/main' into py/no-repeated-adtype
2 parents 0d2955f + 411a341 commit 986d016

File tree

5 files changed

+43
-1
lines changed

5 files changed

+43
-1
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Release 0.38.3
2+
3+
`getparams(::Model, ::AbstractVarInfo)` now returns an empty `Float64[]` if the VarInfo contains no parameters.
4+
15
# Release 0.38.2
26

37
Bump compat for `MCMCChains` to `7`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.38.2"
3+
version = "0.38.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/Inference.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
358358
# Materialize the iterators and concatenate.
359359
return mapreduce(collect, vcat, iters)
360360
end
361+
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
362+
return float(Real)[]
363+
end
361364

362365
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
363366
names_set = OrderedSet{VarName}()

src/mcmc/mh.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ end
246246
]
247247
return expr
248248
end
249+
_val_tuple(::VarInfo, ::Tuple{}) = ()
249250

250251
@generated function _dist_tuple(
251252
props::NamedTuple{propnames}, vi::VarInfo, vns::NamedTuple{names}
@@ -267,6 +268,7 @@ end
267268
]
268269
return expr
269270
end
271+
_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = ()
270272

271273
# Utility functions to link
272274
should_link(varinfo, sampler, proposal) = false

test/mcmc/Inference.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,39 @@ using Turing
628628
StableRNG(seed), demo_incorrect_missing([missing]), NUTS(), 10; check_model=true
629629
)
630630
end
631+
632+
@testset "getparams" begin
633+
@model function e(x=1.0)
634+
return x ~ Normal()
635+
end
636+
evi = Turing.VarInfo(e())
637+
@test isempty(Turing.Inference.getparams(e(), evi))
638+
639+
@model function f()
640+
return x ~ Normal()
641+
end
642+
fvi = Turing.VarInfo(f())
643+
@test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)])
644+
645+
@model function g()
646+
x ~ Normal()
647+
return y ~ Poisson()
648+
end
649+
gvi = Turing.VarInfo(g())
650+
gparams = Turing.Inference.getparams(g(), gvi)
651+
@test gparams[1] == (@varname(x), gvi[@varname(x)])
652+
@test gparams[2] == (@varname(y), gvi[@varname(y)])
653+
@test length(gparams) == 2
654+
end
655+
656+
@testset "empty model" begin
657+
@model function e(x=1.0)
658+
return x ~ Normal()
659+
end
660+
# Can't test with HMC/NUTS because some AD backends error; see
661+
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802
662+
@test sample(e(), IS(), 100) isa MCMCChains.Chains
663+
end
631664
end
632665

633666
end

0 commit comments

Comments
 (0)