Skip to content

Commit 8744525

Browse files
authored
Dev (#49)
* drop some unneeded methods * update for DensityInterface * updating densities to DensityInterface approach * update domains * moving things around * updates to IntegerBounds * more domain mucking * fix some exports * update deps * update `using` * working on tests * Working toward tests passing * some refactoring * working on tests * cleaning up * get tests to pass * tests passing * update logdensity_def for pointwise product * PrettyPrinting + tests * speed up `rootmeasure` * oops didn't mean to include that * tile(::FactoredBase) * Update Half and FactoredBase * drop exp.jl * simplify basemeasure * drop old `include` * drop redundant method * update compat Returns * add AbstractDensity * Move Affine to MeasureTheory * drop some old `For` code * update counting measure * add testvalue(::Type{T}) * bugfix * some dispatch adjustments * simplify show * formatting * transformed measures * ZeroSets * using LinearAlgebra, Statistics * working on MeasureTheory tests * updates * updates * typo * update default to mimic Base * fix tile(::Lebesgue) * Add test_interface function * small doc update * DensityKind(::Likelihood) * fixing show(::Likelihood) * adding some docs * update `rand` method * drop old integration code * add `rebase` * export rebase * law for ⊗ * typo * Make Likelihood more flexible * update kernel methods * add Likelihood method (avoid stack overflow) * refactoring * compat * Maybe Comat just works? * refactoring * some new stuff * working on tests * update powermeasure combinator * bugfix * comment out debugging lines * more refactoring * test @inferred basemeasure_depth(μ) * drop `constructor` (just use ConstructionBase.constructorof`) * debugging * update help * update interface * make tests harder * fixes * Dirac bugfix * formatting * improve type inference * working on type inference * update interface * udpates * get test passing * @test !isabstracttype(typejoin(...)) * work on show methods * update CI * remove old code * update productmeasure * prettyprinting stuff * Drop te @constprop :aggressive stuff (maybe don't need it?) * nerline * dropping some old code * update tbasemeasure_type(::PowerMeasure) * moar tests * update SpikeMixture * update superpose type parameter name * drop old tests * func_string * more updates * getting closer * almost there! * generated function for type stability * tests passing! * newline * more fixes * exports and bugfix * insupport(μ::Counting{T}, x) where {T<:Type} * working on MeasureTheory tests * MeasureTheory tests passing * drop some old code * inlining * improve inference * update `tile(::For)` * tighten down infrerence * update basemeasure(::For) for generators * loosen type bound on instance_type * drop debugging code * small update for Likelihood, and a test * fixing up likelihoods * improve `basemeasure_depth` dispatch * still some trouble with inferred basemeasure_depth * clean up `For` dispatch * simplify _logdensityof * optimize for Returns{True} case * rework basemeasure_depth * aggressive tests passing!! * drop type-level stuff * drop help * license * affero * copyright notice * merge * Drop Create Commons * cleanup after merge * update support computations * insupport(d::SuperpositionMeasure, x) * dorp ParamWeighted * insupport(d::FactoredBase, x) * export unsafe_logdensityof * call promote_type instead of promote_rule * logdensity_def for named tuple product measures * type annotation for now * debugging * drop shows * speed up mapped arrays * throw an error for `Union{}` types * MT tests passing * updates * get tests passing * MIT license for MeasureBase * bump version * cleanup * spacing * Move ConditionalMeasure to MeasureBase * add LogarithmicNumbers * export basemeasure_sequence * update superpose * fix logdensity_rel * remove FIXME (it's fixed!!) * logdensityof(d::Density, x) * simplify insupport(::Lebesgue, ::Real) * clean up * assume insupport yields Bool * change logdensity_rel fall-through to warning and return NaN * update logdensity_rel * drop old code * fix warning * export logdensity_rel * logdensity_def(μ::Dirac, ν::Dirac, x) * logdensity_def methods * drop `static` * ]add StatsFuns * Fixing up superposition * [compat] entries * trying to speed things up * bugfixes * logdensity_rel tests * logdensity_rel tests * drop qualifier, and add a test * more tests * type constraint in "logdensityof(μ::AbstractMeasure, x)" (was piracy, oops) * add some docs * docs * docs * typo * moar speed * don't export Test * some more updates * logdensity_rel for products * `kleisli` docs * update instance_type * instance_type => Core.Typeof * `powermeasure` bug fix * fix logdensity_rel bug * get `commonbase` to take x type into account * test powers * commonbase docstring * deprecate instance_type * avoid breakage * switch || terms * @ifelse macro * simplify logdensity_rel * give up on this @ifelse business * bump version * working on likelihoods * update likelihood * powerweightedmeasure * powerweighted update * more powerweighted methods * bugfix * dropFactoredBase * drop FactoredBase * (::ProductMeasure) | constraint * update conditional measure * update Dirac * move conditional.jl down in the `include`s * Kleisli => TransitionKernel * simplify logdensity_def(::PowerMeasure, x) * rename kleisli.jl to kernel.jl * update Dirac tests * update Half * get tests passing * update kernel * Update Project.toml * no call-site inlining * restrict single-arg `kernel` to <:ParameterizedMeasure * export log_likelihood_ratio * Drop DensityKind(::Likelihood), at least for now * isfinite(x) instead of x>-Inf * add `condition` constructor * EOF newline * simplify logdensity_def for power measures * finishing up
1 parent e6dcca1 commit 8744525

22 files changed

+335
-319
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.7.0"
4+
version = "0.8.0"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/MeasureBase.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import DensityInterface: densityof
1313
import DensityInterface: DensityKind
1414
using DensityInterface
1515

16+
import Base.iterate
1617
import ConstructionBase
1718
using ConstructionBase: constructorof
1819

@@ -31,7 +32,7 @@ export AbstractMeasure
3132
import IfElse: ifelse
3233
export logdensity_def
3334
export basemeasure
34-
export basekleisli
35+
export basekernel
3536

3637
"""
3738
inssupport(m, x)
@@ -51,11 +52,9 @@ abstract type AbstractMeasure end
5152

5253
using Static: @constprop
5354

54-
function Pretty.tile(d::M) where {M<:AbstractMeasure}
55+
function Pretty.quoteof(d::M) where {M<:AbstractMeasure}
5556
the_names = fieldnames(typeof(d))
56-
result = Pretty.literal(repr(M))
57-
isempty(the_names) && return result * Pretty.literal("()")
58-
Pretty.list_layout(Pretty.tile.([getfield(d, n) for n in the_names]); prefix=result)
57+
:($M($([getfield(d, n) for n in the_names]...)))
5958
end
6059

6160
@inline DensityKind(::AbstractMeasure) = HasDensity()
@@ -97,23 +96,21 @@ using Compat
9796
include("schema.jl")
9897
include("splat.jl")
9998
include("proxies.jl")
100-
include("kleisli.jl")
99+
include("kernel.jl")
101100
include("parameterized.jl")
102101
include("combinators/half.jl")
103102
include("domains.jl")
104103
include("primitive.jl")
105104
include("utils.jl")
106-
include("absolutecontinuity.jl")
105+
# include("absolutecontinuity.jl")
107106

108107
include("primitives/counting.jl")
109108
include("primitives/lebesgue.jl")
110109
include("primitives/dirac.jl")
111110
include("primitives/trivial.jl")
112111

113-
include("combinators/conditional.jl")
114112
include("combinators/bind.jl")
115113
include("combinators/transformedmeasure.jl")
116-
include("combinators/factoredbase.jl")
117114
include("combinators/weighted.jl")
118115
include("combinators/superpose.jl")
119116
include("combinators/product.jl")
@@ -123,6 +120,8 @@ include("combinators/likelihood.jl")
123120
include("combinators/pointwise.jl")
124121
include("combinators/restricted.jl")
125122
include("combinators/smart-constructors.jl")
123+
include("combinators/powerweighted.jl")
124+
include("combinators/conditional.jl")
126125

127126
include("rand.jl")
128127

src/combinators/conditional.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@ the notion of normalization may not even make sense.
2323
Because of this, this interface is not yet stable, and users should expect
2424
upcoming changes.
2525
"""
26-
Base.:|::AbstractMeasure, constraint) = ConditionalMeasure(μ, constraint)
26+
Base.:|::AbstractMeasure, constraint) = condition(μ, constraint)
2727

