diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 89b65d2de..f98354db5 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -39,6 +39,20 @@ chosen_combinations = [ :forwarddiff, false, ), + ( + "Simple assume observe", + Models.simple_assume_observe(randn(rng)), + :typed, + :reversediff, + false, + ), + ( + "Simple assume observe", + Models.simple_assume_observe(randn(rng)), + :typed, + :mooncake, + false, + ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), @@ -51,6 +65,7 @@ chosen_combinations = [ ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), ("Dynamic", Models.dynamic(), :typed, :mooncake, true), ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), + ("LDA", lda_instance, :typed, :mooncake, true), ("LDA", lda_instance, :typed, :reversediff, true), ] diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl index 2c881aa95..997863755 100644 --- a/benchmarks/src/Models.jl +++ b/benchmarks/src/Models.jl @@ -127,19 +127,22 @@ end """ A simple Linear Discriminant Analysis model. + +The default value for `z` is chosen randomly to make autodiff work. Alternatively one +could marginalise out `z`. """ -@model function lda(K, d, w) +@model function lda(K, d, w, z=rand(1:K, length(d)), ::Type{T}=Float64) where {T} V = length(unique(w)) D = length(unique(d)) N = length(d) @assert length(w) == N - ϕ = Vector{Vector{Real}}(undef, K) + ϕ = Vector{Vector{T}}(undef, K) for i in 1:K ϕ[i] ~ Dirichlet(ones(V) / V) end - θ = Vector{Vector{Real}}(undef, D) + θ = Vector{Vector{T}}(undef, D) for i in 1:D θ[i] ~ Dirichlet(ones(K) / K) end @@ -150,7 +153,7 @@ A simple Linear Discriminant Analysis model. z[i] ~ Categorical(θ[d[i]]) w[i] ~ Categorical(ϕ[d[i]]) end - return (; ϕ=ϕ, θ=θ, z=z) + return (; ϕ=ϕ, θ=θ) end end diff --git a/test/ad.jl b/test/ad.jl index 73519c3f5..0cc40c9a4 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -103,4 +103,73 @@ using DynamicPPL: LogDensityFunction ) @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + reference_adtype = AutoForwardDiff() + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + model_instance = model() + vi = VarInfo(model_instance) + ldf = LogDensityFunction(model_instance, vi, DefaultContext(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, reference_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, reference_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, reference_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, reference_adtype) + + @testset "$adtype" for adtype in [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end end diff --git a/test/compiler.jl b/test/compiler.jl index 8d81c530a..3d3c6d9e3 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -289,6 +289,20 @@ module Issue537 end @test all((isassigned(x, i) for i in eachindex(x))) end + # Test that that using @. to stop unwanted broadcasting on the RHS works. + @testset "@. ~ with interpolation" begin + @model function at_dot_with_interpolation() + x = Vector{Float64}(undef, 2) + # Without the interpolation the RHS would turn into `Normal.(sum.([1.0, 2.0]))`, + # which would crash. + @. x ~ $(Normal(sum([1.0, 2.0]))) + end + + # The main check is just that calling at_dot_with_interpolation() doesn't crash, + # the check of the keys is not very important. + @show keys(VarInfo(at_dot_with_interpolation())) == [@varname(x[1]), @varname(x[2])] + end + # A couple of uses of .~ that are no longer valid as of v0.35. @testset "old .~ syntax" begin @model function multivariate_dot_tilde()