Skip to content

Commit 85c00e5

Browse files
authored
Smart constructors (#18)
* update 3-arg logdensity * update logjac * factoredbase * update power measure * update half.jl to use Factoredbase * drop outdated test * add a test * bump version * start on smart constructors * work on smart constructors * get tests passing * roll back superpose.jl for now * smart constructor updates * uncomment some tests * add missing Affine method * update RestrictedMeasure * Get MeasureTheory tests passing * add fake Returns for VERSION < v"1.7.0-beta1.0" * update Breakage * Make JET happy * bump version
1 parent 46b676e commit 85c00e5

File tree

16 files changed

+151
-87
lines changed

16 files changed

+151
-87
lines changed

.github/workflows/Breakage.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ jobs:
1313
fail-fast: false
1414
matrix:
1515
pkg: [
16-
"cscherrer/MeasureTheory.jl",
17-
"cscherrer/Soss.jl",
18-
"mschauer/Mitosis.jl"
16+
"cscherrer/MeasureTheory.jl"
1917
]
2018
pkgversion: [latest, stable]
2119

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"

src/MeasureBase.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ Methods for computing density relative to other measures will be
3434
"""
3535
function logdensity end
3636

37+
38+
if VERSION < v"1.7.0-beta1.0"
39+
@eval Returns(x) = _ -> x
40+
end
41+
3742
include("combinators/half.jl")
3843
include("exp.jl")
3944
include("domains.jl")
@@ -60,6 +65,8 @@ include("combinators/spikemixture.jl")
6065
include("kernel.jl")
6166
include("combinators/likelihood.jl")
6267
include("combinators/pointwise.jl")
68+
include("combinators/restricted.jl")
69+
include("combinators/smart-constructors.jl")
6370

6471
include("rand.jl")
6572

src/combinators/affine.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,11 @@ logjac(f::AffineTransform{(:μ,)}) = 0.0
3434
struct Affine{N,M,T} <: AbstractMeasure
3535
f::AffineTransform{N,T}
3636
parent::M
37+
end
3738

38-
function Affine(f::AffineTransform, parent::WeightedMeasure)
39-
WeightedMeasure(parent.logweight, Affine(f, parent.base))
40-
end
39+
Affine(nt::NamedTuple, μ::AbstractMeasure) = affine(nt, μ)
4140

42-
Affine(f::AffineTransform{N,T}, parent::M) where {N,M,T} = new{N,M,T}(f, parent)
43-
end
41+
Affine(nt::NamedTuple) = affine(nt)
4442

4543
parent(d::Affine) = getfield(d, :parent)
4644

@@ -54,10 +52,6 @@ function paramnames(::Type{A}) where {N,M, A<:Affine{N,M}}
5452
tuple(union(N, paramnames(M))...)
5553
end
5654

57-
Affine(nt::NamedTuple, μ::AbstractMeasure) = Affine(AffineTransform(nt), μ)
58-
59-
Affine(nt::NamedTuple) = μ -> Affine(nt, μ)
60-
6155
Base.propertynames(d::Affine{N}) where {N} = N (:parent,)
6256

6357
@inline function Base.getproperty(d::Affine, s::Symbol)
@@ -77,11 +71,11 @@ logdensity(d::Affine{(:σ,)}, x) = logdensity(d.parent, d.σ \ x)
7771
logdensity(d::Affine{(:ω,)}, x) = logdensity(d.parent, d.ω * x)
7872
logdensity(d::Affine{(:μ,)}, x) = logdensity(d.parent, x - d.μ)
7973

80-
basemeasure(d::Affine) = Affine(getfield(d, :f), basemeasure(d.parent))
74+
basemeasure(d::Affine) = affine(getfield(d, :f), basemeasure(d.parent))
8175

8276
# We can't do this until we know we're working with Lebesgue measure, since for
8377
# example it wouldn't make sense to apply a log-Jacobian to a point measure
84-
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = WeightedMeasure(-logjac(d), d.parent)
78+
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = weightedmeasure(-logjac(d), d.parent)
8579

8680
logjac(d::Affine) = logjac(getfield(d, :f))
8781

src/combinators/factoredbase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99

1010
function logdensity(d::FactoredBase, x)
1111
d.inbounds(x) || return -Inf
12-
d.const+ d.varℓ
12+
d.const+ d.varℓ()
1313
end
1414

1515
basemeasure(d::FactoredBase) = d.base

src/combinators/half.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ unhalf(μ::Half) = μ.parent
1414
function basemeasure::Half)
1515
inbounds(x) = x > 0
1616
const= logtwo
17-
varℓ = 0.0
17+
varℓ() = 0.0
1818
base = basemeasure(unhalf(μ))
1919
FactoredBase(inbounds, constℓ, varℓ, base)
2020
end

src/combinators/pointwise.jl

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ export ⊙
22

33
@concrete terse struct PointwiseProductMeasure{T} <: AbstractMeasure
44
data :: T
5-
6-
PointwiseProductMeasure(μs...) = new{typeof(μs)}(μs)
7-
PointwiseProductMeasure(μs) = new{typeof(μs)}(μs)
85
end
96

107
Base.size::PointwiseProductMeasure) = size.data)
@@ -28,30 +25,7 @@ end
2825

2926
Base.length(m::PointwiseProductMeasure{T}) where {T} = length(m.data)
3027

31-
function ::PointwiseProductMeasure{X}, ν::PointwiseProductMeasure{Y}) where {X,Y}
32-
data =.data..., ν.data...)
33-
PointwiseProductMeasure(data...)
34-
end
35-
36-
function ::AbstractMeasure, ν::PointwiseProductMeasure)
37-
data = (μ, ν.data...)
38-
PointwiseProductMeasure(data...)
39-
end
40-
41-
function ::PointwiseProductMeasure, ν::N) where {N <: AbstractMeasure}
42-
data =.data..., ν)
43-
PointwiseProductMeasure(data...)
44-
end
45-
46-
function ::M, ν::N) where {M <: AbstractMeasure, N <: AbstractMeasure}
47-
data = (μ, ν)
48-
PointwiseProductMeasure(data...)
49-
end
50-
51-
function ::AbstractMeasure, ℓ::Likelihood)
52-
data = (μ, ℓ)
53-
PointwiseProductMeasure(data...)
54-
end
28+
(args...) = pointwiseproduct(args...)
5529

5630
function logdensity(d::PointwiseProductMeasure, x)
5731
sum((logdensity(dⱼ, x) for dⱼ in d.data))

src/combinators/power.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ export PowerMeasure
3232
const PowerMeasure{F,T,N,A} = ProductMeasure{F,Fill{T,N,A}}
3333

3434
function Base.:^::AbstractMeasure, dims::Integer...)
35-
return μ^dims
35+
return μ ^ dims
3636
end
3737

3838
function Base.:^::M, dims::NTuple{N,I}) where {M<:AbstractMeasure,N,I<:Integer}
39-
ProductMeasure(identity, Fill(μ, dims))
39+
powermeasure(μ, dims)
4040
end
4141

4242
# Same as PowerMeasure
@@ -53,31 +53,32 @@ end
5353

5454
# sampletype(d::PowerMeasure{M,N}) where {M,N} = @inbounds Array{sampletype(first(marginals(d))), N}
5555

56-
function Base.:^::WeightedMeasure, dims::NTuple{N,I}) where {N,I<:Integer}
57-
k = prod(dims) * μ.logweight
58-
return WeightedMeasure(k, μ.base^dims)
59-
end
56+
6057

6158
params(d::ProductMeasure{F,<:Fill}) where {F} = params(first(marginals(d)))
6259

6360
params(::Type{P}) where {F,P<:ProductMeasure{F,<:Fill}} = params(D)
6461

6562
# basemeasure(μ::PowerMeasure) = @inbounds basemeasure(first(μ.data))^size(μ.data)
6663

67-
@inline basemeasure(d::PowerMeasure) = _basemeasure(d, (basemeasure(d.f(first(d.pars)))))
64+
# Same as PowerMeasure
65+
@inline basemeasure(d::ProductMeasure{F,<:Fill}) where {F}= _basemeasure(d, (basemeasure(d.f(first(d.pars)))))
6866

69-
@inline _basemeasure(d::PowerMeasure, b) = b ^ size(d.pars)
67+
# Same as PowerMeasure
68+
@inline _basemeasure(d::ProductMeasure{F,<:Fill}, b) where {F} = b ^ size(d.pars)
7069

71-
@inline function _basemeasure(d::PowerMeasure, b::FactoredBase)
70+
# Same as PowerMeasure
71+
@inline function _basemeasure(d::ProductMeasure{F,<:Fill}, b::FactoredBase) where {F}
7272
n = length(d.pars)
7373
inbounds(x) = all(xj -> b.inbounds(xj), x)
7474
const= n * b.const
75-
varℓ = n * b.varℓ
75+
varℓ() = n * b.varℓ()
7676
base = b.base ^ size(d.pars)
7777
FactoredBase(inbounds, constℓ, varℓ, base)
7878
end
7979

80-
function logdensity(d::PowerMeasure, x)
80+
# Same as PowerMeasure
81+
function logdensity(d::ProductMeasure{F,<:Fill}, x) where {F}
8182
d1 = d.f(first(d.pars))
8283
sum(xj -> logdensity(d1, xj), x)
8384
end

src/combinators/product.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,13 @@ struct ProductMeasure{F,I} <: AbstractMeasure
88
pars::I
99
end
1010

11-
ProductMeasure(nt::NamedTuple) = ProductMeasure(identity, nt)
12-
1311
Base.size::ProductMeasure) = size(marginals(μ))
1412

1513
Base.length(m::ProductMeasure{T}) where {T} = length(marginals(μ))
1614

1715
# TODO: Pull weights outside
1816
basemeasure(d::ProductMeasure) = ProductMeasure(basemeasure d.f, d.pars)
1917

20-
2118
export marginals
2219

2320
function marginals(d::ProductMeasure{F,I}) where {F,I}
@@ -39,9 +36,6 @@ function Base.show(io::IO, μ::ProductMeasure{NamedTuple{N,T}}) where {N,T}
3936
print(io, "Product(".data, ")")
4037
end
4138

42-
43-
44-
4539
function Base.show_unquoted(io::IO, μ::ProductMeasure, indent::Int, prec::Int)
4640
io = IOContext(io, :compact => true)
4741
if Base.operator_precedence(:*) prec
@@ -141,11 +135,6 @@ end
141135
return x
142136
end
143137

144-
145-
146-
147-
148-
149138
export rand!
150139
using Random: rand!, GLOBAL_RNG, AbstractRNG
151140

src/combinators/restricted.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
struct RestrictedMeasure{F,M} <: AbstractMeasure
2+
f::F
3+
base::M
4+
end
5+
6+
function logdensity(d::RestrictedMeasure, x)
7+
d.f(x) || return -Inf
8+
end
9+
10+
function density(d::RestrictedMeasure, x)
11+
d.f(x) || return 0.0
12+
end
13+
14+
basemeasure::RestrictedMeasure) = μ.base

0 commit comments

Comments
 (0)