Skip to content

Commit 2e9c2e6

Browse files
cscherreroschulz
andauthored
Dev (#88)
* drop debugging code * small update for Likelihood, and a test * fixing up likelihoods * improve `basemeasure_depth` dispatch * still some trouble with inferred basemeasure_depth * clean up `For` dispatch * simplify _logdensityof * optimize for Returns{True} case * rework basemeasure_depth * aggressive tests passing!! * drop type-level stuff * drop help * license * affero * copyright notice * merge * Drop Create Commons * cleanup after merge * update support computations * insupport(d::SuperpositionMeasure, x) * dorp ParamWeighted * insupport(d::FactoredBase, x) * export unsafe_logdensityof * call promote_type instead of promote_rule * logdensity_def for named tuple product measures * type annotation for now * debugging * drop shows * speed up mapped arrays * throw an error for `Union{}` types * MT tests passing * updates * get tests passing * MIT license for MeasureBase * bump version * cleanup * spacing * Move ConditionalMeasure to MeasureBase * add LogarithmicNumbers * export basemeasure_sequence * update superpose * fix logdensity_rel * remove FIXME (it's fixed!!) * logdensityof(d::Density, x) * simplify insupport(::Lebesgue, ::Real) * clean up * assume insupport yields Bool * change logdensity_rel fall-through to warning and return NaN * update logdensity_rel * drop old code * fix warning * export logdensity_rel * logdensity_def(μ::Dirac, ν::Dirac, x) * logdensity_def methods * drop `static` * ]add StatsFuns * Fixing up superposition * [compat] entries * trying to speed things up * bugfixes * logdensity_rel tests * logdensity_rel tests * drop qualifier, and add a test * more tests * type constraint in "logdensityof(μ::AbstractMeasure, x)" (was piracy, oops) * add some docs * docs * docs * typo * moar speed * don't export Test * some more updates * logdensity_rel for products * `kleisli` docs * update instance_type * instance_type => Core.Typeof * `powermeasure` bug fix * fix logdensity_rel bug * get `commonbase` to take x type into account * test powers * commonbase docstring * deprecate instance_type * avoid breakage * switch || terms * @ifelse macro * simplify logdensity_rel * give up on this @ifelse business * bump version * working on likelihoods * update likelihood * powerweightedmeasure * powerweighted update * more powerweighted methods * bugfix * dropFactoredBase * drop FactoredBase * (::ProductMeasure) | constraint * update conditional measure * update Dirac * move conditional.jl down in the `include`s * Kleisli => TransitionKernel * simplify logdensity_def(::PowerMeasure, x) * rename kleisli.jl to kernel.jl * update Dirac tests * update Half * get tests passing * update kernel * Update Project.toml * no call-site inlining * restrict single-arg `kernel` to <:ParameterizedMeasure * export log_likelihood_ratio * Drop DensityKind(::Likelihood), at least for now * isfinite(x) instead of x>-Inf * add `condition` constructor * EOF newline * simplify logdensity_def for power measures * finishing up * updates * kernel stuff * kernel stuff * update showe methods * ass a TODO * use `dot` instead of `sum` * drop old code * typo * formatting * cleanup * kernel updates * uncomment * bugfix * drop old code * pretty printing * exports, cleanup * drop old for.jl * Make DensityKind(::AbstractLikelihood) = IsDensity() * update Compat version * Make likelihoods work with Distributions * _map(f, x::MappedArrays.ReadonlyMappedArray) * export productmeasure * AbstractMeasure(::AbstractMeasure) * fixedrng * StdNormal * add SpecialFunctions * no need to qualify * update basemeasure * include stdnormal * include fixedrng * update tests * using SpecialFunctions * fixing transport_def * transport_def bugfix * StdMeasure(::typeof(randn)) * checked_arg for LebesgueMeasure * NoTransformOrigin => NoTransportOrigin * transport interface for pushforwards * transporting pushforwards * Use LebesgueMeasure for basemeasure * updates * make testvalue fall back on FixedRNG approach * un-break testvalue * CI for Juila 1.8 * fixes * `rand` on a pushforward calls rand on its parent * LebesgueMeasure => LebesgueBase CountingMeasure => CountingBase * tests passing! * change `invoke` type * Change `test_interface` to check for 2-arg testvalue * manually-specifed inverses * more pushfwd stuff * A little less wrong * add mass interface * pullback * mass interface * working on mass interface * add some `massof` methods * Maybe <:Number is better for invalidations? * float instead of Int * logmassof * transports for proxies * drop latent-joint.jl * drop exports * Drop `logmassof` for now * reorganize Lebesgue measure * IntervalSets * proxy(::Lebesgue{MeasureBase.RealNumbers}) = LebesgueBase() * calling a "useproxy" measure calls its proxy * StdUniform()(s::Interval) * typo * (m::AbstractMeasure)(s::Interval) * bugfix * comment * IntervalSets version constraint * update dynamic_basemeasure_depth * format * Calling a measure calls `massof` * work on massof * AbstractSuperpositionMeasure * fix typo * typo * format * docstrings * remove massof(::PowerWeightedMeasure) method * make `massof` better * update testvalue * formatting * update _massof * Update transports for weighted measures * add chain rules * invariant mass under transport * typo * bugfix * hasmethod => Tricks.static_hasmethod * `massof` methods * roll back tranports for WeightedMeasure * Improve transport implementation and add product support (#97) * Improve default transport implementation Increases type stability. * Rename NoTransformOrigin to NoTransportOrigin * Add rrule for _origin_depth * Fix ambiguities when forwarding NoTransportOrigin and NoTransportOrigin * Define getdof for product measures * Generalize test_transport to tuple-valued measures * Implement transport for tuple-based products Co-authored-by: Chad Scherrer <[email protected]> * `@useproxy` delegates `massof` * drop CI for nightly * callable densities (#85) * callable densities * separate `Density` and `LogDensity`, etc * bugfix * move some code around * format * updates * working on densities * update CI * bugfix * formatting * reorg * fix typos * Drop LogDensityMeasure and refactor * docstring * inner type constructor with assertion * type parameters * 2-arg density_rel and logdensity_rel * oops * properties * typo * fix ambiguity * bugfix * bugfix * updates * drop densityof and logdensityof for AbstractDensity * updates * update tests * update * formatting * bad calls throw errors * drop CI for nightly * Pushfwd-inverses (#98) * use InverseFunctions.setinverse * bug fixes * bugfix * bugfix * pushfwd of a pushfwd * format * drop old comment * drop CI for nightly * working on pushfwd * bugfix * inverse(f) => ν.finv * separate logdensity functions from transport API * format * don't unwrap FunctionWithInverse * drop redundant method * leave logdensityof alone, instead write unsafe_logdensityof * more tests * more work on tests * tests * tests * still messing with tests * tests passing * small edits * formatting * add some more failing tests * add atol to isapprox in test * getdof(μ::PushforwardMeasure) = getdof(transport_origin(μ)) * update atol * small fix * drop ((-) ∘ log1p ∘ (-), StdUniform(), StdExponential()) * remove duplicate method * remove duplicate `include` * simplify getdof(::PushforwardMeasure) * Stieltjes measure function (#100) * smf * more smf stuff * transport_to * Lebesgue smf * smf for std measures * format * transport_to * oops * smfinv * bugfix * more fixes * minor refactoring * formatting * change x to p * bugfix * smfinv(::StdLogistic, p) * add NoSMF and NoSMFInverse * roll back some changes * transport_def methods * formatting * another rollback * make transport_def depend on smf(inv) * update smf and transports for ::Half * change `include` order * test_smf * more tests * tests * add tests * formatting * Drop unneeded type parameters * smfinv => invsmf * add some inverses * Base.Fix1 versions * some more methods * drop redundant `transport_def`s * update `pushfwd` * change name * add type * formatting * fix docstring * depend on FunctinoChains * Use fchain * simplify transport_def for StdLogistic * simplify transport_def for StdNormal * drop redundant method Co-authored-by: Oliver Schulz <[email protected]>
1 parent bc0a61e commit 2e9c2e6

37 files changed

+950
-288
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
version:
2626
- '1.6'
2727
- '1.7'
28-
- 'nightly'
28+
- '1.8'
2929
os:
3030
- ubuntu-latest
3131
arch:

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1010
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1111
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
13+
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
1314
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
15+
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
1416
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1517
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1618
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -21,6 +23,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
2123
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
2224
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2325
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
26+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2427
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2528
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2629
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -33,15 +36,18 @@ Compat = "3.35, 4"
3336
ConstructionBase = "1.3"
3437
DensityInterface = "0.4"
3538
FillArrays = "0.12, 0.13"
39+
FunctionChains = "0.1"
3640
IfElse = "0.1"
37-
InverseFunctions = "0.1.7"
41+
IntervalSets = "0.7"
42+
InverseFunctions = "0.1.8"
3843
IrrationalConstants = "0.1"
3944
LogExpFunctions = "0.3"
4045
LogarithmicNumbers = "1"
4146
MappedArrays = "0.4"
4247
NaNMath = "0.3, 1"
4348
PrettyPrinting = "0.3, 0.4"
4449
Reexport = "1"
50+
SpecialFunctions = "2"
4551
Static = "0.5, 0.6"
4652
Tricks = "0.1"
4753
julia = "1.3"

src/MeasureBase.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,16 @@ import Random: gentype
88
using Statistics
99
using LinearAlgebra
1010

11+
import IntervalSets
12+
# This seems harder than it should be to get `IntervalSets.:(..)`
13+
@eval (using IntervalSets: $(Symbol(IntervalSets.:(..))))
14+
15+
using IntervalSets: Interval, width
16+
1117
import DensityInterface: logdensityof
1218
import DensityInterface: densityof
1319
import DensityInterface: DensityKind
20+
using DensityInterface: FuncDensity, LogFuncDensity
1421
using DensityInterface
1522

1623
using InverseFunctions
@@ -19,13 +26,15 @@ using ChangesOfVariables
1926
import Base.iterate
2027
import ConstructionBase
2128
using ConstructionBase: constructorof
29+
using IntervalSets
2230

2331
using PrettyPrinting
2432
const Pretty = PrettyPrinting
2533

2634
using ChainRulesCore
2735
using FillArrays
2836
using Static
37+
using FunctionChains
2938

3039
export
3140
export gentype
@@ -108,17 +117,18 @@ using Compat
108117

109118
using IrrationalConstants
110119

120+
include("smf.jl")
111121
include("getdof.jl")
112122
include("transport.jl")
113123
include("schema.jl")
114124
include("splat.jl")
115125
include("proxies.jl")
116126
include("kernel.jl")
117127
include("parameterized.jl")
118-
include("combinators/half.jl")
119128
include("domains.jl")
120129
include("primitive.jl")
121130
include("utils.jl")
131+
include("mass-interface.jl")
122132
# include("absolutecontinuity.jl")
123133

124134
include("primitives/counting.jl")
@@ -144,9 +154,11 @@ include("standard/stdmeasure.jl")
144154
include("standard/stduniform.jl")
145155
include("standard/stdexponential.jl")
146156
include("standard/stdlogistic.jl")
147-
include("latent-joint.jl")
157+
include("standard/stdnormal.jl")
158+
include("combinators/half.jl")
148159

149160
include("rand.jl")
161+
include("fixedrng.jl")
150162

151163
include("density.jl")
152164
include("density-core.jl")

src/combinators/half.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,18 @@ logdensity_def(μ::Half, x) = logdensity_def(unhalf(μ), x)
2626
insupport(unhalf(d), x)
2727
end
2828

29-
testvalue(::Half) = 1.0
29+
testvalue(::Type{T}, ::Half) where {T} = one(T)
30+
31+
massof::Half) = massof(unhalf(μ))
32+
33+
function smf::Half, x)
34+
2 * smf.parent, max(x, zero(x))) - 1
35+
end
36+
37+
function invsmf::Half, p)
38+
@assert zero(p) p one(p)
39+
invsmf.parent, (p + 1) / 2)
40+
end
41+
42+
transport_def::Half, ::StdUniform, p) = invsmf(μ, p)
43+
transport_def(::StdUniform, μ::Half, x) = smf(μ, x)

