@@ -902,33 +902,59 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f)
902
902
return vi
903
903
end
904
904
905
+ # HACK: We need `SampleFromPrior` to result in ALL values which are in need
906
+ # of a transformation to be transformed. `_getvns` will by default return
907
+ # an empty iterable for `SampleFromPrior`, so we need to override it here.
908
+ # This is quite hacky, but seems safer than changing the behavior of `_getvns`.
909
+ _getvns_link (varinfo:: VarInfo , spl:: AbstractSampler ) = _getvns (varinfo, spl)
910
+ _getvns_link (varinfo:: UntypedVarInfo , spl:: SampleFromPrior ) = nothing
911
+ function _getvns_link (varinfo:: TypedVarInfo , spl:: SampleFromPrior )
912
+ return map (Returns (nothing ), varinfo. metadata)
913
+ end
914
+
905
915
function link (:: DynamicTransformation , varinfo:: VarInfo , spl:: AbstractSampler , model:: Model )
906
- return _link (varinfo)
916
+ return _link (varinfo, spl )
907
917
end
908
918
909
- function _link (varinfo:: UntypedVarInfo )
919
+ function _link (varinfo:: UntypedVarInfo , spl :: AbstractSampler )
910
920
varinfo = deepcopy (varinfo)
911
921
return VarInfo (
912
- _link_metadata! (varinfo, varinfo. metadata),
922
+ _link_metadata! (varinfo, varinfo. metadata, _getvns_link (varinfo, spl) ),
913
923
Base. Ref (getlogp (varinfo)),
914
924
Ref (get_num_produce (varinfo)),
915
925
)
916
926
end
917
927
918
- function _link (varinfo:: TypedVarInfo )
928
+ function _link (varinfo:: TypedVarInfo , spl :: AbstractSampler )
919
929
varinfo = deepcopy (varinfo)
920
- md = map (Base. Fix1 (_link_metadata!, varinfo), varinfo. metadata)
921
- # TODO : Update logp, etc.
930
+ md = _link_metadata_namedtuple! (
931
+ varinfo, varinfo. metadata, _getvns_link (varinfo, spl), Val (getspace (spl))
932
+ )
922
933
return VarInfo (md, Base. Ref (getlogp (varinfo)), Ref (get_num_produce (varinfo)))
923
934
end
924
935
925
- function _link_metadata! (varinfo:: VarInfo , metadata:: Metadata )
936
+ @generated function _link_metadata_namedtuple! (
937
+ varinfo:: VarInfo , metadata:: NamedTuple{names} , vns:: NamedTuple , :: Val{space}
938
+ ) where {names,space}
939
+ vals = Expr (:tuple )
940
+ for f in names
941
+ if inspace (f, space) || length (space) == 0
942
+ push! (vals. args, :(_link_metadata! (varinfo, metadata.$ f, vns.$ f)))
943
+ else
944
+ push! (vals. args, :(metadata.$ f))
945
+ end
946
+ end
947
+
948
+ return :(NamedTuple {$names} ($ vals))
949
+ end
950
+ function _link_metadata! (varinfo:: VarInfo , metadata:: Metadata , target_vns)
926
951
vns = metadata. vns
927
952
928
953
# Construct the new transformed values, and keep track of their lengths.
929
954
vals_new = map (vns) do vn
930
955
# Return early if we're already in unconstrained space.
931
- if istrans (varinfo, vn)
956
+ # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
957
+ if istrans (varinfo, vn) || (target_vns != = nothing && vn ∉ target_vns)
932
958
return metadata. vals[getrange (metadata, vn)]
933
959
end
934
960
@@ -972,32 +998,49 @@ end
972
998
function invlink (
973
999
:: DynamicTransformation , varinfo:: VarInfo , spl:: AbstractSampler , model:: Model
974
1000
)
975
- return _invlink (varinfo)
1001
+ return _invlink (varinfo, spl )
976
1002
end
977
1003
978
- function _invlink (varinfo:: UntypedVarInfo )
1004
+ function _invlink (varinfo:: UntypedVarInfo , spl :: AbstractSampler )
979
1005
varinfo = deepcopy (varinfo)
980
1006
return VarInfo (
981
- _invlink_metadata! (varinfo, varinfo. metadata),
1007
+ _invlink_metadata! (varinfo, varinfo. metadata, _getvns_link (varinfo, spl) ),
982
1008
Base. Ref (getlogp (varinfo)),
983
1009
Ref (get_num_produce (varinfo)),
984
1010
)
985
1011
end
986
1012
987
- function _invlink (varinfo:: TypedVarInfo )
1013
+ function _invlink (varinfo:: TypedVarInfo , spl :: AbstractSampler )
988
1014
varinfo = deepcopy (varinfo)
989
- md = map (Base. Fix1 (_invlink_metadata!, varinfo), varinfo. metadata)
990
- # TODO : Update logp, etc.
1015
+ md = _invlink_metadata_namedtuple! (
1016
+ varinfo, varinfo. metadata, _getvns_link (varinfo, spl), Val (getspace (spl))
1017
+ )
991
1018
return VarInfo (md, Base. Ref (getlogp (varinfo)), Ref (get_num_produce (varinfo)))
992
1019
end
993
1020
994
- function _invlink_metadata! (varinfo:: VarInfo , metadata:: Metadata )
1021
+ @generated function _invlink_metadata_namedtuple! (
1022
+ varinfo:: VarInfo , metadata:: NamedTuple{names} , vns:: NamedTuple , :: Val{space}
1023
+ ) where {names,space}
1024
+ vals = Expr (:tuple )
1025
+ for f in names
1026
+ if inspace (f, space) || length (space) == 0
1027
+ push! (vals. args, :(_invlink_metadata! (varinfo, metadata.$ f, vns.$ f)))
1028
+ else
1029
+ push! (vals. args, :(metadata.$ f))
1030
+ end
1031
+ end
1032
+
1033
+ return :(NamedTuple {$names} ($ vals))
1034
+ end
1035
+ function _invlink_metadata! (varinfo:: VarInfo , metadata:: Metadata , target_vns)
995
1036
vns = metadata. vns
996
1037
997
1038
# Construct the new transformed values, and keep track of their lengths.
998
1039
vals_new = map (vns) do vn
999
- # Return early if we're already in constrained space.
1000
- if ! istrans (varinfo, vn)
1040
+ # Return early if we're already in constrained space OR if we're not
1041
+ # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler.
1042
+ # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
1043
+ if ! istrans (varinfo, vn) || (target_vns != = nothing && vn ∉ target_vns)
1001
1044
return metadata. vals[getrange (metadata, vn)]
1002
1045
end
1003
1046
0 commit comments