Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
8750255
Merge branch 'master' into torfjelde/transformations
torfjelde Apr 16, 2024
ed6ee88
Merge branch 'master' into torfjelde/transformations
torfjelde Jun 18, 2024
607bdb3
Update test/model.jl
torfjelde Jun 18, 2024
55c8098
Apply suggestions from code review
torfjelde Jun 21, 2024
cc910d5
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
torfjelde Jun 27, 2024
ec9f985
Merge branch 'master' into torfjelde/transformations
torfjelde Jul 14, 2024
a079606
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
torfjelde Jul 19, 2024
ad959ec
Type-stability tests are now correctly using `rand_prior_true` instead
torfjelde Jul 21, 2024
9f84070
`getindex_internal` now calls `getindex` instead of `view`, as the
torfjelde Jul 21, 2024
7d39934
Removed seemingly unnecessary definition of `getindex_internal`
torfjelde Jul 21, 2024
b554504
Fixed references to `newmetadata` which has been replaced by `replace…
torfjelde Jul 28, 2024
ddb1dfe
Made implementation of `recombine` more explicit
torfjelde Jul 28, 2024
3b08f1d
Added docstrings for `untyped_varinfo` and `typed_varinfo`
torfjelde Jul 28, 2024
96ccebe
Added TODO comment about implementing `view` for `VarInfo`
torfjelde Jul 28, 2024
beaeeaa
Fixed potential infinite recursion as suggested by @mhauru
torfjelde Jul 28, 2024
ab2c98b
added docstring to `from_vec_trnasform_for_size
torfjelde Jul 28, 2024
f1f7968
Replaced references to `vectorize(dist, x)` with `tovec(x)`
torfjelde Jul 28, 2024
6e57822
Fixed docstring
torfjelde Jul 28, 2024
841215f
Update src/extract_priors.jl
torfjelde Jul 28, 2024
78b2083
Bump minor version since this is a breaking change
torfjelde Jul 28, 2024
b6ecf7b
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
torfjelde Jul 28, 2024
7100ce1
Merge branch 'master' into torfjelde/transformations
torfjelde Jul 28, 2024
bab63e1
Apply suggestions from code review
sunxd3 Jul 30, 2024
6997019
Update src/varinfo.jl
sunxd3 Jul 30, 2024
9dc7f02
Apply suggestions from code review
torfjelde Jul 30, 2024
c0f9923
Apply suggestions from code review
torfjelde Jul 30, 2024
9056928
Update src/extract_priors.jl
torfjelde Aug 6, 2024
e43dd1b
Added fix for product distributions of targets with changing support …
torfjelde Aug 6, 2024
a7673fd
Addeed tests for product of distributions with dynamic support
torfjelde Aug 6, 2024
e8d4c96
Apply suggestions from code review
torfjelde Aug 6, 2024
2fe7605
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
mhauru Aug 8, 2024
ca0c951
Fix typos, improve docstrings
mhauru Aug 8, 2024
60bb054
Use Accessors rather than Setfield
mhauru Aug 8, 2024
32a9ec7
Simplify group_by_symbol
mhauru Aug 8, 2024
bc41d82
Add short_varinfo_name(::VectorVarInfo)
mhauru Aug 8, 2024
6900034
Add tests for subset
mhauru Aug 8, 2024
e9be160
Export VectorVarInfo
mhauru Aug 8, 2024
2ae8516
Tighter type bound for has_varnamevector
mhauru Aug 8, 2024
524c148
Add some VectorVarName methods
mhauru Aug 8, 2024
b076aef
Add todo notes, remove dead code, fix a typo.
mhauru Aug 8, 2024
f28b430
Bug fixes and small improvements
mhauru Aug 14, 2024
5f02494
VarNameVector improvements
mhauru Aug 15, 2024
56fac99
Improve generated_quantities and its tests
mhauru Aug 19, 2024
c793ada
Improvement to VarNameVector
mhauru Aug 19, 2024
ed2d695
Fix a test to work with VectorVarName
mhauru Aug 19, 2024
01935c8
Fix generated_quantities
mhauru Aug 19, 2024
f8d0100
Fix type stability issues
mhauru Aug 21, 2024
d4ba9f5
Various VarNameVector fixes and improvements
mhauru Aug 21, 2024
fef615d
Merge remote-tracking branch 'origin/master' into mhauru/varnamevector
mhauru Aug 22, 2024
bd67b38
Bump version number
mhauru Aug 22, 2024
06d9df5
Merge remote-tracking branch 'origin/torfjelde/varnamevector' into mh…
mhauru Aug 22, 2024
9596bea
Improvements to generated_quantities
mhauru Aug 22, 2024
b8309d2
Code formatting
mhauru Aug 22, 2024
44fc385
Code style
mhauru Aug 22, 2024
ad13acf
Add fallback implementation of findinds for VarNameVector
mhauru Aug 22, 2024
d0322b7
Rename VarNameVector to VarNamedVector
mhauru Aug 22, 2024
250010d
More renaming of VNV. Remove unused VarNamedVector.metadata field.
mhauru Aug 22, 2024
02d5187
Rename FromVec to ReshapeTransform
mhauru Aug 23, 2024
94cf179
Progress towards having VarNamedVector as storage for SimpleVarInfo
mhauru Aug 28, 2024
27bac26
Fix unflatten(vnv::VarNamedVector, vals)
mhauru Aug 29, 2024
38147da
More work on SimpleVarInfo{VarNamedVector}
mhauru Aug 29, 2024
3990914
More tests for SimpleVarInfo{VarNamedVector}
mhauru Aug 29, 2024
a7a9974
More tests for SimpleVarInfo{VarNamedVector}
mhauru Aug 29, 2024
1937359
Respond to review feedback
mhauru Aug 30, 2024
d4120c3
Add float_type_with_fallback(::Type{Union{}})
mhauru Aug 30, 2024
ca04666
Move some VNV functions to the correct file
mhauru Aug 30, 2024
bb9ae76
Fix push! for VNV
mhauru Aug 30, 2024
536a476
Rename VNV.is_transformed to VNV.is_unconstrained
mhauru Aug 30, 2024
6a029bb
Improve VNV docstring
mhauru Aug 30, 2024
d8f8b17
Add VNV inner constructor checks
mhauru Aug 30, 2024
076e478
Reorganise parts of VNV code
mhauru Aug 30, 2024
f11f007
Documentation and small fixes for VNV
mhauru Sep 2, 2024
f8361f6
Rename loosen_types!! and tighten_types, add docstrings and doctests
mhauru Sep 2, 2024
004c327
Rename VarNameVector to VarNamedVector in docs
mhauru Sep 2, 2024
5291290
Documentation and small fixes to VNV
mhauru Sep 2, 2024
3d472f4
Fix subset(::VarNamedVector, args...) for unconstrained variables.
mhauru Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,21 @@ Hence we obtain a "type-stable when possible"-representation by wrapping it in a