src/combinators/likelihood.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,6 @@ more efficient than
202202
203203
logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x)
204204
"""
205-
likelihood_ratio(ℓ::Likelihood, p, q) = exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x))
205+
function likelihood_ratio(ℓ::Likelihood, p, q)
206+
exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x))
207+
end

src/combinators/power.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,5 @@ end
128128
function checked_arg::PowerMeasure, x::Any)
129129
throw(ArgumentError("Size of variate doesn't match size of power measure"))
130130
end
131+
132+
massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes)

src/combinators/product.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ function Pretty.tile(μ::AbstractProductMeasure)
1616
result *= Pretty.literal(")")
1717
end
1818

19+
massof(m::AbstractProductMeasure) = prod(massof, marginals(m))
20+
1921
export marginals
2022

2123
function Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure)
@@ -159,9 +161,11 @@ marginals(μ::ProductMeasure) = μ.marginals
159161

160162
# TODO: Better `map` support in MappedArrays
161163
_map(f, args...) = map(f, args...)
162-
_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(f x.f, x.data)
164+
_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(fchain((x.f, f)), x.data)
163165

164-
testvalue(d::AbstractProductMeasure) = _map(testvalue, marginals(d))
166+
function testvalue(::Type{T}, d::AbstractProductMeasure) where {T}
167+
_map(m -> testvalue(T, m), marginals(d))
168+
end
165169

166170
export
167171

@@ -220,3 +224,16 @@ end
220224
end
221225
return true
222226
end
227+
228+
getdof(d::AbstractProductMeasure) = mapreduce(getdof, +, marginals(d))
229+
230+
function checked_arg::ProductMeasure{<:NTuple{N,Any}}, x::NTuple{N,Any}) where {N}
231+
map(checked_arg, marginals(μ), x)
232+
end
233+
234+
function checked_arg(
235+
μ::ProductMeasure{<:NamedTuple{names}},
236+
x::NamedTuple{names},
237+
) where {names}
238+
NamedTuple{names}(map(checked_arg, values(marginals(μ)), values(x)))
239+
end

src/combinators/smart-constructors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ superpose(nt::NamedTuple) = SuperpositionMeasure(nt)
8585

8686
function superpose::T, ν::T) where {T<:AbstractMeasure}
8787
if μ == ν
88-
return weightedmeasure(logtwo, μ)
88+
return weightedmeasure(static(float(logtwo)), μ)
8989
else
9090
return superpose((μ, ν))
9191
end

src/combinators/spikemixture.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ function Base.rand(rng::AbstractRNG, T::Type, μ::SpikeMixture)
3737
return (rand(rng, T) < μ.w) * rand(rng, T, μ.m)
3838
end
3939

40-
testvalue::SpikeMixture) = testvalue.m)
40+
testvalue(::Type{T}, μ::SpikeMixture) where {T} = zero(T)
4141

4242
insupport::SpikeMixture, x) = dynamic(insupport.m, x)) || iszero(x)

src/combinators/superpose.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using LogExpFunctions
44

55
export SuperpositionMeasure
66

7+
abstract type AbstractSuperpositionMeasure <: AbstractMeasure end
8+
79
@doc raw"""
810
struct SuperpositionMeasure{NT} <: AbstractMeasure
911
components :: NT
@@ -24,17 +26,19 @@ Superposition measures satisfy
2426
\end{aligned}
2527
```
2628
"""
27-
struct SuperpositionMeasure{C} <: AbstractMeasure
29+
struct SuperpositionMeasure{C} <: AbstractSuperpositionMeasure
2830
components::C
2931
end
3032

33+
massof(m::SuperpositionMeasure) = sum(massof, m.components)
34+
3135
function Pretty.tile(d::SuperpositionMeasure)
3236
result = Pretty.literal("SuperpositionMeasure(")
3337
result *= Pretty.list_layout([Pretty.tile.(d.components)...])
3438
result *= Pretty.literal(")")
3539
end
3640

37-
testvalue::SuperpositionMeasure) = testvalue(first.components))
41+
testvalue(::Type{T}, μ::SuperpositionMeasure) where {T} = testvalue(T, first.components))
3842

3943
# SuperpositionMeasure(ms :: AbstractMeasure...) = SuperpositionMeasure{X,length(ms)}(ms)
4044

0 commit comments

Comments
 (0)