28-
@inline basemeasure(cm::ConditionalMeasure) = basemeasure(cm.parent) | cm.constraint
28+
condition(μ, constraint) = ConditionalMeasure(μ, constraint)
29+
30+
@inline basemeasure(cm::ConditionalMeasure) = basemeasure(cm.parent) | cm.constraint
31+
32+
# @generated function Base.:|(μ::ProductMeasure{NamedTuple{M,T}}, constraint::NamedTuple{N}) where {M,T,N}
33+
# newkeys = tuple(setdiff(M, N)...)
34+
# quote
35+
# mar = marginals(μ)
36+
# productmeasure(NamedTuple{$newkeys}(mar))
37+
# end
38+
# end
39+
40+
function Base.:|::ProductMeasure{NamedTuple{M,T}}, constraint::NamedTuple{N}) where {M,T,N}
41+
productmeasure(merge(marginals(μ),rmap(Dirac, constraint)))
42+
end
43+
44+
function Pretty.tile(d::ConditionalMeasure)
45+
Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.constraint), sep=" | ")
46+
end

src/combinators/factoredbase.jl

Lines changed: 0 additions & 21 deletions
This file was deleted.

src/combinators/half.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,8 @@ end
1111

1212
unhalf::Half) = μ.parent
1313

14-
isnonnegative(x) = x 0.0
15-
1614
@inline function basemeasure::Half)
17-
const= static(logtwo)
18-
varℓ = Returns(0.0)
19-
base = basemeasure(unhalf(μ))
20-
FactoredBase(constℓ, varℓ, base)
15+
weightedmeasure(static(logtwo), basemeasure(unhalf(μ)))
2116
end
2217

