Skip to content

Commit 9d4137e

Browse files
committed
Small addition to link! and invlink! (#231)
Makes it possible to only link/invlink a subset of the space of a sampler *for static models only* (as there's no equivalent for `UntypedVarInfo`). Is non-breaking since this is just adding an intermediate method to an existing implementation, giving increased flexbility. Probably don't want to encourage use of this, but it can be useful in cases such as the MH sampler TuringLang/Turing.jl#1582 (comment).
1 parent 7a45694 commit 9d4137e

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.14"
3+
version = "0.10.15"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/varinfo.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,11 @@ function link!(vi::UntypedVarInfo, spl::Sampler)
722722
end
723723
end
724724
function link!(vi::TypedVarInfo, spl::AbstractSampler)
725+
return link!(vi, spl, Val(getspace(spl)))
726+
end
727+
function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val)
725728
vns = _getvns(vi, spl)
726-
return _link!(vi.metadata, vi, vns, Val(getspace(spl)))
729+
return _link!(vi.metadata, vi, vns, spaceval)
727730
end
728731
@generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space}
729732
expr = Expr(:block)
@@ -770,8 +773,11 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler)
770773
end
771774
end
772775
function invlink!(vi::TypedVarInfo, spl::AbstractSampler)
776+
return invlink!(vi, spl, Val(getspace(spl)))
777+
end
778+
function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val)
773779
vns = _getvns(vi, spl)
774-
return _invlink!(vi.metadata, vi, vns, Val(getspace(spl)))
780+
return _invlink!(vi.metadata, vi, vns, spaceval)
775781
end
776782
@generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space}
777783
expr = Expr(:block)

test/turing/varinfo.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@
6767
@test all(x -> !istrans(vi, x), meta.m.vns)
6868
@test meta.s.vals == v_s
6969
@test meta.m.vals == v_m
70+
71+
# Transforming only a subset of the variables
72+
link!(vi, spl, Val((:m, )))
73+
@test all(x -> !istrans(vi, x), meta.s.vns)
74+
@test all(x -> istrans(vi, x), meta.m.vns)
75+
invlink!(vi, spl, Val((:m, )))
76+
@test all(x -> !istrans(vi, x), meta.s.vns)
77+
@test all(x -> !istrans(vi, x), meta.m.vns)
78+
@test meta.s.vals == v_s
79+
@test meta.m.vals == v_m
7080
end
7181
@testset "orders" begin
7282
csym = gensym() # unique per model
@@ -329,4 +339,4 @@
329339
@test vi.metadata.w.gids[1] == Set([hmc.selector])
330340
@test vi.metadata.u.gids[1] == Set([hmc.selector]) =#
331341
end
332-
end
342+
end

0 commit comments

Comments
 (0)