Skip to content

Commit 5f1d99d

Browse files
authored
Add ChangesOfVariables definitions and extend tests
1 parent da5130f commit 5f1d99d

File tree

5 files changed

+54
-8
lines changed

5 files changed

+54
-8
lines changed

ext/LogExpFunctionsChangesOfVariablesExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,25 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1pexp), x::Real)
88
y = log1pexp(x)
99
return y, x - y
1010
end
11+
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(softplus), x::Real)
12+
return ChangesOfVariables.with_logabsdet_jacobian(log1pexp, x)
13+
end
14+
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(softplus),<:Real}, x::Real)
15+
y = f(x)
16+
return y, f.x * (x - y)
17+
end
1118

1219
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logexpm1), x::Real)
1320
y = logexpm1(x)
1421
return y, x - y
1522
end
23+
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(invsoftplus), x::Real)
24+
return ChangesOfVariables.with_logabsdet_jacobian(logexpm1, x)
25+
end
26+
function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(invsoftplus),<:Real}, x::Real)
27+
y = f(x)
28+
return y, f.x * (x - y)
29+
end
1630

1731
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mexp), x::Real)
1832
y = log1mexp(x)

ext/LogExpFunctionsInverseFunctionsExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp
2323
InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic
2424

2525
InverseFunctions.inverse(::typeof(softplus)) = invsoftplus
26+
function InverseFunctions.inverse(f::Base.Fix2{typeof(softplus),<:Real})
27+
Base.Fix2(invsoftplus, f.x)
28+
end
29+
2630
InverseFunctions.inverse(::typeof(invsoftplus)) = softplus
31+
function InverseFunctions.inverse(f::Base.Fix2{typeof(invsoftplus),<:Real})
32+
Base.Fix2(softplus, f.x)
33+
end
2734

2835
end # module

test/basicfuns.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,15 @@ end
162162
end
163163

164164
@testset "softplus" begin
165-
@test softplus(2) log1pexp(2)
166-
@test softplus(2, 1) log1pexp(2)
167-
@test softplus(2, 10) < log1pexp(2)
168-
@test invsoftplus(softplus(2), 1) 2
165+
for T in (Int, Float64, Float32, Float16)
166+
@test @inferred(softplus(T(2))) === log1pexp(T(2))
167+
@test @inferred(softplus(T(2), 1)) isa float(T)
168+
@test @inferred(softplus(T(2), 1)) softplus(T(2))
169+
@test @inferred(softplus(T(2), 5)) softplus(5 * T(2)) / 5
170+
@test @inferred(softplus(T(2), 10)) softplus(10 * T(2)) / 10
171+
end
169172
end
170173

171-
172174
@testset "log1mexp" begin
173175
for T in (Float64, Float32, Float16)
174176
@test @inferred(log1mexp(-T(1))) isa T
@@ -194,6 +196,16 @@ end
194196
end
195197
end
196198

199+
@testset "invsoftplus" begin
200+
for T in (Int, Float64, Float32, Float16)
201+
@test @inferred(invsoftplus(T(2))) === logexpm1(T(2))
202+
@test @inferred(invsoftplus(T(2), 1)) isa float(T)
203+
@test @inferred(invsoftplus(T(2), 1)) invsoftplus(T(2))
204+
@test @inferred(invsoftplus(T(2), 5)) invsoftplus(5 * T(2)) / 5
205+
@test @inferred(invsoftplus(T(2), 10)) invsoftplus(10 * T(2)) / 10
206+
end
207+
end
208+
197209
@testset "log1pmx" begin
198210
@test iszero(log1pmx(0.0))
199211
@test log1pmx(1.0) log(2.0) - 1.0

test/inverse.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
@testset "inverse.jl" begin
22
InverseFunctions.test_inverse(log1pexp, randn())
3+
InverseFunctions.test_inverse(softplus, randn())
4+
InverseFunctions.test_inverse(Base.Fix2(softplus, randexp()), randn())
5+
36
InverseFunctions.test_inverse(logexpm1, randexp())
7+
InverseFunctions.test_inverse(invsoftplus, randexp())
8+
InverseFunctions.test_inverse(Base.Fix2(invsoftplus, randexp()), randexp())
49

510
InverseFunctions.test_inverse(log1mexp, -randexp())
611

@@ -17,7 +22,4 @@
1722

1823
InverseFunctions.test_inverse(log1mlogistic, randexp())
1924
InverseFunctions.test_inverse(logit1mexp, -randexp())
20-
21-
InverseFunctions.test_inverse(softplus, randn())
22-
InverseFunctions.test_inverse(invsoftplus, randexp())
2325
end

test/with_logabsdet_jacobian.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
@testset "with_logabsdet_jacobian" begin
22
derivative(f, x) = ChainRulesTestUtils.frule((ChainRulesTestUtils.NoTangent(), 1), f, x)[2]
3+
derivative(::typeof(softplus), x) = derivative(log1pexp, x)
4+
derivative(f::Base.Fix2{typeof(softplus),<:Real}, x) = derivative(log1pexp, f.x * x)
5+
derivative(::typeof(invsoftplus), x) = derivative(logexpm1, x)
6+
derivative(f::Base.Fix2{typeof(invsoftplus),<:Real}, x) = derivative(logexpm1, f.x * x)
37

48
x = randexp()
9+
y = randexp()
510

611
ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, x, derivative)
712
ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, -x, derivative)
13+
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, x, derivative)
14+
ChangesOfVariables.test_with_logabsdet_jacobian(softplus, -x, derivative)
15+
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), x, derivative)
16+
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), -x, derivative)
817

918
ChangesOfVariables.test_with_logabsdet_jacobian(logexpm1, x, derivative)
19+
ChangesOfVariables.test_with_logabsdet_jacobian(invsoftplus, x, derivative)
20+
ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(invsoftplus, y), x, derivative)
1021

1122
ChangesOfVariables.test_with_logabsdet_jacobian(log1mexp, -x, derivative)
1223

0 commit comments

Comments
 (0)