Skip to content

Commit a26ce11

Browse files
penelopeysmabhinavsnssunxd3
authored
Fix remaining method ambiguities (#2304)
* Enabling aqua ambiguity testing for Turing We test ambiguities only for Turing and not its dependencies. * Format * Fix bundle_samples method ambiguity Concretely: 1. Creating an `AbstractTransition` type which all the Transitions in Turing subtype. 2. Modifying the type signature of bundle_samples to take a Vector{<:Union{AbstractTransition,AbstractVarInfo}} as the first argument. The AbstractVarInfo case occurs when sampling with Prior(), so the type signature of this argument mirrors that of the Sampler in the same function. * Fix get() ambiguities Done by: 1. Constraining the type parameter to AbstractVector{Symbol} 2. Modifying the method below it to use a vector instead of a tuple * Bump to 0.34.0 --------- Co-authored-by: Abhinav Singh <[email protected]> Co-authored-by: Xianda Sun <[email protected]>
1 parent 5b5da11 commit a26ce11

File tree

6 files changed

+16
-13
lines changed

6 files changed

+16
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.33.3"
3+
version = "0.34.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/Inference.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ end
213213
# Extended in contrib/inference/abstractmcmc.jl
214214
getstats(t) = nothing
215215

216-
struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
216+
abstract type AbstractTransition end
217+
218+
struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition
217219
θ :: T
218220
lp :: F # TODO: merge `lp` with `stat`
219221
stat :: S
@@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing
409411
# Default MCMCChains.Chains constructor.
410412
# This is type piracy (at least for SampleFromPrior).
411413
function AbstractMCMC.bundle_samples(
412-
ts::Vector,
414+
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
413415
model::AbstractModel,
414416
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
415417
state,
@@ -472,7 +474,7 @@ end
472474

473475
# This is type piracy (for SampleFromPrior).
474476
function AbstractMCMC.bundle_samples(
475-
ts::Vector,
477+
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
476478
model::AbstractModel,
477479
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
478480
state,

src/mcmc/particle_mcmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
SMC(space::Symbol...) = SMC(space)
4646
SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space)
4747

48-
struct SMCTransition{T,F<:AbstractFloat}
48+
struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
4949
"The parameters for any given sample."
5050
θ::T
5151
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
@@ -222,7 +222,7 @@ end
222222

223223
const CSMC = PG # type alias of PG as Conditional SMC
224224

225-
struct PGTransition{T,F<:AbstractFloat}
225+
struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
226226
"The parameters for any given sample."
227227
θ::T
228228
"The joint log probability of the sample (NOTE: does not work, always set to zero)."

src/mcmc/sghmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ function SGLD(
193193
return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype)
194194
end
195195

196-
struct SGLDTransition{T,F<:Real}
196+
struct SGLDTransition{T,F<:Real} <: AbstractTransition
197197
"The parameters for any given sample."
198198
θ::T
199199
"The joint log probability of the sample."

src/optimisation/Optimisation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp
277277

278278
"""
279279
Base.get(m::ModeResult, var_symbol::Symbol)
280-
Base.get(m::ModeResult, var_symbols)
280+
Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
281281
282282
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
283283
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
284-
argument should be either a `Symbol` or an iterator of `Symbol`s.
284+
argument should be either a `Symbol` or a vector of `Symbol`s.
285285
"""
286-
function Base.get(m::ModeResult, var_symbols)
286+
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
287287
log_density = m.f
288288
# Get all the variable names in the model. This is the same as the list of keys in
289289
# m.values, but they are more convenient to filter when they are VarNames rather than
@@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols)
304304
return (; zip(var_symbols, value_vectors)...)
305305
end
306306

307-
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
307+
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])
308308

309309
"""
310310
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)

test/Aqua.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module AquaTests
33
using Aqua: Aqua
44
using Turing
55

6-
# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems
7-
# in dependencies. Would like to check it for just Turing.jl itself though.
6+
# We test ambiguities separately because it catches a lot of problems
7+
# in dependencies but we test it for Turing.
8+
Aqua.test_ambiguities([Turing])
89
Aqua.test_all(Turing; ambiguities=false)
910

1011
end

0 commit comments

Comments
 (0)