Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Release 0.38.2

`getparams(::Model, ::AbstractVarInfo)` now returns an empty `Float64` if the VarInfo contains no parameters.

# Release 0.38.1

The method `Bijectors.bijector(::DynamicPPL.Model)` was moved to DynamicPPL.jl.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.38.1"
version = "0.38.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 3 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# Materialize the iterators and concatenate.
return mapreduce(collect, vcat, iters)
end
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
return float(Real)[]
end

function _params_to_array(model::DynamicPPL.Model, ts::Vector)
names_set = OrderedSet{VarName}()
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
]
return expr
end
_val_tuple(::VarInfo, ::Tuple{}) = ()

Check warning on line 249 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L249

Added line #L249 was not covered by tests

@generated function _dist_tuple(
props::NamedTuple{propnames}, vi::VarInfo, vns::NamedTuple{names}
Expand All @@ -267,6 +268,7 @@
]
return expr
end
_dist_tuple(::@NamedTuple{}, ::VarInfo, ::Tuple{}) = ()

Check warning on line 271 in src/mcmc/mh.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/mh.jl#L271

Added line #L271 was not covered by tests

# Utility functions to link
should_link(varinfo, sampler, proposal) = false
Expand Down
33 changes: 33 additions & 0 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,39 @@ using Turing
StableRNG(seed), demo_incorrect_missing([missing]), NUTS(), 10; check_model=true
)
end

@testset "getparams" begin
@model function e(x=1.0)
return x ~ Normal()
end
evi = Turing.VarInfo(e())
Comment on lines +635 to +639
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware that these tests are being run once-per-adtype, and there's really no need to do so. But this is true of quite a number of tests in this file and so I thought that was a bigger problem that I didn't want to deal with in this PR.

@test isempty(Turing.Inference.getparams(e(), evi))

@model function f()
return x ~ Normal()
end
fvi = Turing.VarInfo(f())
@test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)])

@model function g()
x ~ Normal()
return y ~ Poisson()
end
gvi = Turing.VarInfo(g())
gparams = Turing.Inference.getparams(g(), gvi)
@test gparams[1] == (@varname(x), gvi[@varname(x)])
@test gparams[2] == (@varname(y), gvi[@varname(y)])
@test length(gparams) == 2
end

@testset "empty model" begin
@model function e(x=1.0)
return x ~ Normal()
end
# Can't test with HMC/NUTS because some AD backends error; see
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/802
@test sample(e(), IS(), 100) isa MCMCChains.Chains
end
end

end
Loading