Skip to content

Commit 4b0ef78

Browse files
authored
formatting (#64)
* formatting * bump version
1 parent f7b4444 commit 4b0ef78

File tree

13 files changed

+105
-90
lines changed

13 files changed

+105
-90
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.12.0"
4+
version = "0.12.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/combinators/power.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ end
8686
end
8787

8888
@inline function logdensity_def(
89-
d::PowerMeasure{M,NTuple{N, Base.OneTo{StaticInt{0}}}},
89+
d::PowerMeasure{M,NTuple{N,Base.OneTo{StaticInt{0}}}},
9090
x,
9191
) where {M,N}
9292
static(0.0)
@@ -108,11 +108,11 @@ end
108108
end
109109
end
110110

111-
112111
@inline getdof::PowerMeasure) = getdof.parent) * prod(map(length, μ.axes))
113112

114-
@inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0)
115-
113+
@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Base.OneTo{StaticInt{0}}}}) where {N}
114+
static(0)
115+
end
116116

117117
@propagate_inbounds function checked_arg::PowerMeasure, x::AbstractArray{<:Any})
118118
@boundscheck begin

src/combinators/transformedmeasure.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ function paramnames(::AbstractTransformedMeasure) end
1414

1515
function parent(::AbstractTransformedMeasure) end
1616

17-
1817
export PushforwardMeasure
1918

2019
"""
@@ -35,13 +34,14 @@ end
3534
gettransform::PushforwardMeasure) = ν.f
3635
parent::PushforwardMeasure) = ν.origin
3736

38-
3937
function Pretty.tile::PushforwardMeasure)
4038
Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure)
4139
end
4240

43-
44-
@inline function logdensity_def::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M}
41+
@inline function logdensity_def(
42+
ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr},
43+
y,
44+
) where {FF,IF,M}
4545
x_orig, inv_ladj = with_logabsdet_jacobian.inv_f, y)
4646
logd_orig = logdensity_def.origin, x_orig)
4747
logd = float(logd_orig + inv_ladj)
@@ -53,16 +53,18 @@ end
5353
# Return constant -Inf to prevent problems with ForwardDiff:
5454
(isfinite(logd_orig) && (inv_ladj == -Inf)),
5555
neginf,
56-
logd
56+
logd,
5757
)
5858
end
5959

60-
@inline function logdensity_def::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M}
60+
@inline function logdensity_def(
61+
ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr},
62+
y,
63+
) where {FF,IF,M}
6164
x_orig = to_origin(ν, y)
6265
return logdensity_def.origin, x_orig)
6366
end
6467

65-
6668
insupport::PushforwardMeasure, y) = insupport(transport_origin(ν), to_origin(ν, y))
6769

6870
testvalue::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(ν)))
@@ -71,30 +73,27 @@ testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(
7173
PushforwardMeasure.f, ν.inv_f, basemeasure(transport_origin(ν)), NoVolCorr())
7274
end
7375

74-
75-
_pushfwd_dof(::Type{MU}, ::Type, dof) where MU = NoDOF{MU}()
76-
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof
76+
_pushfwd_dof(::Type{MU}, ::Type, dof) where {MU} = NoDOF{MU}()
77+
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof
7778

7879
# Assume that DOF are preserved if with_logabsdet_jacobian is functional:
7980
@inline function getdof::MU) where {MU<:PushforwardMeasure}
8081
T = Core.Compiler.return_type(testvalue, Tuple{typeof.origin)})
81-
R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof.f), T})
82+
R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof.f),T})
8283
_pushfwd_dof(MU, R, getdof.origin))
8384
end
8485

8586
# Bypass `checked_arg`, would require potentially costly transformation:
8687
@inline checked_arg(::PushforwardMeasure, x) = x
8788

88-
8989
@inline transport_origin::PushforwardMeasure) = ν.origin
9090
@inline from_origin::PushforwardMeasure, x) = ν.f(x)
9191
@inline to_origin::PushforwardMeasure, y) = ν.inv_f(y)
9292

93-
function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T
93+
function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where {T}
9494
return from_origin(ν, rand(rng, T, transport_origin(ν)))
9595
end
9696

97-
9897
export pushfwd
9998

10099
"""

src/getdof.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ a global property of the measure.
77
"""
88
struct NoDOF{MU} end
99

10-
1110
"""
1211
getdof(μ)
1312
@@ -23,11 +22,10 @@ Also see [`check_dof`](@ref).
2322
function getdof end
2423

2524
# Prevent infinite recursion:
26-
@inline _default_getdof(::Type{MU}, ::MU) where MU = NoDOF{MU}
27-
@inline _default_getdof(::Type{MU}, mu_base) where MU = getdof(mu_base)
28-
29-
@inline getdof::MU) where MU = _default_getdof(MU, basemeasure(μ))
25+
@inline _default_getdof(::Type{MU}, ::MU) where {MU} = NoDOF{MU}
26+
@inline _default_getdof(::Type{MU}, mu_base) where {MU} = getdof(mu_base)
3027

28+
@inline getdof::MU) where {MU} = _default_getdof(MU, basemeasure(μ))
3129