2318
function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}
@@ -27,7 +22,8 @@ end
2722
logdensity_def::Half, x) = logdensity_def(unhalf(μ), x)
2823

2924
@inline function insupport(d::Half, x)
30-
ifelse(isnonnegative(x), insupport(unhalf(d), x), false)
25+
x 0 || return false
26+
insupport(unhalf(d), x)
3127
end
3228

33-
testvalue(::Half) = 1.0
29+
testvalue(::Half) = 1.0

src/combinators/likelihood.jl

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@ export AbstractLikelihood, Likelihood
22

33
abstract type AbstractLikelihood end
44

5-
@inline logdensityof(ℓ::AbstractLikelihood, par) = logdensity_def(ℓ, par)
5+
# @inline function logdensityof(ℓ::AbstractLikelihood, p)
6+
# t() = dynamic(unsafe_logdensityof(ℓ, p))
7+
# f() = -Inf
8+
# ifelse(insupport(ℓ, p), t, f)()
9+
# end
10+
11+
# insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x)
612

713
@doc raw"""
8-
Likelihood(k::AbstractKleisli, x)
14+
Likelihood(k::AbstractTransitionKernel, x)
915
1016
"Observe" a value `x`, yielding a function from the parameters to ℝ.
1117
@@ -89,7 +95,8 @@ Finally, let's return to the expression for Bayes's Law,
8995
9096
The product on the right side is computed pointwise. To work with this in
9197
MeasureBase, we have a "pointwise product" `⊙`, which takes a measure and a
92-
likelihood, and returns a new measure, that is, the unnormalized posterior that has density ``P(θ) P(x|θ)`` with respect to the base measure of the prior.
98+
likelihood, and returns a new measure, that is, the unnormalized posterior that
99+
has density ``P(θ) P(x|θ)`` with respect to the base measure of the prior.
93100
94101
For example, say we have
95102
@@ -109,14 +116,11 @@ struct Likelihood{K,X} <: AbstractLikelihood
109116
k::K
110117
x::X
111118

112-
Likelihood(k::K, x::X) where {K<:AbstractKleisli,X} = new{K,X}(k,x)
119+
Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k,x)
113120
Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k,x)
114-
Likelihood(μ, x) = Likelihood(kleisli(μ), x)
121+
Likelihood(μ, x) = Likelihood(kernel(μ), x)
115122
end
116123

117-
# Not really a density, but this makes the code work
118-
@inline DensityKind(::Likelihood) = IsDensity()
119-
120124
function Pretty.quoteof(ℓ::Likelihood)
121125
k = Pretty.quoteof(ℓ.k)
122126
x = Pretty.quoteof(ℓ.x)
@@ -128,10 +132,48 @@ function Base.show(io::IO, ℓ::Likelihood)
128132
Pretty.pprint(io, ℓ)
129133
end
130134

131-
@inline function logdensity_def(ℓ::Likelihood, p::Tuple)
132-
return logdensity_def(ℓ.k(p), ℓ.x)
133-
end
135+
# @inline function logdensity_def(ℓ::Likelihood, p)
136+
# return logdensity_def(ℓ.k(p), ℓ.x)
137+
# end
134138

135-
@inline function logdensity_def(ℓ::Likelihood, p)
136-
return logdensity_def(ℓ.k((p,)), ℓ.x)
137-
end
139+
# basemeasure(ℓ::Likelihood, p) = basemeasure(ℓ.k(p), ℓ.x)
140+
141+
# basemeasure(ℓ::Likelihood) = @error "Likelihood requires local base measure"
142+
143+
export likelihood
144+
145+
"""
146+
likelihood(k::AbstractTransitionKernel, x; constraints...)
147+
likelihood(k::AbstractTransitionKernel, x, constraints::NamedTuple)
148+
149+
A likelihood is *not* a measure. Rather, a likelihood acts on a measure, through
150+
the "pointwise product" `⊙`, yielding another measure.
151+
"""
152+
function likelihood end
153+
154+
likelihood(k, x, ::NamedTuple{()}) = Likelihood(k, x)
155+
156+
likelihood(k, x; kwargs...) = likelihood(k, x, NamedTuple(kwargs))
157+
158+
likelihood(k, x, pars::NamedTuple) = likelihood(kernel(k, pars), x)
159+
160+
likelihood(k::AbstractTransitionKernel, x) = Likelihood(k, x)
161+
162+
export log_likelihood_ratio
163+
164+
"""
165+
log_likelihood_ratio(ℓ::Likelihood, p, q)
166+
167+
Compute the log of the likelihood ratio, in order to compare two choices for
168+
parameters. This is computed as
169+
170+
logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)
171+
172+
Since `logdensity_rel` can leave common base measure unevaluated, this can be
173+
more efficient than
174+
175+
logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x)
176+
"""
177+
log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)
178+
179+
# likelihood(k, x; kwargs...) = likelihood(k, x, NamedTuple(kwargs))

src/combinators/pointwise.jl

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,34 @@
11
export
22

3-
struct PointwiseProductMeasure{M,L} <: AbstractMeasure
4-
measure::M
3+
struct PointwiseProductMeasure{P,L} <: AbstractMeasure
4+
prior::P
55
likelihood::L
6-
7-
function PointwiseProductMeasure::M, ℓ::L) where {M,L}
8-
@assert static_hasmethod(logdensity_def, Tuple{L, gentype(μ)})
9-
return new{M,L}(μ, ℓ)
10-
end
116
end
127

13-
function Base.show(io::IO, μ::PointwiseProductMeasure)
14-
io = IOContext(io, :compact => true)
15-
print(io, μ.measure, "", μ.likelihood)
16-
end
178

18-
function Base.show_unquoted(io::IO, μ::PointwiseProductMeasure, indent::Int, prec::Int)
19-
io = IOContext(io, :compact => true)
20-
if Base.operator_precedence(:*) prec
21-
print(io, "(")
22-
show(io, μ)
23-
print(io, ")")
24-
else
25-
show(io, μ)
26-
end
27-
return nothing
9+
10+
iterate(p::PointwiseProductMeasure, i=1) = iterate((p.prior, p.likelihood), i)
11+
12+
function Pretty.tile(d::PointwiseProductMeasure)
13+
Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep="")
2814
end
2915

3016
(μ, ℓ) = pointwiseproduct(μ, ℓ)
3117

32-
@inline function logdensity_def(d::PointwiseProductMeasure, x)
33-
logdensity_def(d.measure, x) + logdensity_def(d.likelihood, x)
18+
@inline function logdensity_def(d::PointwiseProductMeasure, p)
19+
μ, ℓ = d
20+
logdensityof(ℓ.k(p), ℓ.x)
3421
end
3522

3623
function gentype(d::PointwiseProductMeasure)
37-
@inbounds gentype(d.measure)
24+
gentype(d.prior)
3825
end
3926

40-
basemeasure(d::PointwiseProductMeasure) = @inbounds basemeasure(d.measure)
27+
@inbounds function insupport(d::PointwiseProductMeasure, p)
28+
μ, ℓ = d
29+
insupport(μ, p) && insupport(ℓ.k(p), ℓ.x)
30+
end
31+
32+
basemeasure(d::PointwiseProductMeasure, x) = d.prior
33+
34+
basemeasure(d::PointwiseProductMeasure) = basemeasure(d.prior)

src/combinators/power.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,26 +65,16 @@ params(d::PowerMeasure) = params(first(marginals(d)))
6565
end
6666

6767
@inline function logdensity_def(d::PowerMeasure{M}, x) where {M}
68-
= 0.0
69-
# ℓ = zero(typeintersect(AbstractFloat,Core.Compiler.return_type(logdensity_def, Tuple{M,T})))
7068
parent = d.parent
71-
@simd for xj in x
72-
Δℓ = logdensity_def(parent, xj)
73-
+= Δℓ
69+
sum(x) do xj
70+
logdensity_def(parent, xj)
7471
end
75-
7672
end
7773

78-
@generated function logdensity_def(d::PowerMeasure{M, Tuple{Base.OneTo{StaticInt{N}}}}, x) where {M,N}
79-
quote
80-
$(Expr(:meta, :inline))
81-
= 0.0
82-
parent = d.parent
83-
@inbounds @simd for j in 1:$N
84-
Δℓ = logdensity_def(parent, x[j])
85-
+= Δℓ
86-
end
87-
74+
@inline function logdensity_def(d::PowerMeasure{M, Tuple{Base.OneTo{StaticInt{N}}}}, x) where {M,N}
75+
parent = d.parent
76+
sum(1:N) do j
77+
@inbounds logdensity_def(parent, x[j])
8878
end
8979
end
9080

src/combinators/powerweighted.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
export
2+
3+
struct PowerWeightedMeasure{M,A} <: AbstractMeasure
4+
parent::M
5+
exponent::A
6+
end
7+
8+
logdensity_def(d::PowerWeightedMeasure, x) = d.exponent * logdensity_def(d.parent, x)
9+
10+
basemeasure(d::PowerWeightedMeasure, x) = basemeasure(d.parent, x) d.exponent
11+
12+
basemeasure(d::PowerWeightedMeasure) = basemeasure(d.parent) d.exponent
13+
14+
function powerweightedmeasure(d, α)
15+
isone(α) && return d
16+
PowerWeightedMeasure(d, α)
17+
end
18+
19+
(d::AbstractMeasure) α = powerweightedmeasure(d, α)
20+
21+
insupport(d::PowerWeightedMeasure, x) = insupport(d.parent, x)
22+
23+
function Base.show(io::IO, d::PowerWeightedMeasure)
24+
print(io, d.parent, "", d.exponent)
25+
end
26+
27+
function powerweightedmeasure(d::PowerWeightedMeasure, α)
28+
powerweightedmeasure(d.parent, α * d.exponent)
29+
end
30+
31+
function powerweightedmeasure(d::WeightedMeasure, α)
32+
weightedmeasure*d.logweight, powerweightedmeasure(d.base, α))
33+
end
34+
35+
function Pretty.tile(d::PowerWeightedMeasure)
36+
Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.exponent), sep="")
37+
end

0 commit comments

Comments
 (0)