Skip to content

Commit d9fdb6a

Browse files
authored
Dev (#51)
* PrettyPrinting + tests * speed up `rootmeasure` * oops didn't mean to include that * tile(::FactoredBase) * Update Half and FactoredBase * drop exp.jl * simplify basemeasure * drop old `include` * drop redundant method * update compat Returns * add AbstractDensity * Move Affine to MeasureTheory * drop some old `For` code * update counting measure * add testvalue(::Type{T}) * bugfix * some dispatch adjustments * simplify show * formatting * transformed measures * ZeroSets * using LinearAlgebra, Statistics * working on MeasureTheory tests * updates * updates * typo * update default to mimic Base * fix tile(::Lebesgue) * Add test_interface function * small doc update * DensityKind(::Likelihood) * fixing show(::Likelihood) * adding some docs * update `rand` method * drop old integration code * add `rebase` * export rebase * law for ⊗ * typo * Make Likelihood more flexible * update kernel methods * add Likelihood method (avoid stack overflow) * refactoring * compat * Maybe Comat just works? * refactoring * some new stuff * working on tests * update powermeasure combinator * bugfix * comment out debugging lines * more refactoring * test @inferred basemeasure_depth(μ) * drop `constructor` (just use ConstructionBase.constructorof`) * debugging * update help * update interface * make tests harder * fixes * Dirac bugfix * formatting * improve type inference * working on type inference * update interface * udpates * get test passing * @test !isabstracttype(typejoin(...)) * work on show methods * update CI * remove old code * update productmeasure * prettyprinting stuff * Drop te @constprop :aggressive stuff (maybe don't need it?) * nerline * dropping some old code * update tbasemeasure_type(::PowerMeasure) * moar tests * update SpikeMixture * update superpose type parameter name * drop old tests * func_string * more updates * getting closer * almost there! * generated function for type stability * tests passing! * newline * more fixes * exports and bugfix * insupport(μ::Counting{T}, x) where {T<:Type} * working on MeasureTheory tests * MeasureTheory tests passing * drop some old code * inlining * improve inference * update `tile(::For)` * tighten down infrerence * update basemeasure(::For) for generators * loosen type bound on instance_type * drop debugging code * small update for Likelihood, and a test * fixing up likelihoods * improve `basemeasure_depth` dispatch * still some trouble with inferred basemeasure_depth * clean up `For` dispatch * simplify _logdensityof * optimize for Returns{True} case * rework basemeasure_depth * aggressive tests passing!! * drop type-level stuff * drop help * license * affero * copyright notice * merge * Drop Create Commons * cleanup after merge * update support computations * insupport(d::SuperpositionMeasure, x) * dorp ParamWeighted * insupport(d::FactoredBase, x) * export unsafe_logdensityof * call promote_type instead of promote_rule * logdensity_def for named tuple product measures * type annotation for now * debugging * drop shows * speed up mapped arrays * throw an error for `Union{}` types * MT tests passing * updates * get tests passing * MIT license for MeasureBase * bump version * cleanup * spacing * Move ConditionalMeasure to MeasureBase * add LogarithmicNumbers * export basemeasure_sequence * update superpose * fix logdensity_rel * remove FIXME (it's fixed!!) * logdensityof(d::Density, x) * simplify insupport(::Lebesgue, ::Real) * clean up * assume insupport yields Bool * change logdensity_rel fall-through to warning and return NaN * update logdensity_rel * drop old code * fix warning * export logdensity_rel * logdensity_def(μ::Dirac, ν::Dirac, x) * logdensity_def methods * drop `static` * ]add StatsFuns * Fixing up superposition * [compat] entries * trying to speed things up * bugfixes * logdensity_rel tests * logdensity_rel tests * drop qualifier, and add a test * more tests * type constraint in "logdensityof(μ::AbstractMeasure, x)" (was piracy, oops) * add some docs * docs * docs * typo * moar speed * don't export Test * some more updates * logdensity_rel for products * `kleisli` docs * update instance_type * instance_type => Core.Typeof * `powermeasure` bug fix * fix logdensity_rel bug * get `commonbase` to take x type into account * test powers * commonbase docstring * deprecate instance_type * avoid breakage * switch || terms * @ifelse macro * simplify logdensity_rel * give up on this @ifelse business * bump version * working on likelihoods * update likelihood * powerweightedmeasure * powerweighted update * more powerweighted methods * bugfix * dropFactoredBase * drop FactoredBase * (::ProductMeasure) | constraint * update conditional measure * update Dirac * move conditional.jl down in the `include`s * Kleisli => TransitionKernel * simplify logdensity_def(::PowerMeasure, x) * rename kleisli.jl to kernel.jl * update Dirac tests * update Half * get tests passing * update kernel * Update Project.toml * no call-site inlining * restrict single-arg `kernel` to <:ParameterizedMeasure * export log_likelihood_ratio * Drop DensityKind(::Likelihood), at least for now * isfinite(x) instead of x>-Inf * add `condition` constructor * EOF newline * simplify logdensity_def for power measures * finishing up * updates * kernel stuff * kernel stuff * update showe methods * ass a TODO * use `dot` instead of `sum` * drop old code * typo * formatting * cleanup * kernel updates * uncomment * bugfix * drop old code * pretty printing * exports, cleanup * drop old for.jl * Make DensityKind(::AbstractLikelihood) = IsDensity() * update Compat version
1 parent bf06bbd commit d9fdb6a

File tree

10 files changed

+148
-76
lines changed

10 files changed

+148
-76
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2323
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2424

2525
[compat]
26-
Compat = "3.35"
26+
Compat = "3.35, 4"
2727
ConstructionBase = "1.3"
2828
DensityInterface = "0.4"
2929
FillArrays = "0.12, 0.13"

src/combinators/for.jl

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/combinators/likelihood.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ struct Likelihood{K,X} <: AbstractLikelihood
121121
Likelihood(μ, x) = Likelihood(kernel(μ), x)
122122
end
123123

124+
DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity()
125+
124126
function Pretty.quoteof(ℓ::Likelihood)
125127
k = Pretty.quoteof(ℓ.k)
126128
x = Pretty.quoteof(ℓ.x)
@@ -132,11 +134,17 @@ function Base.show(io::IO, ℓ::Likelihood)
132134
Pretty.pprint(io, ℓ)
133135
end
134136

135-
# @inline function logdensity_def(ℓ::Likelihood, p)
136-
# return logdensity_def(ℓ.k(p), ℓ.x)
137-
# end
137+
insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x)
138+
139+
@inline function logdensityof(ℓ::AbstractLikelihood, p)
140+
result = dynamic(unsafe_logdensityof(ℓ, p))
141+
ifelse(insupport(ℓ, p) == true, result, oftype(result, -Inf))
142+
end
143+
144+
@inline function unsafe_logdensityof(ℓ::AbstractLikelihood, p)
145+
return unsafe_logdensityof(ℓ.k(p), ℓ.x)
146+
end
138147

