Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ See also: [`combine`](@ref)
function split end

"""
combine(acc::TAcc, acc2::TAcc) where {TAcc<:AbstractAccumulator}
combine(acc::AbstractAccumulator, acc2::AbstractAccumulator)

Combine two accumulators of the same type. Returns a new accumulator of the same type.
Combine two accumulators which have the same type (but may, in general, have different type
parameters). Returns a new accumulator of the same type.

See also: [`split`](@ref)
"""
Expand Down
72 changes: 58 additions & 14 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@
const _DEBUG_ACC_NAME = :Debug
DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME

function split(acc::DebugAccumulator)
return DebugAccumulator(

Check warning on line 160 in src/debug_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_utils.jl#L159-L160

Added lines #L159 - L160 were not covered by tests
OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure
)
end
function combine(acc1::DebugAccumulator, acc2::DebugAccumulator)
return DebugAccumulator(

Check warning on line 165 in src/debug_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_utils.jl#L164-L165

Added lines #L164 - L165 were not covered by tests
merge(acc1.varnames_seen, acc2.varnames_seen),
vcat(acc1.statements, acc2.statements),
acc1.error_on_failure || acc2.error_on_failure,
Expand Down Expand Up @@ -217,9 +217,21 @@
end
end

_has_missings(x) = ismissing(x)
function _has_missings(x::AbstractArray)
# Can't just use `any` because `x` might contain `undef`.
for i in eachindex(x)
if isassigned(x, i) && _has_missings(x[i])
return true
end
end
return false
end

_has_nans(x::NamedTuple) = any(_has_nans, x)
_has_nans(x::AbstractArray) = any(_has_nans, x)
_has_nans(x) = isnan(x)
_has_nans(::Missing) = false

Check warning on line 234 in src/debug_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_utils.jl#L234

Added line #L234 was not covered by tests

function DynamicPPL.accumulate_assume!!(
acc::DebugAccumulator, val, _logjac, vn::VarName, right::Distribution
Expand All @@ -233,13 +245,38 @@
function DynamicPPL.accumulate_observe!!(
acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing}
)
if _has_missings(val)
# If `val` itself is a missing, that's a bug because that should cause
# us to go down the assume path.
val === missing && error(
"Encountered `missing` value on the left-hand side of an observe" *
" statement. This should not happen. Please open an issue at" *
" https://github.com/TuringLang/DynamicPPL.jl.",
)
# Otherwise it's an array with some missing values.
msg =
"Encountered a container with one or more `missing` value(s) on the" *
" left-hand side of an observe statement. To treat the variable on" *
" the left-hand side as a random variable, you should specify a single" *
" `missing` rather than a vector of `missing`s. It is not possible to" *
" set part but not all of a distribution to be `missing`."
if acc.error_on_failure
error(msg)
else
@warn msg

Check warning on line 266 in src/debug_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_utils.jl#L266

Added line #L266 was not covered by tests
end
end
# Check for NaN's as well
if _has_nans(val)
error(
msg =
"Encountered a NaN value on the left-hand side of an" *
" observe statement; this may indicate that your data" *
" contain NaN values.",
)
" contain NaN values."
if acc.error_on_failure
error(msg)
else
@warn msg

Check warning on line 278 in src/debug_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_utils.jl#L278

Added line #L278 was not covered by tests
end
end
stmt = ObserveStmt(; varname=vn, right=right, value=val)
push!(acc.statements, stmt)
Expand Down Expand Up @@ -338,15 +375,17 @@
julia> print(trace)
assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252

julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,));
julia> cond_model = model | (x = 1.0,);

julia> issuccess, trace = check_model_and_trace(cond_model, VarInfo(cond_model));
┌ Warning: The model does not contain any parameters.
└ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342

julia> issuccess
true

julia> print(trace)
observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0)
observe: x (= 1.0) ~ Normal{Float64}(μ=0.0, σ=1.0)
```

## Incorrect model
Expand All @@ -359,15 +398,22 @@
end
demo_incorrect (generic function with 2 methods)

julia> issuccess, trace = check_model_and_trace(rng, demo_incorrect(); error_on_failure=true);
julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually
# alert us to the issue of `x` being sampled twice.
model = demo_incorrect(); varinfo = VarInfo(model);

julia> issuccess, trace = check_model_and_trace(model, varinfo; error_on_failure=true);
ERROR: varname x used multiple times in model
```
"""
function check_model_and_trace(
model::Model, varinfo::AbstractVarInfo; error_on_failure=false
)
# Add debug accumulator to the VarInfo.
varinfo = DynamicPPL.setacc!!(deepcopy(varinfo), DebugAccumulator(error_on_failure))
# Need a NumProduceAccumulator as well or else get_num_produce may throw
varinfo = DynamicPPL.setaccs!!(
deepcopy(varinfo), (DebugAccumulator(error_on_failure), NumProduceAccumulator())
)

# Perform checks before evaluating the model.
issuccess = check_model_pre_evaluation(model)
Expand All @@ -388,17 +434,15 @@
end

"""
check_model([rng, ]model::Model; kwargs...)

Check that `model` is valid, warning about any potential issues.
check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false)

See [`check_model_and_trace`](@ref) for more details on supported keyword arguments
and details of which types of checks are performed.
Check that `model` is valid, warning about any potential issues (or erroring if
`error_on_failure` is `true`).

# Returns
- `issuccess::Bool`: Whether the model check succeeded.
"""
check_model(model::Model, varinfo::AbstractVarInfo=VarInfo(model); error_on_failure=false) =
check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) =
first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure))

# Convenience method used to check if all elements in a list are the same.
Expand All @@ -415,7 +459,7 @@
end

"""
has_static_constraints([rng, ]model::Model; num_evals=5, kwargs...)
has_static_constraints([rng, ]model::Model; num_evals=5, error_on_failure=false)

Return `true` if the model has static constraints, `false` otherwise.

Expand Down
4 changes: 2 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ Evaluation in transformed space of course also works:

```jldoctest simplevarinfo-general
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
Transformed SimpleVarInfo((x = -1.0,), AccumulatorTuple((LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was one of the reasons why I liked the simplified MIME"text/plain" show: To declutter printing out varinfo types.


julia> # (✓) Positive probability mass on negative numbers!
getlogjoint(last(DynamicPPL.evaluate!!(m, vi)))
-1.3678794411714423

julia> # While if we forget to indicate that it's transformed:
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
SimpleVarInfo((x = -1.0,), AccumulatorTuple((LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0))))

julia> # (✓) No probability mass on negative numbers!
getlogjoint(last(DynamicPPL.evaluate!!(m, vi)))
Expand Down
13 changes: 12 additions & 1 deletion test/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
buggy_model = buggy_demo_model()
varinfo = VarInfo(buggy_model)

@test_logs (:warn,) (:warn,) check_model(buggy_model)
@test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo)
issuccess = check_model(buggy_model, varinfo)
@test !issuccess
@test_throws ErrorException check_model(
Expand Down Expand Up @@ -142,6 +142,17 @@
end

@testset "incorrect use of condition" begin
@testset "missing in multivariate" begin
@model function demo_missing_in_multivariate(x)
return x ~ MvNormal(zeros(length(x)), I)
end
model = demo_missing_in_multivariate([1.0, missing])
# Have to run this check_model call with an empty varinfo, because actually
# instantiating the VarInfo would cause it to throw a MethodError.
model = contextualize(model, SamplingContext())
@test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true)
end

@testset "condition both in args and context" begin
@model function demo_condition_both_in_args_and_context(x)
return x ~ Normal()
Expand Down
Loading