Skip to content

Commit e85f821

Browse files
authored
Merge pull request #127 from TuringLang/dev
Update master
2 parents a673228 + 33be8a0 commit e85f821

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
33
authors = ["mohamed82008 <[email protected]>"]
4-
version = "0.7.3"
4+
version = "0.8.0"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/varinfo.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ end
766766

767767

768768
"""
769-
islinked(vi::VarInfo, spl::Sampler)
769+
islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior})
770770
771771
Check whether `vi` is in the transformed space for a particular sampler `spl`.
772772
@@ -775,11 +775,11 @@ Turing's Hamiltonian samplers use the `link` and `invlink` functions from
775775
(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of
776776
real numbers. `islinked` checks if the number is in the constrained space or the real space.
777777
"""
778-
function islinked(vi::UntypedVarInfo, spl::Sampler)
778+
function islinked(vi::UntypedVarInfo, spl::Union{Sampler, SampleFromPrior})
779779
vns = _getvns(vi, spl)
780780
return istrans(vi, vns[1])
781781
end
782-
function islinked(vi::TypedVarInfo, spl::Sampler)
782+
function islinked(vi::TypedVarInfo, spl::Union{Sampler, SampleFromPrior})
783783
vns = _getvns(vi, spl)
784784
return _islinked(vi, vns)
785785
end
@@ -956,16 +956,21 @@ function _show_varnames(io::IO, vi)
956956
md = vi.metadata
957957
vns = md.vns
958958

959-
groups = Dict{Symbol, Vector{VarName}}()
959+
vns_by_name = Dict{Symbol, Vector{VarName}}()
960960
for vn in vns
961-
group = get!(() -> Vector{VarName}(), groups, getsym(vn))
961+
group = get!(() -> Vector{VarName}(), vns_by_name, getsym(vn))
962962
push!(group, vn)
963963
end
964964

965-
print(io, length(groups), length(groups) == 1 ? " variable " : " variables ", "(")
966-
join(io, Iterators.take(keys(groups), _MAX_VARS_SHOWN), ", ")
967-
length(groups) > _MAX_VARS_SHOWN && print(io, ", ...")
968-
print(io, "), dimension ", sum(prod(size(md.vals[md.ranges[md.idcs[vn]]])) for vn in vns))
965+
L = length(vns_by_name)
966+
if L == 0
967+
print(io, "0 variables, dimension 0")
968+
else
969+
(L == 1) ? print(io, "1 variable (") : print(io, L, " variables (")
970+
join(io, Iterators.take(keys(vns_by_name), _MAX_VARS_SHOWN), ", ")
971+
(L > _MAX_VARS_SHOWN) && print(io, ", ...")
972+
print(io, "), dimension ", length(md.vals))
973+
end
969974
end
970975

971976
function Base.show(io::IO, vi::UntypedVarInfo)

0 commit comments

Comments
 (0)