diff --git a/HISTORY.md b/HISTORY.md index 52616315d..80f301582 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.35.9 + +Fixed the `isnan` check introduced in 0.35.7 for distributions which returned NamedTuple. + ## 0.35.8 Added the `DynamicPPL.TestUtils.AD.run_ad` function to test the correctness and/or benchmark the performance of an automatic differentiation backend on DynamicPPL models. diff --git a/Project.toml b/Project.toml index 6c72e55d9..1cb375890 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.8" +version = "0.35.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 76c097e94..ef661fc3a 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -242,6 +242,10 @@ function _has_missings(x::AbstractArray) return false end +_has_nans(x::NamedTuple) = any(_has_nans, x) +_has_nans(x::AbstractArray) = any(_has_nans, x) +_has_nans(x) = isnan(x) + # assume function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) record_varname!(context, vn, dist) @@ -291,7 +295,7 @@ function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) ) end # Check for NaN's as well - if any(isnan, left) + if _has_nans(left) error( "Encountered a NaN value on the left-hand side of an" * " observe statement; this may indicate that your data" * diff --git a/test/debug_utils.jl b/test/debug_utils.jl index b79ff1fbc..2becb690c 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -130,8 +130,15 @@ x[i] ~ Normal(a) end end - model = demo_nan_in_data([1.0, NaN]) - @test_throws ErrorException check_model(model; error_on_failure=true) + m = demo_nan_in_data([1.0, NaN]) + @test_throws ErrorException check_model(m; error_on_failure=true) + # Test NamedTuples with nested arrays, see #898 + @model function demo_nan_complicated(nt) + nt ~ product_distribution((x=Normal(), y=Dirichlet([2, 4]))) + return x ~ Normal() + end + m = demo_nan_complicated((x=1.0, y=[NaN, 0.5])) + @test_throws ErrorException check_model(m; error_on_failure=true) end @testset "incorrect use of condition" begin