139-
# basemeasure(ℓ::Likelihood, p) = basemeasure(ℓ.k(p), ℓ.x)
140148

141149
# basemeasure(ℓ::Likelihood) = @error "Likelihood requires local base measure"
142150

src/combinators/pointwise.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,20 @@ function Pretty.tile(d::PointwiseProductMeasure)
1111
Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep = "")
1212
end
1313

14-
(μ, ℓ) = pointwiseproduct(μ, ℓ)
14+
(prior, ℓ) = pointwiseproduct(prior, ℓ)
15+
16+
@inbounds function insupport(d::PointwiseProductMeasure, p)
17+
prior, ℓ = d
18+
insupport(prior, p) && insupport(ℓ, p)
19+
end
1520

1621
@inline function logdensity_def(d::PointwiseProductMeasure, p)
17-
μ, ℓ = d
18-
logdensityof(ℓ.k(p), ℓ.x)
22+
prior, ℓ = d
23+
unsafe_logdensityof(ℓ, p)
1924
end
2025

26+
basemeasure(d::PointwiseProductMeasure) = d.prior
27+
2128
function gentype(d::PointwiseProductMeasure)
2229
gentype(d.prior)
2330
end
24-
25-
@inbounds function insupport(d::PointwiseProductMeasure, p)
26-
μ, ℓ = d
27-
insupport(μ, p) && insupport(ℓ.k(p), ℓ.x)
28-
end
29-
30-
basemeasure(d::PointwiseProductMeasure, x) = d.prior
31-
32-
basemeasure(d::PointwiseProductMeasure) = basemeasure(d.prior)

src/combinators/smart-constructors.jl

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ productmeasure(tup::Tuple) = ProductMeasure(tup)
5454
productmeasure(f, param_maps, pars) = ProductMeasure(kernel(f, param_maps), pars)
5555

5656
function productmeasure(k::ParameterizedTransitionKernel, pars)
57-
productmeasure(k.f, k.param_maps, pars)
57+
productmeasure(k.suff, k.param_maps, pars)
5858
end
5959

6060
function productmeasure(f::Returns{W}, ::typeof(identity), pars) where {W<:WeightedMeasure}
@@ -117,8 +117,6 @@ end
117117
###############################################################################
118118
# TransitionKernel
119119

120-
kernel(f, pars::NamedTuple) = ParameterizedTransitionKernel(f, pars)
121-
122120
# kernel(Normal(μ=2))
123121
function kernel::M) where {M<:ParameterizedMeasure}
124122
kernel(M)
@@ -128,16 +126,57 @@ function kernel(d::PowerMeasure)
128126
Base.Fix2(powermeasure, d.axes) kernel(d.parent)
129127
end
130128

131-
# kernel(Normal{(:μ,), Tuple{Int64}})
132-
function kernel(::Type{M}) where {M<:AbstractMeasure}
133-
constructorof(M)
129+
function kernel(f)
130+
T = Core.Compiler.return_type(f, Tuple{Any} )
131+
_kernel(f, T)
132+
end
133+
134+
function _kernel(f, ::Type{T}) where {T}
135+
GenericTransitionKernel(f)
136+
end
137+
138+
function _kernel(f, ::Type{P}) where {N,P<:ParameterizedMeasure{N}}
139+
k = length(N)
140+
C = constructorof(P)
141+
maps = ntuple(Val(k)) do i
142+
x -> @inbounds x[i]
143+
end
144+
145+
kernel(params f, C, NamedTuple{N}(maps))
134146
end
135147

136-
# kernel(::Type{P}, op::O) where {O, N, P<:ParameterizedMeasure{N}} = kernel{constructorof(P),O}(op)
148+
kernel(f::F, ::Type{M}; kwargs...) where {F<:Function,M} = kernel(f, M, NamedTuple(kwargs))
137149

138-
function kernel(::Type{M}; param_maps...) where {M}
139-
nt = NamedTuple(param_maps)
140-
kernel(M, nt)
150+
function kernel(f::F, ::Type{M}, nt::NamedTuple) where {F<:Function,M}
151+
ParameterizedTransitionKernel(M, f, nt)
141152
end
142153

143-
kernel(k::ParameterizedTransitionKernel) = k
154+
function kernel(f::F, ::Type{M}, ::NamedTuple{()}) where {F<:Function,M}
155+
T = Core.Compiler.return_type(f, Tuple{Any})
156+
_kernel(f, M, T)
157+
end
158+
159+
kernel(::Type{P}, nt::NamedTuple) where {P<:ParameterizedMeasure} = kernel(identity, P, nt)
160+
161+
kernel(::Type{T}; kwargs...) where {T} = kernel(T, NamedTuple(kwargs))
162+
163+
function kernel(::Type{M}, ::NamedTuple{()}) where {M}
164+
C = constructorof(M)
165+
TypedTransitionKernel(C, identity)
166+
end
167+
168+
function _kernel(f::F, ::Type{M}, ::Type{NT}) where {M,F,N,NT<:NamedTuple{N}}
169+
k = length(N)
170+
maps = ntuple(Val(k)) do i
171+
x -> @inbounds x[i]
172+
end
173+
174+
ParameterizedTransitionKernel(M, values f, NamedTuple{N}(maps))
175+
end
176+
177+
kernel(f::F; kwargs...) where {F<:Function} = kernel(f, NamedTuple(kwargs))
178+
179+
function kernel(f::F, nt::NamedTuple{()}) where {F<:Function}
180+
T = Core.Compiler.return_type(f, Tuple{Any})
181+
_kernel(f, T)
182+
end

src/combinators/weighted.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ _logweight(μ::WeightedMeasure) = μ.logweight
3131
basemeasure::AbstractWeightedMeasure) = μ.base
3232

3333
function Pretty.tile(d::WeightedMeasure)
34-
weight = round(exp(d.logweight), sigdigits = 4)
34+
weight = round(dynamic(exp(d.logweight)), sigdigits = 4)
3535
Pretty.pair_layout(Pretty.tile(weight), Pretty.tile(d.base), sep = " * ")
3636
end
3737

src/density.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ Define a new measure in terms of a log-density `f` over some measure `base`.
100100
"""
101101
∫exp(f::Function, μ) = (logfuncdensity(f), μ)
102102

103-
# TODO: `density` and `logdensity` functions for `DensityMeasure`
104103

105104
"""
106105
logdensityof(m::AbstractMeasure, x)

