Skip to content

Commit 84c6063

Browse files
authored
Dev (#213)
* Dirichlet(k::Integer, α) = Dirichlet(Fill(α, k)) * export TransformVariables as TV * drop redundant import * 0.0 => zero(Float64) * drop outdated Dists.logpdf * update StudentT * drop redundant import * update Uniform * bump MeasureBase version * reworking beta * small update to StudentT * basemeasure for discrete Distributions * using LogExpFunctions => import LogExpFunctions * quoteof(::Chain) * prettyprinting and chain-mucking * Some refactoring for Markov chains * import MeasureBase: ≪ * version bound for PrettyPrinting * copy(rng) might change its type (e.g. GLOBAL_RNG) * tests pass * cleaning up * more cleanup * big update * get tests passing * formatting * oops typo * move affine to MeasureTheory * updating * Val => StaticSymbol * more fixes * fix fix fix * more logdesnity => logdensity_def * more logdesnity fixes * debugging * formatting * bugfixes * working on tests * updates * working on tests * tests passing! * refactor * working on tests * drop static weight for now * fix sampling from ProductMeasure{<:Base.Generator} * tests passing!! * more stuff * constructor => constructorof * constructor =? construtorof * updates * working on tests * fix Dirichlet * update Bernoulli * working on tests * bugfixes for RealizedSamples * tests passing!! * tighten down inference * as(::PowerMeasure) * drop type-level stuff * using InverseFunctions.jl * update license * affero * copyright * update CI to 1.6 * xform => TV.as * oops missed a conflict * fix merge corruption * typo * fix license * Update README.md * merge * enumerate instead of zip * bugfix * inline rand * drop `static` from `insupport` results * update proxies * Move ConditionalMeasure to MeasureBase * IfElse.ifelse(p::Bernoulli, t, f) * IfElseMeasure * update some base measures * test broken :( * fix some redundancies * instance_type => Core.Typeof * update testvalue for Bernoulli and Binomial * un-break broken test (now passing) * Fall-back `For` method for when inference fails * drop extra spaces * more whitespace * bump MeasureBase dependency version * add newline * tidy up * ifelse tests * OEF newline * avoid type piracy * add Julia 1.7 to CI * make Julia 1.6 happy * approx instead of == * Require at least Julia 1.6 * Try Sebastian's idea test_measures ::Any[] * Another Any[] * Drop Likelihood test * drop 1.7 CI (seems buggy?) * bump version * export likelihood * Snedecor's F * Gamma distribution * more gamma stuff * Beroulli() * inverse Gaussian * Getting modifed GLM.jl tests to pass * drop pdf and logpdf * Poisson bugfix * update Normal(μ, σ²) * Gamma(μ, ϕ) for GLMs * updates for GLM support * start on truncated * update parameterized measures * drop FactoredBase * drop old LazyArrays dependency * insupport(::Distribution) * Left out"Dists." * don't export `ifelse` (#192) * Kleisli => TransitionKernel * depend on StatsBase * tests passing * bump MeasureBase version * work on truncated and censored * improve func_string * Simplify logdensity_def(::For, x) * Move truncated and censored updates to separate branches * newline * comment out in-progress stuff * newline * bump version * update formatting spec * more formatting * tweedie docs * drop redundant exports * update exports * omega => lambda * drop SequentialEx * get tests passing * add kernel tests * gitignore * better `Pretty.tile` for Affine and AffineTransforms * formatting * kleisli => kernel * update tile(::For) * update Compat version * bump MB version * update gamma * Let's come back to InverseGaussian * CI on 1.7 * update IfElse * formatting * update product * @kwstruct Bernoulli(logitp) * Base.size(r::RealizedSamples) * cdf(::Affine, x) * working on Aqua and JET fixes * formatting * So many `rand` methods, do we really need all of these? * bump version * bugfix * bugfix * loosen constraint
1 parent 8055b52 commit 84c6063

File tree

9 files changed

+132
-30
lines changed

9 files changed

+132
-30
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureTheory"
22
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.16.3"
4+
version = "0.16.4"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -18,6 +18,7 @@ Infinities = "e1ba4f0e-776d-440f-acd9-e1d2e9742647"
1818
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1919
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
2020
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
21+
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
2122
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2223
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
2324
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
@@ -53,11 +54,12 @@ IfElse = "0.1"
5354
Infinities = "0.1"
5455
InverseFunctions = "0.1"
5556
KeywordCalls = "0.2"
57+
LazyArrays = "0.22"
5658
LogExpFunctions = "0.3.3"
5759
MLStyle = "0.4"
5860
MacroTools = "0.5"
5961
MappedArrays = "0.4"
60-
MeasureBase = "0.10"
62+
MeasureBase = "0.12"
6163
NamedTupleTools = "0.13, 0.14"
6264
NestedTuples = "0.3"
6365
PositiveFactorizations = "0.2"

src/combinators/affine.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export Affine, AffineTransform
22
using LinearAlgebra
3+
import Base
34

45
const AFFINEPARS = [
56
(, )
@@ -43,8 +44,8 @@ Base.size(f::AffineTransform{(:λ,)}) = size(f.λ)
4344

4445
LinearAlgebra.rank(f::AffineTransform{(:σ,)}) = rank(f.σ)
4546
LinearAlgebra.rank(f::AffineTransform{(:λ,)}) = rank(f.λ)
46-
LinearAlgebra.rank(f::AffineTransform{(:μ,:σ,)}) = rank(f.σ)
47-
LinearAlgebra.rank(f::AffineTransform{(:μ,:λ,)}) = rank(f.λ)
47+
LinearAlgebra.rank(f::AffineTransform{(:μ,)}) = rank(f.σ)
48+
LinearAlgebra.rank(f::AffineTransform{(:μ,)}) = rank(f.λ)
4849

4950
function Base.size(f::AffineTransform{(:μ,)})
5051
(n,) = size(f.μ)
@@ -181,7 +182,7 @@ Affine(nt::NamedTuple, μ::AbstractMeasure) = affine(nt, μ)
181182

182183
Affine(nt::NamedTuple) = affine(nt)
183184

184-
parent(d::Affine) = getfield(d, :parent)
185+
Base.parent(d::Affine) = getfield(d, :parent)
185186

186187
function params::Affine)
187188
nt1 = getfield(getfield(μ, :f), :par)
@@ -262,6 +263,12 @@ end
262263
weightedmeasure(-logjac(d), OrthoLebesgue(params(d)))
263264
end
264265

266+
@inline function basemeasure(
267+
d::MeasureTheory.Affine{N,L,Tuple{A}},
268+
) where {N,L<:MeasureBase.Lebesgue,A<:AbstractArray}
269+
weightedmeasure(-logjac(d), OrthoLebesgue(params(d)))
270+
end
271+
265272
@inline function basemeasure(
266273
d::Affine{N,M,Tuple{A1,A2}},
267274
) where {N,M,A1<:AbstractArray,A2<:AbstractArray}
@@ -328,3 +335,7 @@ end
328335
@inline function insupport(d::Affine, x)
329336
insupport(d.parent, inverse(d.f)(x))
330337
end
338+
339+
@inline function Distributions.cdf(d::Affine, x)
340+
cdf(parent(d), inverse(d.f)(x))
341+
end

src/combinators/exponential-families.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export ExponentialFamily
2+
using LazyArrays
23

34
@concrete terse struct ExponentialFamily <: AbstractTransitionKernel
45
support_contains
@@ -16,10 +17,10 @@ function ExponentialFamily(support_contains, base, mdim, pdim, t, a)
1617
return ExponentialFamily(support_contains, base, mdim, pdim, t, I, a)
1718
end
1819

19-
function MeasureBase.powermeasure(fam::ExponentialFamily, dims::NTuple{N,I}) where {N,I}
20+
function MeasureBase.powermeasure(fam::ExponentialFamily, dims::NTuple)
2021
support_contains(x) = all(xj -> fam.support_contains(xj), x)
2122
t = Tuple((y -> f.(y) for f in fam.t))
22-
a(η) = BroadcastArray(fam.a, η)
23+
a(η) = LazyArrays.BroadcastArray(fam.a, η)
2324
p = prod(dims)
2425
ExponentialFamily(
2526
support_contains,
@@ -32,6 +33,8 @@ function MeasureBase.powermeasure(fam::ExponentialFamily, dims::NTuple{N,I}) whe
3233
)
3334
end
3435

36+
powermeasure(fam::ExponentialFamily, ::Tuple{}) = fam
37+
3538
@concrete terse struct ExpFamMeasure <: AbstractMeasure
3639
fam
3740
η # instantiated to a value

src/distproxy.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,23 @@ for m in keys(PROXIES)
2525
@eval begin
2626
import $m: $f
2727
export $f
28-
$m.$f(d::AbstractMeasure, args...) = $m.$f(MeasureTheory.proxy(d), args...)
2928
end
3029
end
3130
end
31+
32+
entropy(m::AbstractMeasure, b::Real) = entropy(proxy(m), b)
33+
mean(m::AbstractMeasure) = mean(proxy(m))
34+
std(m::AbstractMeasure) = std(proxy(m))
35+
var(m::AbstractMeasure) = var(proxy(m))
36+
quantile(m::AbstractMeasure, q) = quantile(proxy(m), q)
37+
38+
for f in [
39+
:cdf
40+
:ccdf
41+
:logcdf
42+
:logccdf
43+
]
44+
@eval begin
45+
$f(d::AbstractMeasure, args...) = $f(MeasureTheory.proxy(d), args...)
46+
end
47+
end

src/parameterized/inverse-gaussian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ function logdensity_def(d::InverseGaussian{(:μ, :ϕ)}, x)
4848
end
4949

5050
function basemeasure(d::InverseGaussian{(:μ, :ϕ)})
51-
= static(-0.5) * (static(log2π) + log(d.ϕ))
51+
= static(-0.5) * (static(float(log2π)) + log(d.ϕ))
5252
weightedmeasure(ℓ, Lebesgue())
5353
end

src/parameterized/mvnormal.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ export MvNormal
2222

2323
as(d::MvNormal{(:μ,)}) = as(Array, length(d.μ))
2424

25-
as(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
26-
as(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
27-
as(d::MvNormal{(:μ, :Σ),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
28-
as(d::MvNormal{(:μ, :Λ),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
25+
as(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
26+
as(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
27+
as(d::MvNormal{(:μ, :Σ),<:Tuple{T,C}}) where {T,C<:Cholesky} = as(Array, size(d.Σ, 1))
28+
as(d::MvNormal{(:μ, :Λ),<:Tuple{T,C}}) where {T,C<:Cholesky} = as(Array, size(d.Λ, 1))
2929

3030
function as(d::MvNormal{(:σ,),Tuple{M}}) where {M<:Triangular}
3131
σ = d.σ
32-
if @inbounds all(i -> σ[i] > 0, diagind(σ))
32+
if @inbounds all(i -> σ[i] 0, diagind(σ))
3333
return as(Array, size(σ, 1))
3434
else
3535
@error "Not implemented yet"
@@ -49,7 +49,7 @@ for N in setdiff(AFFINEPARS, [(:μ,)])
4949
@eval begin
5050
function as(d::MvNormal{$N})
5151
p = proxy(d)
52-
if rank(getfield(p,:f)) == only(supportdim(d))
52+
if rank(getfield(p, :f)) == only(supportdim(d))
5353
return as(Array, supportdim(d))
5454
else
5555
@error "Not yet implemented"
@@ -61,13 +61,13 @@ end
6161
supportdim(d::MvNormal) = supportdim(params(d))
6262

6363
supportdim(nt::NamedTuple{(:Σ,)}) = size(nt.Σ, 1)
64-
supportdim(nt::NamedTuple{(:μ,:Σ)}) = size(nt.Σ, 1)
64+
supportdim(nt::NamedTuple{(:μ, :Σ)}) = size(nt.Σ, 1)
6565
supportdim(nt::NamedTuple{(:Λ,)}) = size(nt.Λ, 1)
66-
supportdim(nt::NamedTuple{(:μ,:Λ)}) = size(nt.Λ, 1)
66+
supportdim(nt::NamedTuple{(:μ, :Λ)}) = size(nt.Λ, 1)
6767

6868
@useproxy MvNormal
6969

70-
for N in [(,), (,), (,), (,)]
70+
for N in [(,), (, ), (,), (, )]
7171
@eval basemeasure_depth(d::MvNormal{$N}) = static(2)
7272
end
7373

@@ -78,7 +78,15 @@ rand(rng::AbstractRNG, ::Type{T}, d::MvNormal) where {T} = rand(rng, T, proxy(d)
7878
insupport(d::MvNormal, x) = insupport(proxy(d), x)
7979

8080
# Note: (C::Cholesky).L may or may not make a copy, depending on C.uplo, which is not included in the type
81-
@inline proxy(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = affine((σ = d.Σ.L,), Normal()^supportdim(d))
82-
@inline proxy(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = affine((λ = d.Λ.L,), Normal()^supportdim(d))
83-
@inline proxy(d::MvNormal{(:μ, :Σ),Tuple{C}}) where {C<:Cholesky} = affine((μ = d.μ, σ = d.Σ.L), Normal()^supportdim(d))
84-
@inline proxy(d::MvNormal{(:μ, :Λ),Tuple{C}}) where {C<:Cholesky} = affine((μ = d.μ, λ = d.Λ.L), Normal()^supportdim(d))
81+
@inline function proxy(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky}
82+
affine((σ = d.Σ.L,), Normal()^supportdim(d))
83+
end
84+
@inline function proxy(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky}
85+
affine((λ = d.Λ.L,), Normal()^supportdim(d))
86+
end
87+
@inline function proxy(d::MvNormal{(:μ, :Σ),Tuple{T,C}}) where {T,C<:Cholesky}
88+
affine((μ = d.μ, σ = d.Σ.L), Normal()^supportdim(d))
89+
end
90+
@inline function proxy(d::MvNormal{(:μ, :Λ),Tuple{T,C}}) where {T,C<:Cholesky}
91+
affine((μ = d.μ, λ = d.Λ.L), Normal()^supportdim(d))
92+
end

src/parameterized/normal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ HalfNormal(σ) = HalfNormal((σ = σ,))
146146
end
147147

148148
@inline function basemeasure(d::Normal{(:σ²,)})
149-
= static(-0.5) * (static(log2π) + log(d.σ²))
149+
= static(-0.5) * (static(float(log2π)) + log(d.σ²))
150150
weightedmeasure(ℓ, Lebesgue())
151151
end
152152

src/resettable-rng.jl

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,57 @@ function Random.Sampler(
7878
return Random.Sampler(r.rng, s, r)
7979
end
8080

81-
function Base.rand(r::ResettableRNG, sp::Random.Sampler)
82-
rand(r.rng, sp)
83-
end
81+
# UIntBitsTypes = [UInt128, UInt16, UInt32, UInt64, UInt8]
82+
# IntBitsTypes = [Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8]
83+
# FloatBitsTypes = [Float16, Float32, Float64]
84+
85+
# for I in IntBitsTypes
86+
# for T in [
87+
# Random.SamplerTrivial{Random.UInt104Raw{I}}
88+
# Random.SamplerTrivial{Random.UInt10Raw{I}}
89+
# ]
90+
# @eval begin
91+
# function Base.rand(r::ResettableRNG, sp::$T)
92+
# rand(r.rng, sp)
93+
# end
94+
# end
95+
# end
96+
# end
97+
98+
# for U in UIntBitsTypes
99+
# for I in IntBitsTypes
100+
# for T in [
101+
# Random.SamplerRangeInt{T,U} where {T<:Union{IntBitsTypes...}}
102+
# Random.SamplerRangeFast{U,I}
103+
# ]
104+
# @eval begin
105+
# function Base.rand(r::ResettableRNG, sp::$T)
106+
# rand(r.rng, sp)
107+
# end
108+
# end
109+
# end
110+
# end
111+
# end
112+
113+
# for T in [
114+
# Random.Sampler
115+
# Random.SamplerBigInt
116+
# Random.SamplerTag{<:Set,<:Random.Sampler}
117+
# # Random.SamplerTrivial{Random.CloseOpen01{T}} where {T<:FloatBitsTypes}
118+
# # Random.SamplerTrivial{Random.UInt23Raw{UInt32}}
119+
# Random.UniformT
120+
# Random.SamplerSimple{T,S,E} where {E,S,T<:Tuple}
121+
# Random.SamplerType{T} where {T<:AbstractChar}
122+
# Random.SamplerTrivial{Tuple{A}} where {A}
123+
# Random.SamplerSimple{Tuple{A,B,C},S,E} where {E,S,A,B,C}
124+
# Random.SamplerSimple{<:AbstractArray,<:Random.Sampler}
125+
# Random.Masked
126+
# Random.SamplerSimple{BitSet,<:Random.Sampler}
127+
# Random.SamplerTrivial{<:Random.UniformBits{T},E} where {E,T}
128+
# ]
129+
# @eval begin
130+
# function Base.rand(r::ResettableRNG, sp::$T)
131+
# rand(r.rng, sp)
132+
# end
133+
# end
134+
# end

test/runtests.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,18 @@ using MeasureBase.Interface
1212
using MeasureTheory: kernel
1313
using Aqua
1414
using IfElse
15-
Aqua.test_all(MeasureTheory; ambiguities = false, unbound_args = false)
15+
16+
# Aqua._test_ambiguities(
17+
# Aqua.aspkgids(MeasureTheory);
18+
# exclude = [Random.AbstractRNG],
19+
# # packages::Vector{PkgId};
20+
# # color::Union{Bool, Nothing} = nothing,
21+
# # exclude::AbstractArray = [],
22+
# # # Options to be passed to `Test.detect_ambiguities`:
23+
# # detect_ambiguities_options...,
24+
# )
25+
26+
Aqua.test_all(MeasureBase; ambiguities = false)
1627

1728
function draw2(μ)
1829
x = rand(μ)
@@ -23,8 +34,8 @@ function draw2(μ)
2334
return (x, y)
2435
end
2536

26-
x = randn(10,3)
27-
Σ = cholesky(x'*x)
37+
x = randn(10, 3)
38+
Σ = cholesky(x' * x)
2839
Λ = cholesky(inv(Σ))
2940
σ = MeasureTheory.getL(Σ)
3041
λ = MeasureTheory.getL(Λ)
@@ -611,7 +622,7 @@ end
611622
@testset "IfElseMeasure" begin
612623
p = rand()
613624
x = randn()
614-
625+
615626
@test let
616627
a = logdensityof(IfElse.ifelse(Bernoulli(p), Normal(), Normal()), x)
617628
b = logdensityof(Normal(), x)

0 commit comments

Comments
 (0)