Skip to content

Commit 2ff6d3b

Browse files
authored
Dev (#41)
* some refactoring * pretty printing * some updates for resettablerng * deepcopy on iterate * Move ResettableRNG to MeasureTheory * Make Pretty a const * drop extra spaces * formatting * formatting * formatting * add DensityInterface dependency * faster `rootmeasure` * updates * primitive measure docs * get tests passing * cleaning up * formatting * drop some unneeded methods * update for DensityInterface * updating densities to DensityInterface approach * update domains * moving things around * updates to IntegerBounds * more domain mucking * fix some exports * update deps * update `using` * working on tests * Working toward tests passing * some refactoring * working on tests * cleaning up * get tests to pass * tests passing * update logdensity_def for pointwise product * PrettyPrinting + tests * speed up `rootmeasure` * oops didn't mean to include that * tile(::FactoredBase) * Update Half and FactoredBase * drop exp.jl * simplify basemeasure * drop old `include` * drop redundant method * update compat Returns * add AbstractDensity * Move Affine to MeasureTheory * drop some old `For` code * update counting measure * add testvalue(::Type{T}) * bugfix * some dispatch adjustments * simplify show * formatting * transformed measures * ZeroSets * using LinearAlgebra, Statistics * working on MeasureTheory tests * updates * updates * typo * update default to mimic Base * fix tile(::Lebesgue) * Add test_interface function * small doc update * DensityKind(::Likelihood) * fixing show(::Likelihood) * adding some docs * update `rand` method * drop old integration code * add `rebase` * export rebase * law for ⊗ * typo * Make Likelihood more flexible * update kernel methods * add Likelihood method (avoid stack overflow) * refactoring * compat * Maybe Comat just works? * refactoring * some new stuff * working on tests * update powermeasure combinator * bugfix * comment out debugging lines * more refactoring * test @inferred basemeasure_depth(μ) * drop `constructor` (just use ConstructionBase.constructorof`) * debugging * update help * update interface * make tests harder * fixes * Dirac bugfix * formatting * improve type inference * working on type inference * update interface * udpates * get test passing * @test !isabstracttype(typejoin(...)) * work on show methods * update CI * remove old code * update productmeasure * prettyprinting stuff * Drop te @constprop :aggressive stuff (maybe don't need it?) * nerline * dropping some old code * update tbasemeasure_type(::PowerMeasure) * moar tests * update SpikeMixture * update superpose type parameter name * drop old tests * func_string * more updates * getting closer * almost there! * generated function for type stability * tests passing! * newline * more fixes * exports and bugfix * insupport(μ::Counting{T}, x) where {T<:Type} * working on MeasureTheory tests * MeasureTheory tests passing * drop some old code * inlining * improve inference * update `tile(::For)` * tighten down infrerence * update basemeasure(::For) for generators * loosen type bound on instance_type * drop debugging code * small update for Likelihood, and a test * fixing up likelihoods * improve `basemeasure_depth` dispatch * still some trouble with inferred basemeasure_depth * clean up `For` dispatch * simplify _logdensityof * optimize for Returns{True} case * rework basemeasure_depth * aggressive tests passing!! * drop type-level stuff * drop help * license * affero * copyright notice * merge * Drop Create Commons * cleanup after merge * update support computations * insupport(d::SuperpositionMeasure, x) * dorp ParamWeighted * insupport(d::FactoredBase, x) * export unsafe_logdensityof * call promote_type instead of promote_rule * logdensity_def for named tuple product measures * type annotation for now * debugging * drop shows * speed up mapped arrays * throw an error for `Union{}` types * MT tests passing * updates * get tests passing * MIT license for MeasureBase * bump version * cleanup * spacing * Move ConditionalMeasure to MeasureBase * add LogarithmicNumbers * export basemeasure_sequence * update superpose * fix logdensity_rel * remove FIXME (it's fixed!!) * logdensityof(d::Density, x) * simplify insupport(::Lebesgue, ::Real) * clean up * assume insupport yields Bool * change logdensity_rel fall-through to warning and return NaN * update logdensity_rel * drop old code * fix warning * export logdensity_rel * logdensity_def(μ::Dirac, ν::Dirac, x) * logdensity_def methods * drop `static` * ]add StatsFuns * Fixing up superposition * [compat] entries * trying to speed things up * bugfixes * logdensity_rel tests * logdensity_rel tests * drop qualifier, and add a test * more tests * type constraint in "logdensityof(μ::AbstractMeasure, x)" (was piracy, oops) * add some docs * docs * docs * typo * moar speed * don't export Test * some more updates * logdensity_rel for products * `kleisli` docs * update instance_type * instance_type => Core.Typeof * `powermeasure` bug fix * fix logdensity_rel bug * get `commonbase` to take x type into account * test powers * commonbase docstring * deprecate instance_type * avoid breakage * switch || terms * @ifelse macro * simplify logdensity_rel * give up on this @ifelse business * bump version
1 parent d1c8540 commit 2ff6d3b

File tree

11 files changed

+120
-73
lines changed

11 files changed

+120
-73
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.6.1"
4+
version = "0.7.1"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -13,12 +13,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1414
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
1515
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
16+
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1617
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1920
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2021
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
21-
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2222
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2323
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2424

@@ -31,10 +31,10 @@ IfElse = "0.1"
3131
LogExpFunctions = "0.3"
3232
LogarithmicNumbers = "1"
3333
MappedArrays = "0.4"
34+
NaNMath = "0.3, 1"
3435
PrettyPrinting = "0.3, 0.4"
3536
Reexport = "1"
3637
Static = "0.5, 0.6"
37-
StatsFuns = "0.9"
3838
Tricks = "0.1"
3939
julia = "1.3"
4040

src/MeasureBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ gentype(μ::AbstractMeasure) = typeof(testvalue(μ))
6464

6565
# gentype(μ::AbstractMeasure) = gentype(basemeasure(μ))
6666

67+
using NaNMath
6768
using LogExpFunctions: logsumexp
6869

70+
@deprecate instance_type(x) Core.Typeof(x) false
71+
6972
"""
7073
`logdensity_def` is the standard way to define a log-density for a new measure.
7174
Note that this definition does not include checking for membership in the

src/combinators/product.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ struct ProductMeasure{M} <: AbstractProductMeasure
6565
marginals::M
6666
end
6767

68+
@inline function logdensity_rel::ProductMeasure, ν::ProductMeasure, x)
69+
mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x)
70+
end
71+
6872
function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple}
6973
Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep="")
7074
end

src/combinators/smart-constructors.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@ end
1414
###############################################################################
1515
# PowerMeaure
1616

17-
function powermeasure::WeightedMeasure, dims::NTuple{N,I}) where {N,I}
17+
function powermeasure::WeightedMeasure, dims::NTuple{N,I}) where {N,I<:AbstractArray}
1818
k = mapreduce(length, *, dims) * μ.logweight
1919
return weightedmeasure(k, μ.base^dims)
2020
end
2121

22+
function powermeasure::WeightedMeasure, dims::NTuple{N,I}) where {N,I}
23+
k = prod(dims) * μ.logweight
24+
return weightedmeasure(k, μ.base^dims)
25+
end
26+
2227
###############################################################################
2328
# ProductMeasure
2429

@@ -84,7 +89,7 @@ function superpose(μ::T, ν::T) where {T}
8489
end
8590
end
8691

87-
function superpose::AbstractMeasure, ν::AbstractMeasure)
92+
function superpose(μ, ν)
8893
components = (μ, ν)
8994
superpose(components)
9095
end

src/combinators/superpose.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ testvalue(μ::SuperpositionMeasure) = testvalue(first(μ.components))
5454
# end
5555

5656
function Base.:+::AbstractMeasure, ν::AbstractMeasure)
57-
components = (μ, ν)
58-
superpose(components)
57+
superpose(μ, ν)
5958
end
6059

6160
using LogarithmicNumbers
@@ -64,8 +63,10 @@ oneplus(x::ULogarithmic) = exp(ULogarithmic, log1pexp(x.log))
6463

6564
@inline function density_def(s::SuperpositionMeasure{Tuple{A,B}}, x) where {A,B}
6665
(μ, ν) = s.components
66+
6767
insupport(μ, x) || return exp(ULogarithmic, logdensity_def(ν, x))
6868
insupport(ν, x) || return exp(ULogarithmic, logdensity_def(μ, x))
69+
6970
α = basemeasure(μ)
7071
β = basemeasure(ν)
7172
dμ_dα = exp(ULogarithmic, logdensity_def(μ, x))
@@ -75,7 +76,7 @@ oneplus(x::ULogarithmic) = exp(ULogarithmic, log1pexp(x.log))
7576
return dμ_dα / oneplus(dβ_dα) + dν_dβ / oneplus(dα_dβ)
7677
end
7778

78-
using StatsFuns
79+
using LogExpFunctions
7980

8081
@inline function logdensity_def::T, ν::T, x::Any) where T<:(SuperpositionMeasure{Tuple{A, B}} where {A, B})
8182
if μ === ν
@@ -87,6 +88,7 @@ end
8788

8889
@inline function logdensity_def(s::SuperpositionMeasure{Tuple{A,B}}, β, x) where {A,B}
8990
(μ, ν) = s.components
91+
9092
insupport(μ, x) || return logdensity_rel(ν, β, x)
9193
insupport(ν, x) || return logdensity_rel(μ, β, x)
9294
return logaddexp(logdensity_rel(μ, β, x), logdensity_rel(ν, β, x))
@@ -105,7 +107,7 @@ end
105107

106108
@inline logdensity_def(s::SuperpositionMeasure, x) = log(density_def(s, x))
107109

108-
basemeasure::SuperpositionMeasure) = superpose(map(basemeasure, μ.components))
110+
basemeasure::SuperpositionMeasure) = superpose(map(basemeasure, μ.components)...)
109111

110112
# TODO: Fix `rand` method (this one is wrong)
111113
# function Base.rand(μ::SuperpositionMeasure{X,N}) where {X,N}

src/density.jl

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,11 @@ known to be in the support of both, it can be more efficient to call
174174
`unsafe_logdensity_rel`.
175175
"""
176176
@inline function logdensity_rel::M, ν::N, x::X) where {M,N,X}
177-
T = unstatic(float(promote_type(return_type(logdensity_def, (μ, x)), return_type(logdensity_def, (ν, x)))))
178-
insupport(μ, x) || begin
179-
insupport(ν, x) || return convert(T, NaN)
180-
return convert(T, -Inf)
181-
end
182-
insupport(ν, x) || return convert(T, Inf)
177+
T = unstatic(promote_type(return_type(logdensity_def, (μ, x)), return_type(logdensity_def, (ν, x))))
178+
inμ = insupport(μ, x)
179+
inν = insupport(ν, x)
180+
inμ || return convert(T, ifelse(inν, -Inf, NaN))
181+
inν || return convert(T, Inf)
183182

