Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -30,6 +29,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

Expand All @@ -56,6 +56,7 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10.12"
InteractiveUtils = "1"
JET = "0.9"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
Expand Down
5 changes: 4 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@
```
"""
unwrap_right_left_vns(right, left, vns) = right, left, vns
function unwrap_right_left_vns(right::NamedDist, left, vns)
function unwrap_right_left_vns(right::NamedDist, left::AbstractArray, ::VarName)
return unwrap_right_left_vns(right.dist, left, right.name)

Check warning on line 254 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L253-L254

Added lines #L253 - L254 were not covered by tests
end
function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName)

Check warning on line 256 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L256

Added line #L256 was not covered by tests
return unwrap_right_left_vns(right.dist, left, right.name)
end
function unwrap_right_left_vns(
Expand Down
2 changes: 1 addition & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@
return left, acclogp_observe!!(context, vi, logp)
end

function assume(rng, spl::Sampler, dist)
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)

Check warning on line 198 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L198

Added line #L198 was not covered by tests
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

Expand Down
2 changes: 2 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@
end
# Optimisation when there are no values to condition on
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
# Same as above, and avoids method ambiguity with below
ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context

Check warning on line 339 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L339

Added line #L339 was not covered by tests
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
# values inside the child context, thus giving precedence to the outermost
# `ConditionContext`.
Expand Down
4 changes: 4 additions & 0 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
Base.size(dist::NamedDist) = Base.size(dist.dist)

Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real,0})

Check warning on line 20 in src/distribution_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/distribution_wrappers.jl#L20

Added line #L20 was not covered by tests
# extract the singleton value from 0-dimensional array
return Distributions.logpdf(dist.dist, first(x))

Check warning on line 22 in src/distribution_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/distribution_wrappers.jl#L22

Added line #L22 was not covered by tests
end
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.logpdf(dist.dist, x)
end
Expand Down
10 changes: 8 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@ function SimpleVarInfo(; kwargs...)
end

# Constructor from `Model`.
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...)
function SimpleVarInfo{T}(model::Model, args...) where {T<:Real}
function SimpleVarInfo(
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
)
return SimpleVarInfo{Float64}(model, args...)
end
function SimpleVarInfo{T}(
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
) where {T<:Real}
return last(evaluate!!(model, SimpleVarInfo{T}(), args...))
end

Expand Down
13 changes: 13 additions & 0 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@
return invlink!!(t, deepcopy(vi), model)
end

# These two StaticTransformation methods needed to resolve ambiguities
function link!!(

Check warning on line 119 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L119

Added line #L119 was not covered by tests
t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model
)
return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model)

Check warning on line 122 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L122

Added line #L122 was not covered by tests
end

function invlink!!(

Check warning on line 125 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L125

Added line #L125 was not covered by tests
t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model
)
return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model)

Check warning on line 128 in src/threadsafe.jl

View check run for this annotation

Codecov / codecov/patch

src/threadsafe.jl#L128

Added line #L128 was not covered by tests
end
Comment on lines +118 to +129
Copy link
Member Author

@penelopeysm penelopeysm Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the method ambiguity fixes in this PR are really straightforward, but it took me a long time to figure out what to do here.

Firstly, note that this method is only ever invoked in a very specific situation, where we have (1) multithreaded evaluation hence ThreadSafeVarInfo, (2) a StaticTransformation, and (3) the nested varinfo is e.g. a SimpleVarInfo and specifically not a VarInfo, because the ThreadSafeVarInfo{VarInfo} methods are overridden in varinfo.jl.

The DynamicPPL test suite does not ever call this method (if it did, it would have failed with a method ambiguity), and since this has never been reported before, one might justifiably assume that no user has ever called it too.

To be fully honest, while I can't see any reason why this would be wrong, I also can't convince myself 100% that this is the correct behaviour. However, looking at the lines directly above this, it seems clear enough to me that the intent of that code was to suggest that all AbstractTransformations should behave that way, except that DynamicTransformation was to be special-cased. So I figured that it would suffice to duplicate the general AbstractTransformation implementation here.


function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model)
# Defer to the wrapped `AbstractVarInfo` object.
# NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the
Expand Down
6 changes: 5 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@ function VarInfo(
)
return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata)
end
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
function VarInfo(
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
)
return VarInfo(Random.default_rng(), model, args...)
end

"""
vector_length(varinfo::VarInfo)
Expand Down
3 changes: 0 additions & 3 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,3 @@ Possibly existing indices of `varname` are neglected.
) where {s,missings,_F,_a,_T}
return s in missings
end

# HACK: Type-piracy. Is this really the way to go?
AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym
8 changes: 8 additions & 0 deletions test/Aqua.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module AquaTests

using Aqua: Aqua
using DynamicPPL

Aqua.test_all(DynamicPPL)

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("test_util.jl")
# groups are chosen to make both groups take roughly the same amount of
# time, but beyond that there is no particular reason for the split.
if GROUP == "All" || GROUP == "Group1"
include("Aqua.jl")
include("utils.jl")
include("compiler.jl")
include("varnamedvector.jl")
Expand Down
Loading