src/domains.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ projectto!(x, ::Simplex) = normalize!(x, 1)
133133
struct Sphere <: CodimOne end
134134

135135
function zeroset(::Sphere)
136-
f(x::AbstractArray{T}) where {T} = sum(xⱼ -> xⱼ^2, x) - one(T)
136+
f(x::AbstractArray{T}) where {T} = dot(x, x) - one(T)
137137
∇f(x::AbstractArray{T}) where {T} = x
138138
ZeroSet(f, ∇f)
139139
end

src/kernel.jl

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,59 @@
1-
# TODO: Dangerous to export this - let's not
1+
export AbstractTransitionKernel,
2+
GenericTransitionKernel, TypedTransitionKernel, ParameterizedTransitionKernel
3+
24
abstract type AbstractTransitionKernel <: AbstractMeasure end
35

4-
struct ParameterizedTransitionKernel{F,N,T} <: AbstractTransitionKernel
6+
struct GenericTransitionKernel{F} <: AbstractTransitionKernel
7+
f::F
8+
end
9+
10+
(k::GenericTransitionKernel)(x) = k.f(x)
11+
12+
struct TypedTransitionKernel{M,F} <: AbstractTransitionKernel
13+
m::M
514
f::F
15+
end
16+
17+
(k::TypedTransitionKernel)(x) = (k.m k.f)(x)
18+
struct ParameterizedTransitionKernel{M,S,N,T} <: AbstractTransitionKernel
19+
m::M
20+
suff::S
621
param_maps::NamedTuple{N,T}
722

823
function ParameterizedTransitionKernel(
9-
::Type{F},
24+
::Type{M},
25+
suff::S,
1026
param_maps::NamedTuple{N,T},
11-
) where {F,N,T}
12-
new{Type{F},N,T}(F, param_maps)
27+
) where {M,S,N,T}
28+
new{Type{M},S,N,T}(M, suff, param_maps)
1329
end
14-
function ParameterizedTransitionKernel(f::F, param_maps::NamedTuple{N,T}) where {F,N,T}
15-
new{F,N,T}(f, param_maps)
30+
function ParameterizedTransitionKernel(
31+
m::M,
32+
suff::S,
33+
param_maps::NamedTuple{N,T},
34+
) where {M,S,N,T}
35+
new{M,S,N,T}(m, suff, param_maps)
1636
end
1737
end
1838

1939
"""
20-
kernel(f, M)
21-
kernel((f1, f2, ...), M)
40+
A *kernel* is a function that returns a measure.
41+
42+
k1 = kernel() do x
43+
Normal(x, x^2)
44+
end
2245
23-
A kernel `κ = kernel(f, m)` returns a wrapper around a function `f` giving the
24-
parameters for a measure of type `M`, such that `κ(x) = M(f(x)...)` respective
25-
`κ(x) = M(f1(x), f2(x), ...)`
46+
k2 = kernel(Normal) do x
47+
(μ = x, σ = x^2)
48+
end
49+
50+
k3 = kernel(Normal; μ = identity, σ = abs2)
51+
52+
k4 = kernel(Normal; μ = first, σ = last) do x
53+
(x, x^2)
54+
end
2655
27-
If the argument is a named tuple `(;a=f1, b=f1)`, `κ(x)` is defined as
28-
`M(;a=f(x),b=g(x))`.
56+
x = randn(); k1(x) == k2(x) == k3(x) == k4(x)
2957
3058
This function is not exported, because "kernel" can have so many other meanings.
3159
See for example https://github.com/JuliaGaussianProcesses/KernelFunctions.jl for
@@ -37,23 +65,17 @@ another common use of this term.
3765
"""
3866
function kernel end
3967

40-
# kernel(Normal) do x
41-
# (μ=x,σ=x^2)
42-
# end
43-
44-
kernel(f, ::Type{M}) where {M} = kernel(M, f)
45-
4668
mapcall(t, x) = map(func -> func(x), t)
4769

48-
# (k::TransitionKernel{Type{P},<:Tuple})(x) where {P<:ParameterizedMeasure} = k.f(mapcall(k.param_maps, x)...)
70+
function (k::ParameterizedTransitionKernel)(x)
71+
s = k.suff(x)
72+
k.m(; mapcall(k.param_maps, s)...)
73+
end
4974

