@@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities(
108108    varinfo =  DynamicPPL. VarInfo (model)
109109    iters =  Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
110110    return  map (iters) do  (sample_idx, chain_idx)
111-         if  DynamicPPL. supports_varname_indexing (chain)
112-             varname_pairs =  _varname_pairs_with_varname_indexing (
113-                 chain, varinfo, sample_idx, chain_idx
114-             )
115-         else 
116-             varname_pairs =  _varname_pairs_without_varname_indexing (
117-                 chain, varinfo, sample_idx, chain_idx
118-             )
119-         end 
120-         fixed_model =  DynamicPPL. fix (model, Dict (varname_pairs))
121-         return  fixed_model ()
111+         #  TODO : Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
112+         #  Update the varinfo with the current sample and make variables not present in `chain`
113+         #  to be sampled.
114+         DynamicPPL. setval_and_resample! (varinfo, chain, sample_idx, chain_idx)
115+         #  NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
116+         #  `deepcopy` the `varinfo` before passing it to the `model`.
117+         model (deepcopy (varinfo))
122118    end 
123119end 
124120
125- """ 
126-     _varname_pairs_with_varname_indexing( 
127-         chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx 
128-     ) 
129- 
130- Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values 
131- from the chain. 
132- 
133- This implementation assumes `chain` can be indexed using variable names, and is the 
134- preffered implementation. 
135- """ 
136- function  _varname_pairs_with_varname_indexing (
137-     chain:: MCMCChains.Chains , varinfo, sample_idx, chain_idx
138- )
139-     vns =  DynamicPPL. varnames (chain)
140-     vn_parents =  Iterators. map (vns) do  vn
141-         #  The call nested_setindex_maybe! is used to handle cases where vn is not
142-         #  the variable name used in the model, but rather subsumed by one. Except
143-         #  for the subsumption part, this could be
144-         #  vn => getindex_varname(chain, sample_idx, vn, chain_idx)
145-         #  TODO (mhauru) This call to nested_setindex_maybe! is unintuitive.
146-         DynamicPPL. nested_setindex_maybe! (
147-             varinfo, DynamicPPL. getindex_varname (chain, sample_idx, vn, chain_idx), vn
148-         )
149-     end 
150-     varname_pairs =  Iterators. map (Iterators. filter (! isnothing, vn_parents)) do  vn_parent
151-         vn_parent =>  varinfo[vn_parent]
152-     end 
153-     return  varname_pairs
154- end 
155- 
156- """ 
157- Check which keys in `key_strings` are subsumed by `vn_string` and return the their values. 
158- 
159- The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and 
160- won't catch all cases. We should get rid of this if we can. 
161- """ 
162- #  TODO (mhauru) See docstring above.
163- function  _vcat_subsumed_values (vn_string, values, key_strings)
164-     indices =  findall (Base. Fix1 (DynamicPPL. subsumes_string, vn_string), key_strings)
165-     return  ! isempty (indices) ?  reduce (vcat, values[indices]) :  nothing 
166- end 
167- 
168- """ 
169-     _varname_pairs_without_varname_indexing( 
170-         chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx 
171-     ) 
172- 
173- Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values 
174- from the chain. 
175- 
176- This implementation does not assume that `chain` can be indexed using variable names. It is 
177- thus not guaranteed to work in cases where the variable names have complex subsumption 
178- patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`. 
179- """ 
180- function  _varname_pairs_without_varname_indexing (
181-     chain:: MCMCChains.Chains , varinfo, sample_idx, chain_idx
182- )
183-     values =  chain. value[sample_idx, :, chain_idx]
184-     keys =  Base. keys (chain)
185-     keys_strings =  map (string, keys)
186-     varname_pairs =  [
187-         vn =>  _vcat_subsumed_values (string (vn), values, keys_strings) for 
188-         vn in  Base. keys (varinfo)
189-     ]
190-     return  varname_pairs
191- end 
192- 
193121end 
0 commit comments