184183
return unsafe_logdensity_rel(μ, ν, x)
185184
end
@@ -198,7 +197,24 @@ See also `logdensity_rel`.
198197
end
199198
μs = basemeasure_sequence(μ)
200199
νs = basemeasure_sequence(ν)
201-
return _logdensity_rel(μs, νs, x)
200+
cb = commonbase(μs, νs, X)
201+
# _logdensity_rel(μ, ν)
202+
isnothing(cb) && begin
203+
μ = μs[end]
204+
ν = νs[end]
205+
@warn """
206+
No common base measure for
207+
208+
and
209+
210+
211+
Returning a relative log-density of NaN. If this is incorrect, add a
212+
three-argument method
213+
logdensity_def(, , x)
214+
"""
215+
return NaN
216+
end
217+
return _logdensity_rel(μs, νs, cb, x)
202218
end
203219

204220
# Note that this method assumes `μ` and `ν` to have the same type
@@ -210,44 +226,29 @@ function logdensity_def(μ::T, ν::T, x) where {T}
210226
end
211227
end
212228

213-
@generated function _logdensity_rel(μs::Tμ, νs::Tν, x::X) where {Tμ, Tν, X}
229+
@generated function _logdensity_rel(μs::Tμ, νs::Tν, ::Tuple{StaticInt{M},StaticInt{N}}, x::X) where {Tμ, Tν,M,N,X}
214230
= schema(Tμ)
215231
= schema(Tν)
216232

