diff --git a/test/ad.jl b/test/ad.jl index 73519c3f5..a4f3dbfa7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,14 @@ using DynamicPPL: LogDensityFunction @testset "Automatic differentiation" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + test_adtypes = [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + @testset "Unsupported backends" begin @model demo() = x ~ Normal() @test_logs (:warn, r"not officially supported") LogDensityFunction( @@ -18,15 +26,10 @@ using DynamicPPL: LogDensityFunction f = LogDensityFunction(m, varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - ref_adtype = ADTypes.AutoForwardDiff() ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) - @testset "$adtype" for adtype in [ - AutoReverseDiff(; compile=false), - AutoReverseDiff(; compile=true), - AutoMooncake(; config=nothing), - ] + @testset "$adtype" for adtype in test_adtypes @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" # Put predicates here to avoid long lines @@ -103,4 +106,66 @@ using DynamicPPL: LogDensityFunction ) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = LogDensityFunction(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end end diff --git a/test/compiler.jl b/test/compiler.jl index 8d81c530a..3d3c6d9e3 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -289,6 +289,20 @@ module Issue537 end @test all((isassigned(x, i) for i in eachindex(x))) end + # Test that that using @. to stop unwanted broadcasting on the RHS works. + @testset "@. ~ with interpolation" begin + @model function at_dot_with_interpolation() + x = Vector{Float64}(undef, 2) + # Without the interpolation the RHS would turn into `Normal.(sum.([1.0, 2.0]))`, + # which would crash. + @. x ~ $(Normal(sum([1.0, 2.0]))) + end + + # The main check is just that calling at_dot_with_interpolation() doesn't crash, + # the check of the keys is not very important. + @show keys(VarInfo(at_dot_with_interpolation())) == [@varname(x[1]), @varname(x[2])] + end + # A couple of uses of .~ that are no longer valid as of v0.35. @testset "old .~ syntax" begin @model function multivariate_dot_tilde()