## Efficient storage and iteration

Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`VarNameVector`](@ref):
Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`VarNamedVector`](@ref):

```@docs
DynamicPPL.VarNameVector
DynamicPPL.VarNamedVector
```

In a [`VarNameVector{<:VarName,Vector{T}}`](@ref), we achieve the desirata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s.
In a [`VarNamedVector{<:VarName,Vector{T}}`](@ref), we achieve the desirata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s.

This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields:

- `varnames::Vector{VarName}`: the `VarName`s in the order they appear in the `Vector{T}`.
- `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`.
- `transforms::Vector`: the transforms associated with each `VarName`.

Mutating functions, e.g. `setindex!(vnv::VarNameVector, val, vn::VarName)`, are then treated according to the following rules:
Mutating functions, e.g. `setindex!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules:

1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc.

Expand All @@ -156,7 +156,7 @@ Mutating functions, e.g. `setindex!(vnv::VarNameVector, val, vn::VarName)`, are
2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field.
3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`.

This means that `VarNameVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in.
This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in.

For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example:

Expand Down Expand Up @@ -195,7 +195,7 @@ DynamicPPL.contiguify!
For example, one might encounter the following scenario:

```@example varinfo-design
vnv = DynamicPPL.VarNameVector(@varname(x) => [true])
vnv = DynamicPPL.VarNamedVector(@varname(x) => [true])
println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))")

for i in 1:5
Expand All @@ -210,7 +210,7 @@ end
We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage:

```@example varinfo-design
vnv = DynamicPPL.VarNameVector(@varname(x) => [true])
vnv = DynamicPPL.VarNamedVector(@varname(x) => [true])
println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))")

for i in 1:5
Expand All @@ -225,13 +225,13 @@ for i in 1:5
end
```

This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNameVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous.
This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous.

!!! note

Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing he `VarName`'s transformation with a `DynamicPPL.FromVec`.
Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`.

Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNameVector` as the `metadata` field:
Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field:

```@example varinfo-design
# Type-unstable
Expand Down Expand Up @@ -287,23 +287,23 @@ DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x))

### Performance summary

In the end, we have the following "rough" performance characteristics for `VarNameVector`:
In the end, we have the following "rough" performance characteristics for `VarNamedVector`:

| Method | Is blazingly fast? |
|:---------------------------------------:|:--------------------------------------------------------------------------------------------:|
| `getindex` | ${\color{green} \checkmark}$ |
| `setindex!` | ${\color{green} \checkmark}$ |
| `push!` | ${\color{green} \checkmark}$ |
| `delete!` | ${\color{red} \times}$ |
| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size |
| `values_as(::VarNameVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise |
| Method | Is blazingly fast? |
|:----------------------------------------:|:--------------------------------------------------------------------------------------------:|
| `getindex` | ${\color{green} \checkmark}$ |
| `setindex!` | ${\color{green} \checkmark}$ |
| `push!` | ${\color{green} \checkmark}$ |
| `delete!` | ${\color{red} \times}$ |
| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size |
| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise |

## Other methods

```@docs
DynamicPPL.replace_values(::VarNameVector, vals::AbstractVector)
DynamicPPL.replace_values(::VarNamedVector, vals::AbstractVector)
```

```@docs; canonical=false
DynamicPPL.values_as(::VarNameVector)
DynamicPPL.values_as(::VarNamedVector)
```
145 changes: 138 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
error("$(typeof(c)) do not support indexing using varnmes.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
Expand Down Expand Up @@ -41,21 +41,152 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)

Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.

# Examples
## General
Often you might have additional quantities computed inside the model that you want to
inspect, e.g.
```julia
@model function demo(x)
# sample and observe
θ ~ Prior()
x ~ Likelihood()
return interesting_quantity(θ, x)
end
m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10

for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end

return (m, )
end
demo (generic function with 1 method)

julia> model = demo(randn(10));

julia> chain = sample(model, MH(), 10);

julia> generated_quantities(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.043088571494005024,)
(-0.16489786710222099,)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
if DynamicPPL.supports_varname_indexing(chain)
varname_pairs = _varname_pairs_with_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
else
varname_pairs = _varname_pairs_without_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
end
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
return fixed_model()
end
end

"""
_varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation assumes `chain` can be indexed using variable names, and is the
preffered implementation.
"""
function _varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
vns = DynamicPPL.varnames(chain)
vn_parents = Iterators.map(vns) do vn
# The call nested_setindex_maybe! is used to handle cases where vn is not
# the variable name used in the model, but rather subsumed by one. Except
# for the subsumption part, this could be
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
DynamicPPL.nested_setindex_maybe!(
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
)
end
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
vn_parent => varinfo[vn_parent]
end
return varname_pairs
end

"""
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.

The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
won't catch all cases. We should get rid of this if we can.
"""
# TODO(mhauru) See docstring above.
function _vcat_subsumed_values(vn_string, values, key_strings)
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
end

"""
_varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation does not assume that `chain` can be indexed using variable names. It is
thus not guaranteed to work in cases where the variable names have complex subsumption
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
"""
function _varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
values = chain.value[sample_idx, :, chain_idx]
keys = Base.keys(chain)
keys_strings = map(string, keys)
varname_pairs = [
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
vn in Base.keys(varinfo)
]
return varname_pairs
end

end
6 changes: 3 additions & 3 deletions ext/DynamicPPLReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ else
end

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction
) where {Tcompile}
ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction
)
return LogDensityProblemsAD.ADgradient(
Val(:ReverseDiff),
ℓ;
compile=Val(Tcompile),
compile=Val(ad.compile),
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
# `zero(D)` will return 0 when D is Real.
Expand Down
5 changes: 3 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
VectorVarInfo,
Copy link
Member

@yebai yebai Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhauru Is it possible to also introduce and test against VectorSimpleVarInfo (i.e. SimpleVarInfo with VarNameVector as storage format)? If so, can you investigate what might be needed from VarNameVector?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on this, and it seems doable. I got a big chunk of the test suite to pass, still need to add some more tests and see if they fail.

SimpleVarInfo,
VarNameVector,
VarNamedVector,
push!!,
empty!!,
subset,
Expand Down Expand Up @@ -176,7 +177,7 @@ include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varnamevector.jl")
include("varnamedvector.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

julia> # For the sake of brevity, let's just check the type.
md = values_as(vi); md.s isa DynamicPPL.Metadata
md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector}
true

julia> values_as(vi, NamedTuple)
Expand All @@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

julia> # For the sake of brevity, let's just check the type.
values_as(vi) isa DynamicPPL.Metadata
values_as(vi) isa Union{DynamicPPL.Metadata, Vector}
true

julia> values_as(vi, NamedTuple)
Expand Down Expand Up @@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`.
This should generally not be called explicitly, as it's only used in
[`matchingvalue`](@ref) to determine the default type to use in place of
type-parameters passed to the model.

This method is considered legacy, and is likely to be deprecated in the future.
"""
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})
Expand Down
15 changes: 12 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ function assume(
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
BangBang.setindex!!(vi, f(r), vn)
Expand Down Expand Up @@ -516,7 +519,10 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vns[1], "del", true)
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
Expand Down Expand Up @@ -554,7 +560,10 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vns[1], "del", true)
f = (vn, dist) -> init(rng, dist, spl)
r = f.(vns, dists)
for i in eachindex(vns)
Expand Down
Loading
Loading