diff --git a/docs/src/api.md b/docs/src/api.md index 14b2447b5..9a1923b53 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -380,7 +380,6 @@ acclogprior!! getloglikelihood setloglikelihood!! accloglikelihood!! -resetlogp!! ``` #### Variables and their realizations diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f190c7605..b400e83dd 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -70,7 +70,6 @@ export AbstractVarInfo, acclogjac!!, acclogprior!!, accloglikelihood!!, - resetlogp!!, is_flagged, set_flag!, unset_flag!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7940f20e6..ac841baab 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -330,6 +330,15 @@ function map_accumulators!!(func::Function, vi::AbstractVarInfo) return setaccs!!(vi, map(func, getaccs(vi))) end +""" + resetaccs!!(vi::AbstractVarInfo) + +Reset the values of all accumulators, using [`reset`](@ref). +""" +function resetaccs!!(vi::AbstractVarInfo) + return setaccs!!(vi, map(reset, getaccs(vi))) +end + """ map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname} @@ -423,24 +432,6 @@ function acclogp!!(vi::AbstractVarInfo, logp::Number) return accloglikelihood!!(vi, logp) end -""" - resetlogp!!(vi::AbstractVarInfo) - -Reset the values of the log probabilities (prior and likelihood) in `vi` to zero. -""" -function resetlogp!!(vi::AbstractVarInfo) - if hasacc(vi, Val(:LogPrior)) - vi = map_accumulator!!(zero, vi, Val(:LogPrior)) - end - if hasacc(vi, Val(:LogJacobian)) - vi = map_accumulator!!(zero, vi, Val(:LogJacobian)) - end - if hasacc(vi, Val(:LogLikelihood)) - vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) - end - return vi -end - # Variables and their realizations. @doc """ keys(vi::AbstractVarInfo) @@ -491,7 +482,7 @@ function getindex_internal end @doc """ empty!!(vi::AbstractVarInfo) -Empty `vi` of variables and reset any `logp` accumulators zeros. +Empty `vi` of variables and reset all accumulators. This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. """ BangBang.empty!! diff --git a/src/accumulators.jl b/src/accumulators.jl index 0dcf9c7cf..0208f19a5 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` - `accumulate_observe!!(acc::T, dist, val, vn)` - `accumulate_assume!!(acc::T, val, logjac, vn, dist)` +- `reset(acc::T)` - `Base.copy(acc::T)` In these functions: @@ -50,9 +51,7 @@ accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`. -`vn` is the name of the variable being observed, `left` is the value of the variable, and -`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case -of literal observations like `0.0 ~ Normal()`. +See [`AbstractAccumulator`](@ref) for the meaning of the arguments. `accumulate_observe!!` may mutate `acc`, but not any of the other arguments. @@ -65,11 +64,7 @@ function accumulate_observe!! end Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. -`vn` is the name of the variable being assumed, `val` is the value of the variable (in the -original, unlinked space), and `right` is the distribution on the RHS of the tilde -statement. `logjac` is the log determinant of the Jacobian of the transformation that was -done to convert the value of `vn` as it was given to `val`: for example, if the sampler is -operating in linked (Euclidean) space, then logjac will be nonzero. +See [`AbstractAccumulator`](@ref) for the meaning of the arguments. `accumulate_assume!!` may mutate `acc`, but not any of the other arguments. @@ -77,14 +72,37 @@ See also: [`accumulate_observe!!`](@ref) """ function accumulate_assume!! end +""" + reset(acc::AbstractAccumulator) + +Return a new accumulator like `acc`, but with its contents reset to the state that they +should be at the beginning of model evaluation. + +Note that this may in general have very similar behaviour to [`split`](@ref), and may share +the same implementation, but the difference is that `split` may in principle happen at any +stage during model evaluation, whereas `reset` is only called at the beginning of model +evaluation. +""" +function reset end + +@doc """ + Base.copy(acc::AbstractAccumulator) + +Create a new accumulator that is a copy of `acc`, without aliasing (i.e., this should +behave conceptually like a `deepcopy`). +""" Base.copy + """ split(acc::AbstractAccumulator) -Return a new accumulator like `acc` but empty. +Return a new accumulator like `acc` suitable for use in a forked thread. + +The returned value should be such that `combine(acc, split(acc))` is equal to `acc`. This is +used in the context of multi-threading where different threads may accumulate independently +and the results are then combined. -The precise meaning of "empty" is that that the returned value should be such that -`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading -where different threads may accumulate independently and the results are then combined. +Note that this may in general have very similar behaviour to [`reset`](@ref), but is +semantically different. See [`reset`](@ref) for more details. See also: [`combine`](@ref) """ diff --git a/src/debug_utils.jl b/src/debug_utils.jl index d71fa57cc..c2be4b46b 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -156,12 +156,22 @@ end const _DEBUG_ACC_NAME = :Debug DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME -function split(acc::DebugAccumulator) +function Base.:(==)(acc1::DebugAccumulator, acc2::DebugAccumulator) + return ( + acc1.varnames_seen == acc2.varnames_seen && + acc1.statements == acc2.statements && + acc1.error_on_failure == acc2.error_on_failure + ) +end + +function _zero(acc::DebugAccumulator) return DebugAccumulator( OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure ) end -function combine(acc1::DebugAccumulator, acc2::DebugAccumulator) +DynamicPPL.reset(acc::DebugAccumulator) = _zero(acc) +DynamicPPL.split(acc::DebugAccumulator) = _zero(acc) +function DynamicPPL.combine(acc1::DebugAccumulator, acc2::DebugAccumulator) return DebugAccumulator( merge(acc1.varnames_seen, acc2.varnames_seen), vcat(acc1.statements, acc2.statements), @@ -416,7 +426,7 @@ function check_model_and_trace( issuccess = check_model_pre_evaluation(model) # Force single-threaded execution. - DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 2b505f1bb..0dfd12401 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -47,8 +47,9 @@ end Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h) -split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) - +_zero(::Tacc) where {Tlogp,Tacc<:LogProbAccumulator{Tlogp}} = Tacc(zero(Tlogp)) +reset(acc::LogProbAccumulator) = _zero(acc) +split(acc::LogProbAccumulator) = _zero(acc) function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator) if basetypeof(acc) !== basetypeof(acc2) msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))" @@ -59,8 +60,6 @@ end acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val) -Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc))) - function Base.convert( ::Type{AccType}, acc::LogProbAccumulator ) where {T,AccType<:LogProbAccumulator{T}} diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 64dcf2eea..d311a5f63 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -10,7 +10,13 @@ end accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator -split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) +function Base.:(==)(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator) + return acc1.priors == acc2.priors +end + +_zero(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) +reset(acc::PriorDistributionAccumulator) = _zero(acc) +split(acc::PriorDistributionAccumulator) = _zero(acc) function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator) return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors)) end diff --git a/src/model.jl b/src/model.jl index ac9968cf2..9f9c6ec3b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -884,7 +884,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ function evaluate_threadunsafe!!(model, varinfo) - return _evaluate!!(model, resetlogp!!(varinfo)) + return _evaluate!!(model, resetaccs!!(varinfo)) end """ @@ -899,7 +899,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) + wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper) # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it # will return the underlying VI, which is a bit counterintuitive (because diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index dea432022..61834ab62 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -32,6 +32,12 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end +function Base.:(==)( + acc1::PointwiseLogProbAccumulator{wlp1}, acc2::PointwiseLogProbAccumulator{wlp2} +) where {wlp1,wlp2} + return (wlp1 == wlp2 && acc1.logps == acc2.logps) +end + function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) end @@ -56,10 +62,11 @@ function accumulator_name( return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} +function _zero(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) end - +reset(acc::PointwiseLogProbAccumulator) = _zero(acc) +split(acc::PointwiseLogProbAccumulator) = _zero(acc) function combine( acc::PointwiseLogProbAccumulator{whichlogprob}, acc2::PointwiseLogProbAccumulator{whichlogprob}, @@ -223,23 +230,37 @@ function pointwise_logdensities( # Get the data by executing the model once vi = VarInfo(model) + # This accumulator tracks the pointwise log-probabilities in a single iteration. AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} vi = setaccs!!(vi, (AccType(),)) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + + # Maintain a separate accumulator that isn't tied to a VarInfo but rather + # tracks _all_ iterations. + all_logps = AccType() for (sample_idx, chain_idx) in iters # Update the values setval!(vi, chain, sample_idx, chain_idx) # Execute model vi = last(evaluate!!(model, vi)) + + # Get the log-probabilities + this_iter_logps = getacc(vi, Val(accumulator_name(AccType))).logps + + # Merge into main acc + for (varname, this_lp) in this_iter_logps + # Because `this_lp` is obtained from one model execution, it should only + # contain one variable, hence `only()`. + push!(all_logps, varname, only(this_lp)) + end end - logps = getacc(vi, Val(accumulator_name(AccType))).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( - varname => reshape(vals, niters, nchains) for (varname, vals) in logps + varname => reshape(vals, niters, nchains) for (varname, vals) in all_logps.logps ) return logdensities end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ad22bf52d..cfad93ed9 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -287,7 +287,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector) end function BangBang.empty!!(vi::SimpleVarInfo) - return resetlogp!!(Accessors.@set vi.values = empty!!(vi.values)) + return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 450dd2c38..6ca3b9852 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -171,27 +171,13 @@ end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) + return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end -function resetlogp!!(vi::ThreadSafeVarInfo) - vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) +function resetaccs!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo) for i in eachindex(vi.accs_by_thread) - if hasacc(vi, Val(:LogPrior)) - vi.accs_by_thread[i] = map_accumulator( - zero, vi.accs_by_thread[i], Val(:LogPrior) - ) - end - if hasacc(vi, Val(:LogJacobian)) - vi.accs_by_thread[i] = map_accumulator( - zero, vi.accs_by_thread[i], Val(:LogJacobian) - ) - end - if hasacc(vi, Val(:LogLikelihood)) - vi.accs_by_thread[i] = map_accumulator( - zero, vi.accs_by_thread[i], Val(:LogLikelihood) - ) - end + vi.accs_by_thread[i] = map(reset, vi.accs_by_thread[i]) end return vi end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index df663bf54..71baebe92 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -20,15 +20,21 @@ function ValuesAsInModelAccumulator(include_colon_eq) return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq) end +function Base.:(==)(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) + return (acc1.include_colon_eq == acc2.include_colon_eq && acc1.values == acc2.values) +end + function Base.copy(acc::ValuesAsInModelAccumulator) return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq) end accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel -function split(acc::ValuesAsInModelAccumulator) +function _zero(acc::ValuesAsInModelAccumulator) return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq) end +reset(acc::ValuesAsInModelAccumulator) = _zero(acc) +split(acc::ValuesAsInModelAccumulator) = _zero(acc) function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) if acc1.include_colon_eq != acc2.include_colon_eq msg = "Cannot combine accumulators with different include_colon_eq values." diff --git a/src/varinfo.jl b/src/varinfo.jl index e115a6799..dec4db3ec 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -883,7 +883,7 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - vi = resetlogp!!(vi) + vi = resetaccs!!(vi) return vi end diff --git a/test/accumulators.jl b/test/accumulators.jl index df20f2c11..e45dfb028 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -13,6 +13,7 @@ using DynamicPPL: convert_eltype, getacc, map_accumulator, + reset, setacc!!, split @@ -23,12 +24,12 @@ using DynamicPPL: LogPriorAccumulator() == LogPriorAccumulator{Float64}() == LogPriorAccumulator{Float64}(0.0) == - zero(LogPriorAccumulator(1.0)) + DynamicPPL.reset(LogPriorAccumulator(1.0)) @test LogLikelihoodAccumulator(0.0) == LogLikelihoodAccumulator() == LogLikelihoodAccumulator{Float64}() == LogLikelihoodAccumulator{Float64}(0.0) == - zero(LogLikelihoodAccumulator(1.0)) + DynamicPPL.reset(LogLikelihoodAccumulator(1.0)) end @testset "addition and incrementation" begin @@ -135,7 +136,7 @@ using DynamicPPL: @testset "map_accumulator(s)!!" begin # map over all accumulators accs = AccumulatorTuple(lp_f32, ll_f32) - @test map(zero, accs) == AccumulatorTuple( + @test map(DynamicPPL.reset, accs) == AccumulatorTuple( LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0) ) # Test that the original wasn't modified. @@ -146,7 +147,7 @@ using DynamicPPL: AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0)) # only apply to a particular accumulator - @test map_accumulator(zero, accs, Val(:LogLikelihood)) == + @test map_accumulator(DynamicPPL.reset, accs, Val(:LogLikelihood)) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0)) @test map_accumulator( acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index be6deb96e..526fce92c 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -204,7 +204,7 @@ end # Reset the logp accumulators. - svi_eval = DynamicPPL.resetlogp!!(svi_eval) + svi_eval = DynamicPPL.resetaccs!!(svi_eval) # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 24a738a78..0421c89e2 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -25,7 +25,7 @@ @test getlogjoint(vi) == lp @test getlogjoint(threadsafe_vi) == lp + 42 - threadsafe_vi = resetlogp!!(threadsafe_vi) + threadsafe_vi = DynamicPPL.resetaccs!!(threadsafe_vi) @test iszero(getlogjoint(threadsafe_vi)) expected_accs = DynamicPPL.AccumulatorTuple( (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... diff --git a/test/varinfo.jl b/test/varinfo.jl index 202ddc1b2..ba7c17b34 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -112,7 +112,7 @@ end test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end - @testset "get/set/acc/resetlogp" begin + @testset "get/set/acclogp" begin function test_varinfo_logp!(vi) @test DynamicPPL.getlogjoint(vi) === 0.0 vi = DynamicPPL.setlogprior!!(vi, 1.0) @@ -131,8 +131,6 @@ end @test DynamicPPL.getlogprior(vi) === 2.0 @test DynamicPPL.getloglikelihood(vi) === 2.0 @test DynamicPPL.getlogjoint(vi) === 4.0 - vi = DynamicPPL.resetlogp!!(vi) - @test DynamicPPL.getlogjoint(vi) === 0.0 end vi = VarInfo() @@ -143,7 +141,7 @@ end test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end - @testset "accumulators" begin + @testset "logp accumulators" begin @model function demo() a ~ Normal() b ~ Normal() @@ -227,6 +225,71 @@ end @test_throws r"has no field `?LogPrior" getlogjoint(vi) end + @testset "resetaccs" begin + # Put in a bunch of accumulators, check that they're all reset either + # when we call resetaccs!!, empty!!, or evaluate!!. + @model function demo() + a ~ Normal() + return x ~ Normal(a) + end + model = demo() + vi_orig = VarInfo(model) + # It already has the logp accumulators, so let's add in some more. + vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.DebugUtils.DebugAccumulator(true)) + vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.ValuesAsInModelAccumulator(true)) + vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.PriorDistributionAccumulator()) + vi_orig = DynamicPPL.setacc!!( + vi_orig, DynamicPPL.PointwiseLogProbAccumulator{:both}() + ) + # And evaluate the model once so that they are populated. + _, vi_orig = DynamicPPL.evaluate!!(model, vi_orig) + + function all_accs_empty(vi::AbstractVarInfo) + for acc_key in keys(DynamicPPL.getaccs(vi)) + acc = DynamicPPL.getacc(vi, Val(acc_key)) + acc == DynamicPPL.reset(acc) || return false + end + return true + end + + @test !all_accs_empty(vi_orig) + + vi = DynamicPPL.resetaccs!!(deepcopy(vi_orig)) + @test all_accs_empty(vi) + @test getlogjoint(vi) == 0.0 # for good measure + @test getlogprior(vi) == 0.0 + @test getloglikelihood(vi) == 0.0 + + vi = DynamicPPL.empty!!(deepcopy(vi_orig)) + @test all_accs_empty(vi) + @test getlogjoint(vi) == 0.0 + @test getlogprior(vi) == 0.0 + @test getloglikelihood(vi) == 0.0 + + function all_accs_same(vi1::AbstractVarInfo, vi2::AbstractVarInfo) + # Check that they have the same accs + keys1 = Set(keys(DynamicPPL.getaccs(vi1))) + keys2 = Set(keys(DynamicPPL.getaccs(vi2))) + keys1 == keys2 || return false + # Check that they have the same values + for acc_key in keys1 + acc1 = DynamicPPL.getacc(vi1, Val(acc_key)) + acc2 = DynamicPPL.getacc(vi2, Val(acc_key)) + if acc1 != acc2 + @show acc1, acc2 + end + acc1 == acc2 || return false + end + return true + end + # Hopefully this doesn't matter + @test all_accs_same(vi_orig, deepcopy(vi_orig)) + # If we re-evaluate, then we expect the accs to be reset prior to evaluation. + # Thus after re-evaluation, the accs should be exactly the same as before. + _, vi = DynamicPPL.evaluate!!(model, deepcopy(vi_orig)) + @test all_accs_same(vi, vi_orig) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag!