766
766
767
767
768
768
"""
769
- islinked(vi::VarInfo, spl::Sampler)
769
+ islinked(vi::VarInfo, spl::Union{ Sampler, SampleFromPrior} )
770
770
771
771
Check whether `vi` is in the transformed space for a particular sampler `spl`.
772
772
@@ -775,11 +775,11 @@ Turing's Hamiltonian samplers use the `link` and `invlink` functions from
775
775
(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of
776
776
real numbers. `islinked` checks if the number is in the constrained space or the real space.
777
777
"""
778
- function islinked (vi:: UntypedVarInfo , spl:: Sampler )
778
+ function islinked (vi:: UntypedVarInfo , spl:: Union{ Sampler, SampleFromPrior} )
779
779
vns = _getvns (vi, spl)
780
780
return istrans (vi, vns[1 ])
781
781
end
782
- function islinked (vi:: TypedVarInfo , spl:: Sampler )
782
+ function islinked (vi:: TypedVarInfo , spl:: Union{ Sampler, SampleFromPrior} )
783
783
vns = _getvns (vi, spl)
784
784
return _islinked (vi, vns)
785
785
end
@@ -956,16 +956,21 @@ function _show_varnames(io::IO, vi)
956
956
md = vi. metadata
957
957
vns = md. vns
958
958
959
- groups = Dict {Symbol, Vector{VarName}} ()
959
+ vns_by_name = Dict {Symbol, Vector{VarName}} ()
960
960
for vn in vns
961
- group = get! (() -> Vector {VarName} (), groups , getsym (vn))
961
+ group = get! (() -> Vector {VarName} (), vns_by_name , getsym (vn))
962
962
push! (group, vn)
963
963
end
964
964
965
- print (io, length (groups), length (groups) == 1 ? " variable " : " variables " , " (" )
966
- join (io, Iterators. take (keys (groups), _MAX_VARS_SHOWN), " , " )
967
- length (groups) > _MAX_VARS_SHOWN && print (io, " , ..." )
968
- print (io, " ), dimension " , sum (prod (size (md. vals[md. ranges[md. idcs[vn]]])) for vn in vns))
965
+ L = length (vns_by_name)
966
+ if L == 0
967
+ print (io, " 0 variables, dimension 0" )
968
+ else
969
+ (L == 1 ) ? print (io, " 1 variable (" ) : print (io, L, " variables (" )
970
+ join (io, Iterators. take (keys (vns_by_name), _MAX_VARS_SHOWN), " , " )
971
+ (L > _MAX_VARS_SHOWN) && print (io, " , ..." )
972
+ print (io, " ), dimension " , length (md. vals))
973
+ end
969
974
end
970
975
971
976
function Base. show (io:: IO , vi:: UntypedVarInfo )
0 commit comments