Skip to content

Commit 0b12781

Browse files
committed
Add comparison methods
1 parent 55be793 commit 0b12781

File tree

6 files changed

+75
-22
lines changed

6 files changed

+75
-22
lines changed

src/accumulators.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname}
136136
@inline return haskey(at.nt, accname)
137137
end
138138
Base.keys(at::AccumulatorTuple) = keys(at.nt)
139+
Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt
140+
Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h)
139141

140142
function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T}
141143
return AccumulatorTuple(convert(T, accs.nt))

src/default_accumulators.jl

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,16 @@ struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccu
5757
"the number of observations"
5858
num_produce::Eltype
5959
"mapping of variable names to their order in the model"
60-
order::OrderedDict{VNType, Eltype}
60+
order::OrderedDict{VNType,Eltype}
6161
end
6262

6363
"""
6464
VariableOrderAccumulator{T<:Integer}(n=zero(T))
6565
6666
Create a new `VariableOrderAccumulator` accumulator with the number of observations set to n
6767
"""
68-
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = VariableOrderAccumulator(convert(T, n), OrderedDict{VarName, T}())
68+
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
69+
VariableOrderAccumulator(convert(T, n), OrderedDict{VarName,T}())
6970
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
7071
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()
7172

@@ -76,7 +77,38 @@ function Base.show(io::IO, acc::LogLikelihoodAccumulator)
7677
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
7778
end
7879
function Base.show(io::IO, acc::VariableOrderAccumulator)
79-
return print(io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))")
80+
return print(
81+
io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))"
82+
)
83+
end
84+
85+
# Note that == and isequal are different, and equality under the latter should imply
86+
# equality of hashes. Both of the below implementations are also different from the default
87+
# implementation for structs.
88+
Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp
89+
function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
90+
return acc1.logp == acc2.logp
91+
end
92+
function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
93+
return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order
94+
end
95+
96+
function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
97+
return isequal(acc1.logp, acc2.logp)
98+
end
99+
function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
100+
return isequal(acc1.logp, acc2.logp)
101+
end
102+
function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
103+
return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order)
104+
end
105+
106+
Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h)
107+
function Base.hash(acc::LogLikelihoodAccumulator, h::UInt)
108+
return hash((LogLikelihoodAccumulator, acc.logp), h)
109+
end
110+
function Base.hash(acc::VariableOrderAccumulator, h::UInt)
111+
return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h)
80112
end
81113

82114
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
@@ -96,7 +128,9 @@ end
96128
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
97129
# Note that assumptions are not allowed within in parallelised blocks, and thus the
98130
# dictionaries should be identical.
99-
return VariableOrderAccumulator(max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order))
131+
return VariableOrderAccumulator(
132+
max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order)
133+
)
100134
end
101135

102136
function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
@@ -105,7 +139,9 @@ end
105139
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
106140
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
107141
end
108-
increment(acc::VariableOrderAccumulator) = VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
142+
function increment(acc::VariableOrderAccumulator)
143+
return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
144+
end
109145

110146
Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
111147
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
@@ -138,9 +174,9 @@ function Base.convert(
138174
return LogLikelihoodAccumulator(convert(T, acc.logp))
139175
end
140176
function Base.convert(
141-
::Type{VariableOrderAccumulator{ElType, VnType}}, acc::VariableOrderAccumulator
142-
) where {ElType, VnType}
143-
order = OrderedDict{VnType, ElType}()
177+
::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator
178+
) where {ElType,VnType}
179+
order = OrderedDict{VnType,ElType}()
144180
for (k, v) in acc.order
145181
order[convert(VnType, k)] = convert(ElType, v)
146182
end

src/simple_varinfo.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,12 @@ struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformati
198198
transformation::C
199199
end
200200

201+
function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo)
202+
return vi1.values == vi2.values &&
203+
vi1.accs == vi2.accs &&
204+
vi1.transformation == vi2.transformation
205+
end
206+
201207
transformation(vi::SimpleVarInfo) = vi.transformation
202208

203209
function SimpleVarInfo(values, accs)

src/threadsafe.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ end
7878

7979
syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
8080

81-
setorder!!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = ThreadSafeVarInfo(setorder!!(vi.varinfo, vn, index), vi.accs_by_thread)
81+
function setorder!!(vi::ThreadSafeVarInfo, vn::VarName, index::Int)
82+
return ThreadSafeVarInfo(setorder!!(vi.varinfo, vn, index), vi.accs_by_thread)
83+
end
8284
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
8385

8486
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)

src/varinfo.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ struct Metadata{
6060
flags::Dict{String,BitVector}
6161
end
6262

63+
function Base.:(==)(md1::Metadata, md2::Metadata)
64+
return (
65+
md1.idcs == md2.idcs &&
66+
md1.vns == md2.vns &&
67+
md1.ranges == md2.ranges &&
68+
md1.vals == md2.vals &&
69+
md1.dists == md2.dists &&
70+
md1.flags == md2.flags
71+
)
72+
end
73+
6374
###########
6475
# VarInfo #
6576
###########
@@ -155,6 +166,10 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
155166
VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}}
156167
}
157168

169+
function Base.:(==)(vi1::VarInfo, vi2::VarInfo)
170+
return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs)
171+
end
172+
158173
# NOTE: This is kind of weird, but it effectively preserves the "old"
159174
# behavior where we're allowed to call `link!` on the same `VarInfo`
160175
# multiple times.
@@ -275,9 +290,7 @@ function typed_varinfo(vi::UntypedVarInfo)
275290

276291
push!(
277292
new_metas,
278-
Metadata(
279-
sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags
280-
),
293+
Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags),
281294
)
282295
end
283296
nt = NamedTuple{syms_tuple}(Tuple(new_metas))
@@ -597,14 +610,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va
597610
end
598611

599612
flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags)
600-
return Metadata(
601-
indices,
602-
vns,
603-
ranges,
604-
vals,
605-
metadata.dists[indices_for_vns],
606-
flags,
607-
)
613+
return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], flags)
608614
end
609615

610616
function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo)

test/accumulators.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ using DynamicPPL:
7171
@test convert(
7272
LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0)
7373
) == LogLikelihoodAccumulator{Float32}(1.0f0)
74-
@test convert(VariableOrderAccumulator{UInt8,VarName}, VariableOrderAccumulator(1)) ==
75-
VariableOrderAccumulator{UInt8}(1)
74+
@test convert(
75+
VariableOrderAccumulator{UInt8,VarName}, VariableOrderAccumulator(1)
76+
) == VariableOrderAccumulator{UInt8}(1)
7677

7778
@test convert_eltype(Float32, LogPriorAccumulator(1.0)) ==
7879
LogPriorAccumulator{Float32}(1.0f0)

0 commit comments

Comments
 (0)