Skip to content

Commit 73e112c

Browse files
authored
Refactoring (#31)
* some refactoring * pretty printing * some updates for resettablerng * deepcopy on iterate * Move ResettableRNG to MeasureTheory * Make Pretty a const * drop extra spaces * formatting * formatting * formatting * add DensityInterface dependency * faster `rootmeasure` * updates * primitive measure docs * get tests passing * cleaning up * formatting * drop some unneeded methods * update for DensityInterface * updating densities to DensityInterface approach * update domains * moving things around * updates to IntegerBounds * more domain mucking * fix some exports * update deps * update `using` * working on tests * Working toward tests passing * some refactoring * working on tests * cleaning up * get tests to pass * tests passing * update logdensity_def for pointwise product * PrettyPrinting + tests * speed up `rootmeasure` * oops didn't mean to include that * tile(::FactoredBase) * Update Half and FactoredBase * drop exp.jl * simplify basemeasure * drop old `include` * drop redundant method * update compat Returns * add AbstractDensity * Move Affine to MeasureTheory * drop some old `For` code * update counting measure * add testvalue(::Type{T}) * bugfix * some dispatch adjustments * simplify show * formatting * transformed measures * ZeroSets * using LinearAlgebra, Statistics * working on MeasureTheory tests * updates * updates * typo * update default to mimic Base * fix tile(::Lebesgue) * Add test_interface function * small doc update * DensityKind(::Likelihood) * fixing show(::Likelihood) * adding some docs * update `rand` method * drop old integration code * add `rebase` * export rebase * law for ⊗ * typo * Make Likelihood more flexible * update kernel methods * add Likelihood method (avoid stack overflow) * refactoring * compat * Maybe Comat just works? * refactoring * some new stuff * working on tests * update powermeasure combinator * bugfix * comment out debugging lines * more refactoring * test @inferred basemeasure_depth(μ) * drop `constructor` (just use ConstructionBase.constructorof`) * debugging * update help * update interface * make tests harder * fixes * Dirac bugfix * formatting * improve type inference * working on type inference * update interface * udpates * get test passing * @test !isabstracttype(typejoin(...)) * work on show methods * update CI * remove old code * update productmeasure * prettyprinting stuff * Drop te @constprop :aggressive stuff (maybe don't need it?) * nerline * dropping some old code * update tbasemeasure_type(::PowerMeasure) * moar tests * update SpikeMixture * update superpose type parameter name * drop old tests * func_string * more updates * getting closer * almost there! * generated function for type stability * tests passing! * newline * more fixes * exports and bugfix * insupport(μ::Counting{T}, x) where {T<:Type} * working on MeasureTheory tests * MeasureTheory tests passing * drop some old code * inlining * improve inference * update `tile(::For)` * tighten down infrerence * update basemeasure(::For) for generators * loosen type bound on instance_type * drop debugging code * small update for Likelihood, and a test * fixing up likelihoods * improve `basemeasure_depth` dispatch * still some trouble with inferred basemeasure_depth * clean up `For` dispatch * simplify _logdensityof * optimize for Returns{True} case * rework basemeasure_depth * aggressive tests passing!! * drop type-level stuff * drop help * merge * partial-static * cleanup * fix merge bug * fix typo * instance_type(T::DataType) * fine-tuning some logdensity_def methods * Constructor method for Counting * some adjustments to partialstatic * update Static.dynamic(x::PartialStatic) * back to actual tests * update logdensity_def(d::PowerMeasure{M}, x) * solve! * license * Move `solve` to MeasureTheory * speed up likelihoods and pointwise products * promote_rule for PartialStatic * some updates * logdensityof(ℓ::AbstractLikelihood, par) * fix test bug
1 parent 02cc7a2 commit 73e112c

16 files changed

+221
-66449
lines changed

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ function logdensity_def end
6464

6565
using Compat
6666

67+
include("partial-static.jl")
6768
include("proxies.jl")
6869
include("kleisli.jl")
6970
include("parameterized.jl")

src/combinators/for.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,42 @@ end
1717

