Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Release 0.38.3

`getparams(::Model, ::AbstractVarInfo)` now returns an empty `Float64` if the VarInfo contains no parameters.

# Release 0.38.2

Bump compat for `MCMCChains` to `7`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.38.2"
version = "0.38.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 3 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# Materialize the iterators and concatenate.
return mapreduce(collect, vcat, iters)
end
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
return float(Real)[]
end

function _params_to_array(model::DynamicPPL.Model, ts::Vector)
names_set = OrderedSet{VarName}()
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
]
return expr
end
_val_tuple(::VarInfo, ::Tuple{}) = ()

Check warning on line 249 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L249

Added line #L249 was not covered by tests

@generated function _dist_tuple(
props::NamedTuple{propnames}, vi::VarInfo, vns::NamedTuple{names}
Expand All @@ -267,6 +268,7 @@
]
return expr
end
_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = ()

Check warning on line 271 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L271

Added line #L271 was not covered by tests

# Utility functions to link
should_link(varinfo, sampler, proposal) = false
Expand Down
33 changes: 33 additions & 0 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,39 @@ using Turing
StableRNG(seed), demo_incorrect_missing([missing]), NUTS(), 10; check_model=true
)
end

@testset "getparams" begin
@model function e(x=1.0)
return x ~ Normal()
end
evi = Turing.VarInfo(e())
Comment on lines +635 to +639
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware that these tests are being run once-per-adtype, and there's really no need to do so. But this is true of quite a number of tests in this file and so I thought that was a bigger problem that I didn't want to deal with in this PR.

@test isempty(Turing.Inference.getparams(e(), evi))

@model function f()
return x ~ Normal()
end
fvi = Turing.VarInfo(f())
@test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)])

@model function g()
x ~ Normal()
return y ~ Poisson()
end
gvi = Turing.VarInfo(g())
gparams = Turing.Inference.getparams(g(), gvi)
@test gparams[1] == (@varname(x), gvi[@varname(x)])
@test gparams[2] == (@varname(y), gvi[@varname(y)])
@test length(gparams) == 2
end

@testset "empty model" begin
@model function e(x=1.0)
return x ~ Normal()
end
# Can't test with HMC/NUTS because some AD backends error; see
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802
@test sample(e(), IS(), 100) isa MCMCChains.Chains
end
end

end
Loading