50-
(k::ParameterizedTransitionKernel)(x) = k.f(; mapcall(k.param_maps, x)...)
75+
(k::AbstractTransitionKernel)(x1, x2, xs...) = k((x1, x2, xs...))
5176

52-
(k::ParameterizedTransitionKernel)(x...) = k(x)
77+
(k::AbstractTransitionKernel)(; kwargs...) = k(NamedTuple(kwargs))
5378

54-
function (k::ParameterizedTransitionKernel)(x::Tuple)
55-
k.f(NamedTuple{k.param_maps}(x))
56-
end
5779

5880
"""
5981
For any `k::TransitionKernel`, `basekernel` is expected to satisfy
@@ -63,29 +85,31 @@ basekernel(k)(p) == (basemeasure ∘ k)(p)
6385
6486
The main purpose of `basekernel` is to make it efficient to compute
6587
```
66-
basemeasure(d::ProductMeasure) = productmeasure(basekernel(d.f), d.xs)
88+
basemeasure(d::ProductMeasure) == productmeasure(basekernel(d.f), d.xs)
6789
```
6890
"""
6991
function basekernel end
7092

7193
# TODO: Find a way to do better than this
7294
basekernel(f) = basemeasure f
7395

74-
basekernel(k::ParameterizedTransitionKernel) = kernel(basekernel(k.f), k.param_maps)
75-
7696
basekernel(f::Returns) = Returns(basemeasure(f.value))
7797

7898
function Base.show(io::IO, μ::AbstractTransitionKernel)
7999
io = IOContext(io, :compact => true)
80100
Pretty.pprint(io, μ)
81101
end
82102

83-
function Pretty.quoteof(k::ParameterizedTransitionKernel)
84-
qf = Pretty.quoteof(k.f)
85-
qg = Pretty.quoteof(k.param_maps)
86-
:(ParameterizedTransitionKernel($qf, $qg))
103+
function Pretty.tile(k::K) where {K<:AbstractTransitionKernel}
104+
Pretty.list_layout(
105+
Pretty.tile.([getproperty(k, p) for p in propertynames(k)]),
106+
prefix = nameof(constructorof(K)),
107+
)
87108
end
88109

110+
89111
const kleisli = kernel
90112

91113
export kleisli
114+
115+
kernel(k::AbstractTransitionKernel) = k

src/parameterized.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,24 @@ function Base.propertynames(μ::ParameterizedMeasure{N}) where {N}
1010
return N
1111
end
1212

13-
function Base.show(io::IO, μ::ParameterizedMeasure{()})
14-
io = IOContext(io, :compact => true)
15-
print(io, nameof(typeof(μ)), "()")
13+
function Pretty.tile(d::ParameterizedMeasure)
14+
result = Pretty.literal(nameof(typeof(d)))
15+
par = getfield(d, :par)
16+
result *= Pretty.literal(sprint(show, par; context = :compact => true))
17+
result
1618
end
1719

18-
function Base.show(io::IO, μ::ParameterizedMeasure{N}) where {N}
19-
io = IOContext(io, :compact => true)
20-
print(io, nameof(typeof(μ)))
21-
print(io, getfield(μ, :par))
20+
function Pretty.tile(d::ParameterizedMeasure{()})
21+
result = Pretty.literal(nameof(typeof(d)))
22+
par = getfield(d, :par)
23+
result *= Pretty.literal("()")
24+
result
2225
end
2326

2427
# Allow things like
2528
#
2629
# julia> Normal{(:μ,)}(2)
2730
# Normal(μ = 2,)
28-
#
29-
3031
function kernel(::Type{P}) where {N,P<:ParameterizedMeasure{N}}
3132
C = constructorof(P)
3233
_kernel(C, Val(N))
@@ -41,13 +42,17 @@ end
4142
C(NamedTuple{N,Tuple{T}}((arg,)))::C{N,Tuple{T}}
4243
end
4344

44-
f
45+
kernel(f)
4546
end
4647

4748
function (::Type{P})(nt::NamedTuple) where {N,P<:ParameterizedMeasure{N}}
4849
C = constructorof(P)
4950
arg = NamedTuple{N}(nt)
50-
return C(arg)
51+
return _parameterized(C, arg)
52+
end
53+
54+
function _parameterized(::Type{C}, arg::NamedTuple{N,T}) where {C,N,T}
55+
return C(arg)::C{N,T}
5156
end
5257

5358
function (::Type{P})(args...) where {N,P<:ParameterizedMeasure{N}}

0 commit comments

Comments
 (0)