1818
# For(f, gen::Base.Generator) = ProductMeasure(Base.Generator(f ∘ gen.f, gen.iter))
1919

20+
@inline function logdensity_def(d::For{T,F,I}, x::AbstractVector{X}) where {X,T,F,I<:Tuple{<:AbstractVector}}
21+
= zero(float(Core.Compiler.return_type(logdensity_def, Tuple{T,X})))
22+
@inbounds for j in eachindex(x)
23+
+= logdensity_def(d.f(j), x[j])
24+
end
25+
26+
end
27+
28+
function logdensity_def(d::For, x::AbstractVector)
29+
sum(eachindex(x)) do i
30+
@inbounds logdensity_def(d.f(getindex.(d.inds,i)...), x[i])
31+
end
32+
end
33+
34+
function logdensity_def(d::For{T,F,I}, x::AbstractArray{X}) where {T,F,I,X}
35+
= zero(float(Core.Compiler.return_type(logdensity_def, Tuple{T,X})))
36+
37+
@inbounds for j in CartesianIndices(x)
38+
i = (getindex(ind, j) for ind in d.inds)
39+
+= logdensity_def(d.f(i...), x[j])
40+
end
41+
42+
end
43+
2044
function logdensity_def(d::For{T,F,I}, x) where {N,T,F,I<:NTuple{N,<:Base.Generator}}
2145
sum(zip(x, d.inds...)) do (xⱼ, dⱼ...)
2246
logdensity_def(d.f(dⱼ...), xⱼ)
2347
end
2448
end
2549

50+
function logdensity_def(d::For{T,F,I}, x::AbstractVector) where {N,T,F,I<:NTuple{N,<:Base.Generator}}
51+
sum(zip(x, d.inds...)) do (xⱼ, dⱼ...)
52+
logdensity_def(d.f(dⱼ...), xⱼ)
53+
end
54+
end
55+
2656
function marginals(d::For{T,F,I}) where {T,F,I}
2757
f(x...) = d.f(x...)::T
2858
mappedarray(f, d.inds...)
@@ -43,9 +73,8 @@ end
4373
end
4474

4575
@inline function _basemeasure(d::For{T,F,I}, ::Type{B}, ::False) where {T,F,I,B<:AbstractMeasure}
46-
new_f = basekleisli(d.f)
47-
new_F = typeof(new_f)
48-
For{B,new_F, I}(new_f, d.inds)
76+
f = basekleisli(d.f)
77+
_For(B, f, d.inds)
4978
end
5079

5180
@inline function _basemeasure(d::For{T,F,I}, ::Type{B}, ::False) where {T,F,I,B}
@@ -58,8 +87,7 @@ end
5887

5988
function _basemeasure(d::For{T,F,I}, ::Type{B}, ::False) where {N,T<:AbstractMeasure,F,I<:NTuple{N,<:Base.Generator},B}
6089
f = basekleisli(d.f)
61-
newF = typeof(f)
62-
For{B,newF,I}(f, d.inds)
90+
_For(B, f, d.inds)
6391
end
6492

6593
function Pretty.tile(d::For{T}) where {T}
@@ -74,6 +102,10 @@ function Pretty.tile(d::For{T}) where {T}
74102
)
75103
end
76104

105+
function _For(::Type{T}, f::F, inds::I) where {T,F,I}
106+
For{T,F,I}(f,inds)
107+
end
108+
77109

