Skip to content

Commit 5b939db

Browse files
authored
cleanup (#39)
- Get `JET.report_package(MeasureBase)` to pass - Drop the `partialstatic` stuff, fun idea but didn't seem to help much - Add dependencies for `LogarithmicNumbers` and `StatsFuns` - More docstrings - Move `ConditionalMeasure` (`|`) here from MeasureTheory - Improvements for `superpose` - Some minor tidying up
1 parent d037aff commit 5b939db

File tree

15 files changed

+338
-231
lines changed

15 files changed

+338
-231
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -11,12 +11,14 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1111
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
14+
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
1415
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
1516
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1819
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1920
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
21+
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2022
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2123
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2224

@@ -27,10 +29,12 @@ DensityInterface = "0.4"
2729
FillArrays = "0.12, 0.13"
2830
IfElse = "0.1"
2931
LogExpFunctions = "0.3"
32+
LogarithmicNumbers = "1"
3033
MappedArrays = "0.4"
3134
PrettyPrinting = "0.3, 0.4"
3235
Reexport = "1"
3336
Static = "0.5, 0.6"
37+
StatsFuns = "0.9"
3438
Tricks = "0.1"
3539
julia = "1.3"
3640

src/MeasureBase.jl

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@ export rebase
2828

2929
export AbstractMeasure
3030

31+
import IfElse: ifelse
32+
export logdensity_def
33+
export basemeasure
34+
export basekleisli
35+
36+
"""
37+
inssupport(m, x)
38+
insupport(m)
39+
40+
`insupport(m,x)` computes whether `x` is in the support of `m`.
41+
42+
`insupport(m)` returns a function, and satisfies
43+
44+
insupport(m)(x) == insupport(m, x)
45+
"""
46+
function insupport end
47+
48+
export insupport
49+
3150
abstract type AbstractMeasure end
3251

3352
using Static: @constprop
@@ -45,28 +64,35 @@ gentype(μ::AbstractMeasure) = typeof(testvalue(μ))
4564

4665
# gentype(μ::AbstractMeasure) = gentype(basemeasure(μ))
4766

48-
import IfElse: ifelse
49-
export logdensity_def
50-
export basemeasure
51-
export basekleisli
52-
5367
using LogExpFunctions: logsumexp
5468

5569
"""
56-
logdensity_def(μ::AbstractMeasure{X}, x::X)
70+
`logdensity_def` is the standard way to define a log-density for a new measure.
71+
Note that this definition does not include checking for membership in the
72+
support; this is instead checked using `insupport`. `logdensity_def` is
73+
a low-level function, and should typically not be called directly. See
74+
`logdensityof` for more information and other alternatives.
75+
76+
---
77+
78+
logdensity_def(m, x)
5779
58-
Compute the logdensity of the measure μ at the point x. This is the standard way
59-
to define `logdensity` for a new measure. the base measure is implicit here, and
60-
is understood to be `basemeasure(μ)`.
80+
Compute the log-density of the measure m at the point `x`, relative to
81+
`basemeasure(m)`, and assuming `insupport(m, x)`.
6182
62-
Methods for computing density relative to other measures will be
83+
---
84+
85+
logdensity_def(m1, m2, x)
86+
87+
Compute the log-density of `m1` relative to `m2` at the point `x`, assuming
88+
`insupport(m1, x)` and `insupport(m2, x)`.
6389
"""
6490
function logdensity_def end
6591

6692
using Compat
6793

94+
include("schema.jl")
6895
include("splat.jl")
69-
include("partial-static.jl")
7096
include("proxies.jl")
7197
include("kleisli.jl")
7298
include("parameterized.jl")
@@ -81,6 +107,7 @@ include("primitives/lebesgue.jl")
81107
include("primitives/dirac.jl")
82108
include("primitives/trivial.jl")
83109

110+
include("combinators/conditional.jl")
84111
include("combinators/bind.jl")
85112
include("combinators/transformedmeasure.jl")
86113
include("combinators/factoredbase.jl")
@@ -97,21 +124,6 @@ include("combinators/smart-constructors.jl")
97124
include("rand.jl")
98125

99126
include("density.jl")
100-
module Interface
101-
102-
using Reexport
103-
using MeasureBase
104-
using MeasureBase:basemeasure_depth, proxy
105-
using MeasureBase: insupport
106-
@reexport using Test
107-
108-
export test_interface
109-
export basemeasure_depth
110-
export proxy
111-
export insupport
112-
113-
include("interface.jl")
114-
end # module Interface
115127

116128
using .Interface
117129

src/combinators/conditional.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
struct ConditionalMeasure{M,C} <: AbstractMeasure
2+
parent::M
3+
constraint::C
4+
end
5+
6+
"""
7+
(m::AbstractMeasure) | constraint
8+
9+
Return a new measure by constraining `m` to satisfy `constraint`.
10+
11+
Note that the form of `constraint` will vary depending on the structure of a
12+
given measure. For example, a measure over `NamedTuple`s may allow `NamedTuple`
13+
constraints, while another may require `constraint` to be a predicate or a
14+
function returning a real number (in which case the constraint could be
15+
considered as the zero-set of that function).
16+
17+
At the time of this writing, invariants required of this function are not yet
18+
settled. Specifically, there's the question of normalization. It's common for
19+
conditional distributions to be normalized, but this can often not be expressed
20+
in closed form, and can be very expensive to compute. For more general measures,
21+
the notion of normalization may not even make sense.
22+
23+
Because of this, this interface is not yet stable, and users should expect
24+
upcoming changes.
25+
"""
26+
Base.:|::AbstractMeasure, constraint) = ConditionalMeasure(μ, constraint)
27+
28+
@inline basemeasure(cm::ConditionalMeasure) = basemeasure(cm.parent) | cm.constraint