217233
q = quote
218234
$(Expr(:meta, :inline))
235+
= logdensity_def(μs[$M], νs[$N], x)
219236
end
220237

221-
for it in Iterators.product(enumerate(sμ), enumerate(sν))
222-
((nμ, μtype), (nν, νtype)) = it
223-
if static_hasmethod(logdensity_def, Tuple{μtype, νtype, X})
224-
push!(q.args, :(ℓ = logdensity_def(μs[$nμ], νs[$nν], x)))
225-
for i in 1:-1
226-
push!(q.args, :(ℓ += logdensity_def(μs[$i], x)))
227-
end
228-
for j in 1:-1
229-
push!(q.args, :(ℓ -= logdensity_def(νs[$j], x)))
230-
end
231-
232-
return q
233-
end
238+
for i in 1:M-1
239+
push!(q.args, :(Δℓ = logdensity_def(μs[$i], x)))
240+
# push!(q.args, :(println("Adding", Δℓ)))
241+
push!(q.args, :(ℓ += Δℓ))
234242
end
235243

236-
return quote
237-
μ = μs[end]
238-
ν = νs[end]
239-
@warn """
240-
No common base measure for
241-
242-
and
243-
244-
245-
Returning a relative log-density of NaN. If this is incorrect, add a
246-
three-argument method
247-
logdensity_def(μ, ν, x)
248-
"""
249-
NaN
244+
for j in 1:N-1
245+
push!(q.args, :(Δℓ = logdensity_def(νs[$j], x)))
246+
# push!(q.args, :(println("Subtracting", Δℓ)))
247+
push!(q.args, :(ℓ -= Δℓ))
250248
end
249+
250+
push!(q.args, :(return ℓ))
251+
return q
251252
end
252253