78110
"""
79111
For(f, base...)

src/combinators/likelihood.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
export Likelihood
1+
export AbstractLikelihood, Likelihood
2+
3+
abstract type AbstractLikelihood end
4+
5+
@inline logdensityof(ℓ::AbstractLikelihood, par) = logdensity_def(ℓ, par)
26

37
@doc raw"""
48
Likelihood(k::AbstractKleisli, x)
@@ -101,7 +105,7 @@ and we observe `x=3`. We can compute the posterior measure on `μ` as
101105
julia> logdensity_def(post, 2)
102106
-2.5
103107
"""
104-
struct Likelihood{K,X}
108+
struct Likelihood{K,X} <: AbstractLikelihood
105109
k::K
106110
x::X
107111

@@ -124,10 +128,10 @@ function Base.show(io::IO, ℓ::Likelihood)
124128
Pretty.pprint(io, ℓ)
125129
end
126130

127-
@inline function logdensity_def(ℓ::Likelihood, p)
131+
@inline function logdensity_def(ℓ::Likelihood, p::Tuple)
128132
return logdensity_def(ℓ.k(p), ℓ.x)
129133
end
130134

131-
@inline function logdensityof(ℓ::Likelihood, p)
132-
return logdensityof(ℓ.k(p), ℓ.x)
135+
@inline function logdensity_def(ℓ::Likelihood, p)
136+
return logdensity_def(ℓ.k((p,)), ℓ.x)
133137
end

src/combinators/pointwise.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
export
22

3-
struct PointwiseProductMeasure{T} <: AbstractMeasure
4-
data::T
3+
struct PointwiseProductMeasure{M,L} <: AbstractMeasure
4+
measure::M
5+
likelihood::L
6+
7+
function PointwiseProductMeasure::M, ℓ::L) where {M,L}
8+
@assert static_hasmethod(logdensity_def, Tuple{L, gentype(μ)})
9+
return new{M,L}(μ, ℓ)
10+
end
511
end
612

713
Base.size::PointwiseProductMeasure) = size.data)
814

915
function Base.show(io::IO, μ::PointwiseProductMeasure)
1016
io = IOContext(io, :compact => true)
11-
print(io, join(string.(μ.data), ""))
17+
print(io, μ.measure, "", μ.likelihood)
1218
end
1319

1420
function Base.show_unquoted(io::IO, μ::PointwiseProductMeasure, indent::Int, prec::Int)
@@ -28,13 +34,11 @@ Base.length(m::PointwiseProductMeasure{T}) where {T} = length(m.data)
2834
(args...) = pointwiseproduct(args...)
2935

3036
@inline function logdensity_def(d::PointwiseProductMeasure, x)
31-
sum(d.data) do dⱼ
32-
logdensity_def(dⱼ, x)
33-
end
37+
logdensity_def(d.measure, x) + logdensity_def(d.likelihood, x)
3438
end
3539

3640
function gentype(d::PointwiseProductMeasure)
37-
@inbounds gentype(first(d.data))
41+
@inbounds gentype(d.measure)
3842
end
3943

40-
basemeasure(d::PointwiseProductMeasure) = @inbounds basemeasure(first(d.data))
44+
basemeasure(d::PointwiseProductMeasure) = @inbounds basemeasure(d.measure)

src/combinators/power.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,12 @@ params(d::PowerMeasure) = params(first(marginals(d)))
5959
basemeasure(d.parent) ^ d.axes
6060
end
6161

62-
@inline function logdensity_def(d::PowerMeasure, x)
63-
sum(x) do xj
64-
logdensity_def(d.parent, xj)
62+
@inline function logdensity_def(d::PowerMeasure{M}, x) where {M}
63+
T = eltype(x)
64+
= zero(float(Core.Compiler.return_type(logdensity_def, Tuple{M,T})))
65+
parent = d.parent
66+
@inbounds for xj in x
67+
+= logdensity_def(parent, xj)
6568
end
69+
6670
end

src/combinators/product.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure)
2121
end
2222
Base.length::AbstractProductMeasure) = length(marginals(μ))
2323
Base.size::AbstractProductMeasure) = size(marginals(μ))
24+
2425
basemeasure(d::AbstractProductMeasure) = productmeasure(map(basemeasure, marginals(d)))
2526

2627
function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractProductMeasure) where {T}

src/combinators/smart-constructors.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@ half(μ::AbstractMeasure) = Half(μ)
77
###############################################################################
88
# PointwiseProductMeasure
99

10-
pointwiseproduct::AbstractMeasure...) = PointwiseProductMeasure(μ)
11-
1210
function pointwiseproduct::AbstractMeasure, ℓ::Likelihood)
13-
data = (μ, ℓ)
14-
return PointwiseProductMeasure(data)
11+
return PointwiseProductMeasure(μ, ℓ)
1512
end
1613

1714
###############################################################################

src/density.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,36 +101,35 @@ Define a new measure in terms of a log-density `f` over some measure `base`.
101101

102102
# TODO: `density` and `logdensity` functions for `DensityMeasure`
103103

104-
@inline logdensityof(μ, x) = _logdensityof(μ, x)
104+
@inline logdensityof(μ, x) = dynamic(_logdensityof(μ, x))
105105

106106
@inline _logdensityof(μ, x) = _logdensityof(μ, basemeasure(μ, x), x)
107107

108108
@inline function _logdensityof(μ, α, x)
109-
= dynamic(logdensity_def(μ, x))
110-
L = typeof(ℓ)
111-
_logdensityof(μ, α, x, ℓ)::L
109+
_logdensityof(μ, α, x, partialstatic(logdensity_def(μ, x)))
112110
end
113111

114112
@inline function _logdensityof::M, β::M, x, ℓ) where {M}
115113
return
116114
end
117115

118116
@inline function _logdensityof::M, β, x, ℓ) where {M}
119-
n = basemeasure_depth(μ) - static(1)
117+
n = static(basemeasure_depth(β))
120118
_logdensityof(β, basemeasure(β,x), x, ℓ, n)
121119
end
122120

123-
@generated function _logdensityof(μ, β, x, ℓ::T, ::StaticInt{n}) where {n,T}
121+
@generated function _logdensityof(μ, β, x, ℓ, ::StaticInt{n}) where {n}
124122
nsteps = max(n, 0)
125123
quote
126124
$(Expr(:meta, :inline))
125+
# @show ℓ
127126
Base.Cartesian.@nexprs $nsteps i -> begin
128-
Δℓ = oftype(ℓ, logdensity_def(μ, x))
127+
Δℓ = logdensity_def(μ, x)
129128
# @show μ
130129
# @show Δℓ
131130
# println()
132131
μ,β = β, basemeasure(β, x)
133-
+= Δℓ
132+
+= partialstatic(Δℓ)
134133
end
135134
return
136135
end
@@ -164,7 +163,6 @@ export logdensityof
164163
export density_def
165164

166165
density_def(μ, ν::AbstractMeasure, x) = exp(logdensity_def(μ, ν, x))
167-
168166
density_def(μ, x) = exp(logdensity_def(μ, x))
169167

170168
"""