src/combinators/half.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ isnonnegative(x) = x ≥ 0.0
1515

1616
@inline function basemeasure::Half)
1717
const= static(logtwo)
18-
varℓ = Returns(static(0.0))
18+
varℓ = Returns(0.0)
1919
base = basemeasure(unhalf(μ))
2020
FactoredBase(constℓ, varℓ, base)
2121
end

src/combinators/product.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,13 @@ end
100100
end
101101

102102
vals = map(x -> Expr(:(=), x,x), N)
103-
push!(q, Expr(:tuple, vals...))
103+
push!(q.args, Expr(:tuple, vals...))
104104
return q
105105
end
106106

107107
function basemeasure::ProductMeasure{Base.Generator{I,F}}) where {I,F}
108108
mar = marginals(μ)
109-
T = Core.Compiler.return_type(mar.f, Tuple{_eltype(mar.iter)})
109+
T = Core.Compiler.return_type(mar.f, Tuple{eltype(mar.iter)})
110110
B = Core.Compiler.return_type(basemeasure, Tuple{T})
111111
_basemeasure(μ, B, static(Base.issingletontype(B)))
112112
end

src/combinators/smart-constructors.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ superpose(a::AbstractArray) = SuperpositionMeasure(a)
7676
superpose(t::Tuple) = SuperpositionMeasure(t)
7777
superpose(nt::NamedTuple) = SuperpositionMeasure(nt)
7878

79+
function superpose::T, ν::T) where {T}
80+
if μ==ν
81+
return weightedmeasure(static(logtwo), μ)
82+
else
83+
return superpose((μ, ν))
84+
end
85+
end
86+
7987
function superpose::AbstractMeasure, ν::AbstractMeasure)
8088
components = (μ, ν)
8189
superpose(components)

src/combinators/superpose.jl

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,53 @@ function Base.:+(μ::AbstractMeasure, ν::AbstractMeasure)
5858
superpose(components)
5959
end
6060

61-
function logdensity_def::SuperpositionMeasure, x)
62-
logsumexp((logdensity_def(m, x) for m in μ.components))
61+
using LogarithmicNumbers
62+
63+
oneplus(x::ULogarithmic) = exp(ULogarithmic, log1pexp(x.log))
64+
65+
@inline function density_def(s::SuperpositionMeasure{Tuple{A,B}}, x) where {A,B}
66+
(μ, ν) = s.components
67+
insupport(μ, x) || return exp(ULogarithmic, logdensity_def(ν, x))
68+
insupport(ν, x) || return exp(ULogarithmic, logdensity_def(μ, x))
69+
α = basemeasure(μ)
70+
β = basemeasure(ν)
71+
dμ_dα = exp(ULogarithmic, logdensity_def(μ, x))
72+
dν_dβ = exp(ULogarithmic, logdensity_def(ν, x))
73+
dα_dβ = exp(ULogarithmic, logdensity_rel(α, β, x))
74+
dβ_dα = inv(dα_dβ)
75+
return dμ_dα / oneplus(dβ_dα) + dν_dβ / oneplus(dα_dβ)
6376
end
6477

78+
using StatsFuns
79+
80+
@inline function logdensity_def::T, ν::T, x::Any) where T<:(SuperpositionMeasure{Tuple{A, B}} where {A, B})
81+
if μ === ν
82+
return zero(return_type(logdensity_def, (μ, x)))
83+
else
84+
return logdensity_def(μ,x) - logdensity_def(ν, x)
85+
end
86+
end
87+
88+
@inline function logdensity_def(s::SuperpositionMeasure{Tuple{A,B}}, β, x) where {A,B}
89+
(μ, ν) = s.components
90+
insupport(μ, x) || return logdensity_rel(ν, β, x)
91+
insupport(ν, x) || return logdensity_rel(μ, β, x)
92+
return logaddexp(logdensity_rel(μ, β, x), logdensity_rel(ν, β, x))
93+
end
94+
95+
@inline function logdensity_def(s::SuperpositionMeasure{Tuple{A,B}}, β::SuperpositionMeasure, x) where {A,B}
96+
(μ, ν) = s.components
97+
insupport(μ, x) || return logdensity_rel(ν, β, x)
98+
insupport(ν, x) || return logdensity_rel(μ, β, x)
99+
return logaddexp(logdensity_rel(μ, β, x), logdensity_rel(ν, β, x))
100+
end
101+
102+
@inline function logdensity_def(s, β::SuperpositionMeasure{Tuple{A,B}}, x) where {A,B}
103+
-logdensity_def(β, s, x)
104+
end
105+
106+
@inline logdensity_def(s::SuperpositionMeasure, x) = log(density_def(s, x))
107+
65108
basemeasure::SuperpositionMeasure) = superpose(map(basemeasure, μ.components))
66109

67110
# TODO: Fix `rand` method (this one is wrong)

0 commit comments

Comments
 (0)