3230
"""
3331
MeasureBase.check_dof(ν, μ)::Nothing
@@ -41,15 +39,18 @@ function check_dof(ν, μ)
4139
n_ν = getdof(ν)
4240
n_μ = getdof(μ)
4341
if n_ν != n_μ
44-
throw(ArgumentError("Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF"))
42+
throw(
43+
ArgumentError(
44+
"Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF",
45+
),
46+
)
4547
end
4648
return nothing
4749
end
4850

4951
_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
5052
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback
5153

52-
5354
"""
5455
MeasureBase.NoArgCheck{MU,T}
5556
@@ -58,7 +59,6 @@ variate of measures of type `MU`.
5859
"""
5960
struct NoArgCheck{MU,T} end
6061

61-
6262
"""
6363
MeasureBase.checked_arg(μ::MU, x::T)::T
6464
@@ -68,10 +68,16 @@ return `NoArgCheck{MU,T}()` if not check can be performed.
6868
function checked_arg end
6969

7070
# Prevent infinite recursion:
71-
@propagate_inbounds _default_checked_arg(::Type{MU}, ::MU, ::T) where {MU,T} = NoArgCheck{MU,T}
72-
@propagate_inbounds _default_checked_arg(::Type{MU}, mu_base, x) where MU = checked_arg(mu_base, x)
71+
@propagate_inbounds function _default_checked_arg(::Type{MU}, ::MU, ::T) where {MU,T}
72+
NoArgCheck{MU,T}
73+
end
74+
@propagate_inbounds function _default_checked_arg(::Type{MU}, mu_base, x) where {MU}
75+
checked_arg(mu_base, x)
76+
end
7377

74-
@propagate_inbounds checked_arg(mu::MU, x) where MU = _default_checked_arg(MU, basemeasure(mu), x)
78+
@propagate_inbounds function checked_arg(mu::MU, x) where {MU}
79+
_default_checked_arg(MU, basemeasure(mu), x)
80+
end
7581

7682
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
7783
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback

src/insupport.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ insupport(m)(x) == insupport(m, x)
1010
"""
1111
function insupport end
1212

13-
1413
"""
1514
MeasureBase.require_insupport(μ, x)::Nothing
1615

src/interface.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,9 @@ function test_interface(μ::M) where {M}
6565
end
6666
end
6767

68-
6968
function test_transport(ν, μ)
7069
supertype(x::Real) = Real
71-
supertype(x::AbstractArray{<:Real,N}) where N = AbstractArray{<:Real,N}
70+
supertype(x::AbstractArray{<:Real,N}) where {N} = AbstractArray{<:Real,N}
7271

7372
@testset "transport_to to " begin
7473
x = rand(μ)
@@ -82,7 +81,7 @@ function test_transport(ν, μ)
8281
x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y)
8382
@test x x2
8483
@test y y2
85-
@test ladj_fwd - ladj_inv
84+
@test ladj_fwd -ladj_inv
8685
@test ladj_fwd logdensityof(μ, x) - logdensityof(ν, y)
8786
end
8887
end

src/standard/stdexponential.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ insupport(d::StdExponential, x) = x ≥ zero(x)
77
@inline logdensity_def(::StdExponential, x) = -x
88
@inline basemeasure(::StdExponential) = Lebesgue()
99

10-
@inline transport_def(::StdUniform, μ::StdExponential, x) = - expm1(-x)
11-
@inline transport_def(::StdExponential, μ::StdUniform, x) = - log1p(-x)
10+
@inline transport_def(::StdUniform, μ::StdExponential, x) = -expm1(-x)
11+
@inline transport_def(::StdExponential, μ::StdUniform, x) = -log1p(-x)
1212

1313
Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} = randexp(rng, T)
14-

src/standard/stdlogistic.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ export StdLogistic
44

55
@inline insupport(d::StdLogistic, x) = true
66

7-
@inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u))
7+
@inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u))
88
@inline basemeasure(::StdLogistic) = Lebesgue()
99

1010
@inline transport_def(::StdUniform, μ::StdLogistic, x) = logistic(x)
1111
@inline transport_def(::StdLogistic, μ::StdUniform, x) = logit(x)
1212

13-
@inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T))
13+
@inline function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T}
14+
logit(rand(rng, T))
15+
end

src/standard/stdmeasure.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@ abstract type StdMeasure <: AbstractMeasure end
33
StdMeasure(::typeof(rand)) = StdUniform()
44
StdMeasure(::typeof(randexp)) = StdExponential()
55

6-
76
@inline check_dof(::StdMeasure, ::StdMeasure) = nothing
87

9-
108
@inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x
119

1210
function transport_def::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x)
@@ -17,24 +15,34 @@ function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x)
1715
return Fill(transport_def.parent, μ, only(x)), map(length, ν.axes)...)
1816
end
1917

20-
function transport_def::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x)
18+
function transport_def(
19+
ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}},
20+
μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}},
21+
x,
22+
)
2123
return transport_to.parent, μ.parent).(x)
2224
end
2325

24-
function transport_def::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M}
26+
function transport_def(
27+
ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}},
28+
μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}},
29+
x,
30+
) where {N,M}
2531
return reshape(transport_to.parent, μ.parent).(x), map(length, ν.axes)...)
2632
end
2733

28-
2934
# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):
3035

3136
_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
3237
_std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof
3338
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))
3439

35-
MeasureBase.transport_to(::Type{NU}, μ) where {NU<:StdMeasure} = transport_to(_std_measure_for(NU, μ), μ)
36-
MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} = transport_to(ν, _std_measure_for(MU, ν))
37-
40+
function MeasureBase.transport_to(::Type{NU}, μ) where {NU<:StdMeasure}
41+
transport_to(_std_measure_for(NU, μ), μ)
42+
end
43+
function MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:StdMeasure}
44+
transport_to(ν, _std_measure_for(MU, ν))
45+
end
3846

3947
# Transform between standard measures and Dirac:
4048

0 commit comments

Comments
 (0)