src/interface.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ function test_interface(μ::M) where {M}
1515
μ = $μ
1616
@testset "" begin
1717
μ = $μ
18-
M = $M
1918

2019
###########################################################################
2120
# basemeasure_depth
@@ -31,14 +30,13 @@ function test_interface(μ::M) where {M}
3130
###########################################################################
3231
# testvalue, logdensityof
3332

34-
x = testvalue(μ)
33+
x = @inferred testvalue(μ)
3534
β = @inferred basemeasure(μ, x)
3635

3736
ℓμ = @inferred logdensityof(μ, x)
3837
ℓβ = @inferred logdensityof(β, x)
3938

4039
@test ℓμ logdensity_def(μ, x) + ℓβ
41-
4240
end
4341
end
4442
end

src/parameterized.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,19 @@ end
2929
export kleisli
3030

3131
function kleisli(::Type{P}) where {N,P<:ParameterizedMeasure{N}}
32-
C = constructorof(P)
33-
function(args...) C(NamedTuple{N}(args...)) end
32+
C = constructorof(P)
33+
_kleisli(C, Val(N))
34+
end
35+
36+
@inline function _kleisli(::Type{C}, ::Val{N}) where {C,N}
37+
@inline function(args::T) where {T<:Tuple}
38+
C(NamedTuple{N,T}(args))::C{N,T}
39+
end
3440
end
3541

3642
function (::Type{P})(args...) where {N,P<:ParameterizedMeasure{N}}
3743
C = constructorof(P)
38-
return C(NamedTuple{N}(args...))
44+
return C(NamedTuple{N}(args))::C{N,typeof(args)}
3945
end
4046

4147
(::Type{P})(; kwargs...) where {P<:ParameterizedMeasure} = P(NamedTuple(kwargs))

0 commit comments

Comments
 (0)