From 33121d06bb8456ce8a9bfd2b0055609099cf0a57 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 20:44:01 +0000 Subject: [PATCH 01/14] Add Aqua tests --- test/Aqua.jl | 8 ++++++++ test/Project.toml | 1 + test/runtests.jl | 1 + 3 files changed, 10 insertions(+) create mode 100644 test/Aqua.jl diff --git a/test/Aqua.jl b/test/Aqua.jl new file mode 100644 index 000000000..2ebdb55dc --- /dev/null +++ b/test/Aqua.jl @@ -0,0 +1,8 @@ +module AquaTests + +using Aqua: Aqua +using DynamicPPL + +Aqua.test_all(DynamicPPL) + +end diff --git a/test/Project.toml b/test/Project.toml index e0fbbb8c5..2f0d5b8bf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/test/runtests.jl b/test/runtests.jl index caddef5f9..bdfc01b2c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,7 @@ include("test_util.jl") # groups are chosen to make both groups take roughly the same amount of # time, but beyond that there is no particular reason for the split. if GROUP == "All" || GROUP == "Group1" + include("Aqua.jl") include("utils.jl") include("compiler.jl") include("varnamedvector.jl") From b64d3f70fe5c8266637c7fba3cf7b6f522ddd9e4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 16:38:19 +0000 Subject: [PATCH 02/14] Fix logpdf(::NamedDist) method ambiguity --- src/distribution_wrappers.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index c631b6f19..d7097f5b4 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -17,6 +17,10 @@ Base.length(dist::NamedDist) = Base.length(dist.dist) Base.size(dist::NamedDist) = Base.size(dist.dist) Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x) +function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real,0}) + # extract the singleton value from 0-dimensional array + return Distributions.logpdf(dist.dist, first(x)) +end function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real}) return Distributions.logpdf(dist.dist, x) end From 5c7ebeea8da44230d2ac4709f0e681a9530d666b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 19:40:32 +0000 Subject: [PATCH 03/14] Fix SimpleVarInfo method ambiguity --- src/simple_varinfo.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 173eaa9e1..c6713f81a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -232,8 +232,14 @@ function SimpleVarInfo(; kwargs...) end # Constructor from `Model`. -SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) -function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} +function SimpleVarInfo( + model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... +) + return SimpleVarInfo{Float64}(model, args...) +end +function SimpleVarInfo{T}( + model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... +) where {T<:Real} return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) end From 605683448070739f67f36eb1a1717c2e9f9b4291 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 20:42:46 +0000 Subject: [PATCH 04/14] Fix VarInfo method ambiguity --- src/varinfo.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index ca143ea63..dd751052b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -200,7 +200,11 @@ function VarInfo( ) return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata) end -VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) +function VarInfo( + model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... +) + return VarInfo(Random.default_rng(), model, args...) +end """ vector_length(varinfo::VarInfo) From 0bc1a4e94ecc75c25b85fd852a2c1f74a9130fb1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 20:45:19 +0000 Subject: [PATCH 05/14] Add InteractiveUtils compat entry See: https://discourse.julialang.org/t/psa-compat-requirements-in-the-general-registry-are-changing/104958 --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a9463a821..0a067f29d 100644 --- a/Project.toml +++ b/Project.toml @@ -56,6 +56,7 @@ Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12" +InteractiveUtils = "1" JET = "0.9" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" From 1a82abcff6b0fc42b8a8cff0ca10272f4951e1ae Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 20:47:21 +0000 Subject: [PATCH 06/14] Add Random.AbstractRNG type annotation --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index af04d0f57..e4ba5d252 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -195,7 +195,7 @@ function tilde_observe!!(context, right, left, vi) return left, acclogp_observe!!(context, vi, logp) end -function assume(rng, spl::Sampler, dist) +function assume(rng::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end From 5bb16c455d91354697760544853617645805fa11 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 10 Jan 2025 12:19:20 +0000 Subject: [PATCH 07/14] Remove unneeded getsym method --- src/varname.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/varname.jl b/src/varname.jl index a47f42c25..c16587065 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -41,6 +41,3 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end - -# HACK: Type-piracy. Is this really the way to go? -AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From 614587c0c26ca2540f3e428f891d5fb4a19f3e96 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 12 Jan 2025 16:12:40 +0000 Subject: [PATCH 08/14] =?UTF-8?q?Fix=20(newly=20introduced=20=F0=9F=98=85)?= =?UTF-8?q?=20ConditionContext=20method=20ambiguity?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/contexts.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 87ad8df0b..a54c60374 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -335,6 +335,8 @@ function ConditionContext(values::Union{NamedTuple,AbstractDict}) end # Optimisation when there are no values to condition on ConditionContext(::NamedTuple{()}, context::AbstractContext) = context +# Same as above, and avoids method ambiguity with below +ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context # Collapse consecutive levels of `ConditionContext`. Note that this overrides # values inside the child context, thus giving precedence to the outermost # `ConditionContext`. From d83e5c2c7b323c02256efcef585823a8268d9fb3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 12 Jan 2025 20:09:55 +0000 Subject: [PATCH 09/14] Fix unwrap_right_left_vns method ambiguity --- src/compiler.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8bde5e784..95e76778b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -250,7 +250,10 @@ x[1][3] ``` """ unwrap_right_left_vns(right, left, vns) = right, left, vns -function unwrap_right_left_vns(right::NamedDist, left, vns) +function unwrap_right_left_vns(right::NamedDist, left::AbstractArray, ::VarName) + return unwrap_right_left_vns(right.dist, left, right.name) +end +function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName) return unwrap_right_left_vns(right.dist, left, right.name) end function unwrap_right_left_vns( From 637d20e39d302570eae8f1288571edc62585938a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Mar 2025 01:42:56 +0000 Subject: [PATCH 10/14] KernelAbstractions is a weakdep not a dep --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0a067f29d..c67384c3e 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -30,6 +29,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" From 072c3dc2158de693951932f87fdac9376f70381d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 9 Mar 2025 01:58:15 +0000 Subject: [PATCH 11/14] Fix StaticTransformation / ThreadSafeVarInfo link/invlink ambiguity --- src/threadsafe.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 539c1e9d6..2dc2645de 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -115,6 +115,19 @@ function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) return invlink!!(t, deepcopy(vi), model) end +# These two StaticTransformation methods needed to resolve ambiguities +function link!!( + t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model +) + return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model) +end + +function invlink!!( + t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model +) + return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model) +end + function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the From a0b8134a9beb572eb7fc72e48361d9bdbd832c55 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Mar 2025 16:59:24 +0000 Subject: [PATCH 12/14] Fix more RNGs --- benchmarks/src/DynamicPPLBenchmarks.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index b67f2ce06..4c73bf355 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -9,6 +9,7 @@ using LogDensityProblems: LogDensityProblems using ForwardDiff: ForwardDiff using Mooncake: Mooncake using ReverseDiff: ReverseDiff +using StableRNGs: StableRNG include("./Models.jl") using .Models: Models @@ -61,18 +62,20 @@ The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversedi `islinked` determines whether to link the VarInfo for evaluation. """ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) + rng = StableRNG(23) + suite = BenchmarkGroup() vi = if varinfo_choice == :untyped vi = VarInfo() - model(vi) + model(rng, vi) vi elseif varinfo_choice == :typed - VarInfo(model) + VarInfo(rng, model) elseif varinfo_choice == :simple_namedtuple - SimpleVarInfo{Float64}(model()) + SimpleVarInfo{Float64}(model(rng)) elseif varinfo_choice == :simple_dict - retvals = model() + retvals = model(rng) vns = [VarName{k}() for k in keys(retvals)] SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) else From b1387707715b01d6dc3611419b10189752c993c5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 12 Mar 2025 18:21:48 +0000 Subject: [PATCH 13/14] Don't run Aqua tests on CI min versions --- .github/workflows/CI.yml | 2 ++ test/runtests.jl | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a0aba44a8..2480d82c1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -70,6 +70,8 @@ jobs: env: GROUP: ${{ matrix.test_group }} JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }} + # Only run Aqua tests on latest version + AQUA: ${{ matrix.runner.version == '1' ? 'true' : 'false' }} - uses: julia-actions/julia-processcoverage@v1 diff --git a/test/runtests.jl b/test/runtests.jl index a6f6fdeb1..783a1170d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ using OrderedCollections: OrderedSet using DynamicPPL: getargs_dottilde, getargs_tilde const GROUP = get(ENV, "GROUP", "All") +const AQUA = get(ENV, "AQUA", "true") == "true" Random.seed!(100) include("test_util.jl") @@ -44,7 +45,9 @@ include("test_util.jl") # groups are chosen to make both groups take roughly the same amount of # time, but beyond that there is no particular reason for the split. if GROUP == "All" || GROUP == "Group1" - include("Aqua.jl") + if AQUA + include("Aqua.jl") + end include("utils.jl") include("compiler.jl") include("varnamedvector.jl") From f1d5003955a73e1a3d161caa6d567a9c0e14a7e0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Mar 2025 16:55:18 +0000 Subject: [PATCH 14/14] Fix ternary in GitHub Actions expression --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2480d82c1..ac8414c4e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -71,7 +71,7 @@ jobs: GROUP: ${{ matrix.test_group }} JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }} # Only run Aqua tests on latest version - AQUA: ${{ matrix.runner.version == '1' ? 'true' : 'false' }} + AQUA: ${{ matrix.runner.version == '1' && 'true' || 'false' }} - uses: julia-actions/julia-processcoverage@v1