253254
export densityof

src/interface.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ using Reexport
55
@reexport using MeasureBase
66

77
using MeasureBase:basemeasure_depth, proxy
8-
using MeasureBase: insupport
8+
using MeasureBase: insupport, basemeasure_sequence, commonbase
99

1010
export test_interface
1111
export basemeasure_depth
1212
export proxy
1313
export insupport
14+
export basemeasure_sequence
15+
export commonbase
1416

1517
using Test
1618

src/kleisli.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,32 @@ end
1515
kleisli(f, M)
1616
kleisli((f1, f2, ...), M)
1717
18-
A kleisli `κ = kleisli(f, m)` returns a wrapper around
19-
a function `f` giving the parameters for a measure of type `M`,
20-
such that `κ(x) = M(f(x)...)`
21-
respective `κ(x) = M(f1(x), f2(x), ...)`
18+
`kleisli` was originally called `kernel`, as in a *Markov kernel*. Such a kernel
19+
can be considered to map each value in its domain to a probability measure.
20+
21+
In the context of MeasureTheory, the codomain is not required to be a
22+
*probability* measure; any measure will do. This makes "Markov" not really fit,
23+
since the map need not be Markovian.
24+
25+
This leaves us with "kernel", which can mean too wide a range of things to be
26+
useful in such a general context as measure theory. See for example
27+
https://github.com/JuliaGaussianProcesses/KernelFunctions.jl for one common use
28+
of this term.
29+
30+
We solve this problem by changing to a term from an even more general context.
31+
In category theory, a *Kleisli arrow* is a function taking monadic values.
32+
Since measures comprise a monad, our use is a special case of this.
33+
34+
A kleisli `κ = kleisli(f, m)` returns a wrapper around a function `f` giving the
35+
parameters for a measure of type `M`, such that `κ(x) = M(f(x)...)` respective
36+
`κ(x) = M(f1(x), f2(x), ...)`
2237
2338
If the argument is a named tuple `(;a=f1, b=f1)`, `κ(x)` is defined as
2439
`M(;a=f(x),b=g(x))`.
2540
2641
# Reference
2742
28-
* https://en.wikipedia.org/wiki/Markov_kleisli
43+
* https://en.wikipedia.org/wiki/Markov_kernel
2944
"""
3045
function kleisli end
3146

src/primitives/counting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ insupport(::CountingMeasure, x) = true
99
struct Counting{T} <: AbstractMeasure
1010
support::T
1111

12-
Counting(supp) = new{instance_type(supp)}(supp)
12+
Counting(supp) = new{Core.Typeof(supp)}(supp)
1313
end
1414

1515
function logdensity_def::Counting, x)

src/utils.jl

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ end
5252
return getfield(T, :instance)::T
5353
end
5454

55-
# See https://github.com/cscherrer/KeywordCalls.jl/issues/22
56-
@inline instance_type(f::F) where {F} = F
57-
@inline instance_type(T::UnionAll) = Type{T}
58-
@inline instance_type(T::DataType) = Type{T}
59-
6055
export basemeasure_depth
6156

6257
@inline function basemeasure_depth::M) where {M}
@@ -70,9 +65,12 @@ export basemeasure_depth
7065
return static(10)
7166
end
7267

68+
"""
69+
basemeasure_sequence(m)
7370
74-
export basemeasure_sequence
75-
71+
Construct the longest `Tuple` starting with `m` having each term as the base
72+
measure of the previous term, and with no repeated entries.
73+
"""
7674
@inline function basemeasure_sequence::M) where {M}
7775
b_1 = μ
7876
done = false
@@ -86,21 +84,32 @@ export basemeasure_sequence
8684
return filter(!isnothing, Base.Cartesian.@ntuple 10 b)
8785
end
8886

89-
# @inline function basemeasure_depth(μ::M) where {M}
90-
# return basemeasure_depth(μ, basemeasure(μ), static(0))
91-
# end
92-
93-
# @inline function basemeasure_depth(μ::M, β::M, s::StaticInt{N}) where {M,N}
94-
# s
95-
# end
96-
97-
# @generated function basemeasure_depth(μ::M, β::B, ::StaticInt{N}) where {M,B,N}
98-
# s = Expr(:call, Expr(:curly, :StaticInt, N + 1))
99-
# quote
100-
# $(Expr(:meta, :inline))
101-
# basemeasure_depth(β, basemeasure(β), $s)
102-
# end
103-
# end
87+
commonbase(μ, ν) = commonbase(μ, ν, Any)
88+
89+
"""
90+
commonbase(μ, ν, T) -> Tuple{StaticInt{i}, StaticInt{j}}
91+
92+
Find minimal (with respect to their sum) `i` and `j` such that there is a method
93+
94+
logdensity_def(basemeasure_sequence(μ)[i], basemeasure_sequence(ν)[j], ::T)
95+
96+
This is used in `logdensity_rel` to help make that function efficient.
97+
"""
98+
@inline function commonbase(μ, ν, ::Type{T}) where {T}
99+
return commonbase(basemeasure_sequence(μ), basemeasure_sequence(ν), T)
100+
end
101+
102+
@generated function commonbase::M, ν::N, ::Type{T}) where {M<:Tuple,N<:Tuple,T}
103+
m = schema(M)
104+
n = schema(N)
105+
106+
sols = Iterators.filter(((i,j),) -> static_hasmethod(logdensity_def, Tuple{m[i], n[j], T}), Iterators.product(1:length(m), 1:length(n)))
107+
isempty(sols) && return :(nothing)
108+
minsol = static.(argmin(((i,j),) -> i+j, sols))
109+
quote
110+
$minsol
111+
end
112+
end
104113

105114
mymap(f, gen::Base.Generator) = mymap(f gen.f, gen.iter)
106115
mymap(f, inds...) = Iterators.map(f, inds...)

0 commit comments

Comments
 (0)