Skip to content

Commit 08cb10b

Browse files
committed
Add test_with_logabsdet_jacobian
1 parent cb0d3b1 commit 08cb10b

File tree

6 files changed

+157
-47
lines changed

6 files changed

+157
-47
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ version = "0.1.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
78

89
[compat]
910
julia = "1"
1011

1112
[extras]
1213
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1314
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
14-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1515

1616
[targets]
17-
test = ["Documenter", "ForwardDiff", "Test"]
17+
test = ["Documenter", "ForwardDiff"]

src/ChangesOfVariables.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ transformations).
1010
module ChangesOfVariables
1111

1212
using LinearAlgebra
13+
using Test
1314

1415
include("with_ladj.jl")
16+
include("test.jl")
1517

1618
end # module

src/test.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
2+
3+
4+
_to_realvec_and_back(V::AbstractVector{<:Real}) = V, identity
5+
_to_realvec_and_back(x::Real) = [x], V -> V[1]
6+
_to_realvec_and_back(x::Complex) = [real(x), imag(x)], V -> Complex(V[1], V[2])
7+
_to_realvec_and_back(x::NTuple{N}) where N = [x...], V -> ntuple(i -> V[i], Val(N))
8+
9+
function _to_realvec_and_back(x::Ref)
10+
xval = x[]
11+
V, to_xval = _to_realvec_and_back(xval)
12+
back_to_ref(V) = Ref(to_xval(V))
13+
return (V, back_to_ref)
14+
end
15+
16+
_to_realvec_and_back(A::AbstractArray{<:Real}) = vec(A), V -> reshape(V, size(A))
17+
18+
function _to_realvec_and_back(A::AbstractArray{Complex{T}, N}) where {T<:Real, N}
19+
RA = cat(real.(A), imag.(A), dims = N+1)
20+
V, to_array = _to_realvec_and_back(RA)
21+
function back_to_complex(V)
22+
RA = to_array(V)
23+
Complex.(view(RA, map(_ -> :, size(A))..., 1), view(RA, map(_ -> :, size(A))..., 2))
24+
end
25+
return (V, back_to_complex)
26+
end
27+
28+
29+
_to_realvec(x) = _to_realvec_and_back(x)[1]
30+
31+
32+
function _auto_with_logabsdet_jacobian(f, x, getjacobian)
33+
y = f(x)
34+
V, to_x = _to_realvec_and_back(x)
35+
vf(V) = _to_realvec(f(to_x(V)))
36+
ladj = logabsdet(getjacobian(vf, V))[1]
37+
return (y, ladj)
38+
end
39+
40+
41+
"""
42+
ChangesOfVariables.test_with_logabsdet_jacobian(
43+
f, x, getjacobian;
44+
test_inferred::Bool = true, kwargs...
45+
)
46+
47+
Test if [`with_logabsdet_jacobian(f, x)`](@ref) is implemented correctly.
48+
49+
Checks if the result of `with_logabsdet_jacobian(f, x)` is approximately
50+
equal to `(f(x), logabsdet(getjacobian(f, x)))`
51+
52+
So the test uses `getjacobian(f, x)` to calculate a reference Jacobian for
53+
`f` at `x`. Passing `ForwardDiff.jabobian`, `Zygote.jacobian` or similar as
54+
the `getjacobian` function will do fine in most cases.
55+
56+
If `x` or `f(x)` are real-valued scalars or complex-valued scalars or arrays,
57+
the test will try to reshape them automatically, to account for limitations
58+
of (e.g.) `ForwardDiff` and to ensure the result of `getjacobian` is a real
59+
matrix.
60+
61+
If `test_inferred == true` will test type inference on
62+
`with_logabsdet_jacobian`.
63+
64+
`kwargs...` are forwarded to `isapprox`.
65+
"""
66+
function test_with_logabsdet_jacobian(f, x, getjacobian; compare=isapprox, test_inferred::Bool = true, kwargs...)
67+
@testset "test_with_logabsdet_jacobian: $f with input $x" begin
68+
y, ladj = if test_inferred
69+
@inferred with_logabsdet_jacobian(f, x)
70+
else
71+
with_logabsdet_jacobian(f, x)
72+
end
73+
ref_y, ref_ladj = _auto_with_logabsdet_jacobian(f, x, getjacobian)
74+
@test compare(y, ref_y; kwargs...)
75+
@test compare(ladj, ref_ladj; kwargs...)
76+
end
77+
return nothing
78+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ChangesOfVariables
55
import Documenter
66

77
Test.@testset "Package ChangesOfVariables" begin
8+
include("test_test.jl")
89
include("test_with_ladj.jl")
910

1011
# doctests
@@ -16,3 +17,4 @@ Test.@testset "Package ChangesOfVariables" begin
1617
)
1718
Documenter.doctest(ChangesOfVariables)
1819
end # testset
20+

test/test_test.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
2+
3+
using ChangesOfVariables
4+
using Test
5+
6+
using LinearAlgebra
7+
import ForwardDiff
8+
9+
using ChangesOfVariables: test_with_logabsdet_jacobian
10+
11+
12+
@testset "test_with_logabsdet_jacobian" begin
13+
@testset "ChangesOfVariables._to_realvec_and_back" begin
14+
for x in (rand(3), 0.5, Complex(0.2,0.7), (3,5,9), Ref(42), rand(3, 4, 5), Complex.(rand(3,5), rand(3,5)))
15+
V, to_x = ChangesOfVariables._to_realvec_and_back(x)
16+
@test V isa AbstractVector{<:Real}
17+
@test V == ChangesOfVariables._to_realvec(x)
18+
@test x isa Ref ? to_x(V)[] == x[] : to_x(V) == x
19+
end
20+
end
21+
22+
x = Complex(0.2, -0.7)
23+
y, ladj_y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, x, ForwardDiff.jacobian)
24+
@test y == inv(x)
25+
@test ladj_y -4 * log(abs(x))
26+
27+
X = Complex.(randn(3,3), randn(3,3))
28+
Y, ladj_Y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian)
29+
@test Y == inv(X)
30+
@test ladj_Y == -4 * 3 * logabsdet(X)[1]
31+
32+
myisapprox(a, b; kwargs...) = isapprox(a, b; kwargs...)
33+
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, test_inferred = true, atol = 10^-6)
34+
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, test_inferred = false, atol = 10^-6)
35+
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, compare = myisapprox, atol = 10^-6)
36+
end

test/test_with_ladj.jl

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,66 +4,58 @@ using ChangesOfVariables
44
using Test
55

66
using LinearAlgebra
7-
using ForwardDiff: derivative, jacobian
7+
import ForwardDiff
88

9+
using ChangesOfVariables: test_with_logabsdet_jacobian
910

10-
fwddiff_ladj(f, x::Real) = log(abs(derivative(f, x)))
11-
fwddiff_ladj(f, x::AbstractArray{<:Real}) = logabsdet(jacobian(f, x))[1]
12-
fwddiff_with_ladj(f, x) = (f(x), fwddiff_ladj(f, x))
1311

14-
ascomplex(A::AbstractArray{T}) where T = reinterpret(Complex{T}, A)
15-
asreal(A::AbstractArray{Complex{T}}) where T = reinterpret(T, A)
16-
17-
isaprx(a, b) = isapprox(a,b)
18-
isaprx(a::NTuple{N,Any}, b::NTuple{N,Any}) where N = all(map(isaprx, a, b))
19-
20-
21-
foo(x) = inv(exp(-x) + 1)
12+
@testset "with_logabsdet_jacobian" begin
13+
foo(x) = inv(exp(-x) + 1)
2214

23-
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(foo), x)
24-
y = foo(x)
25-
ladj = -x + 2 * log(y)
26-
(y, ladj)
27-
end
15+
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(foo), x)
16+
y = foo(x)
17+
ladj = -x + 2 * log(y)
18+
(y, ladj)
19+
end
2820

2921

30-
@testset "with_logabsdet_jacobian" begin
3122
x = 4.2
3223
X = rand(10)
3324
A = rand(5, 5)
34-
CA = rand(10, 5)
25+
CA = Complex.(rand(5, 5), rand(5, 5))
3526

36-
@test isaprx(@inferred(with_logabsdet_jacobian(foo, x)), fwddiff_with_ladj(foo, x))
27+
28+
getjacobian = ForwardDiff.jacobian
29+
30+
isaprx(a, b; kwargs...) = isapprox(a,b; kwargs...)
31+
isaprx(a::NTuple{N,Any}, b::NTuple{N,Any}; kwargs...) where N = all(map((a,b) -> isaprx(a, b; kwargs...), a, b))
32+
33+
34+
test_with_logabsdet_jacobian(foo, x, getjacobian)
3735

3836
@static if VERSION >= v"1.6"
39-
log_foo = log foo
40-
@test isaprx(@inferred(with_logabsdet_jacobian(log_foo, x)), fwddiff_with_ladj(log_foo, x))
37+
test_with_logabsdet_jacobian(log foo, x, getjacobian)
38+
end
39+
40+
@testset "getjacobian on mapped and broadcasted" begin
41+
for f in (Base.Fix1(map, foo), Base.Fix1(broadcast, foo))
42+
for arg in (x, fill(x,), Ref(x), (x,), X)
43+
test_with_logabsdet_jacobian(f, arg, getjacobian, compare = isaprx)
44+
end
45+
end
4146
end
4247

43-
mapped_foo = Base.Fix1(map, foo)
44-
@test isaprx(@inferred(with_logabsdet_jacobian(mapped_foo, x)), fwddiff_with_ladj(mapped_foo, x))
45-
@test isaprx(@inferred(with_logabsdet_jacobian(mapped_foo, fill(x))), fwddiff_with_ladj(mapped_foo, fill(x)))
46-
@test isaprx(@inferred(with_logabsdet_jacobian(mapped_foo, Ref(x))), fwddiff_with_ladj(mapped_foo, fill(x)))
47-
@test isaprx(@inferred(with_logabsdet_jacobian(mapped_foo, (x,))), (mapped_foo((x,)), fwddiff_ladj(mapped_foo, x)))
48-
@test isaprx(@inferred(with_logabsdet_jacobian(mapped_foo, X)), fwddiff_with_ladj(mapped_foo, X))
49-
50-
broadcasted_foo = Base.Fix1(broadcast, foo)
51-
@test isaprx(@inferred(with_logabsdet_jacobian(broadcasted_foo, x)), fwddiff_with_ladj(broadcasted_foo, x))
52-
@test isaprx(@inferred(with_logabsdet_jacobian(broadcasted_foo, fill(x))), fwddiff_with_ladj(broadcasted_foo, x))
53-
@test isaprx(@inferred(with_logabsdet_jacobian(broadcasted_foo, Ref(x))), fwddiff_with_ladj(broadcasted_foo, x))
54-
@test isaprx(@inferred(with_logabsdet_jacobian(broadcasted_foo, (x,))), (mapped_foo((x,)), fwddiff_ladj(mapped_foo, x)))
55-
@test isaprx(@inferred(with_logabsdet_jacobian(broadcasted_foo, X)), fwddiff_with_ladj(broadcasted_foo, X))
56-
57-
for f in (identity, adjoint, transpose)
58-
@test isaprx(@inferred(with_logabsdet_jacobian(f, x)), fwddiff_with_ladj(f, x))
59-
@test isaprx(@inferred(with_logabsdet_jacobian(f, A)), fwddiff_with_ladj(f, A))
48+
@testset "getjacobian on identity, adjoint and transpose" begin
49+
for f in (identity, adjoint, transpose)
50+
for arg in (x, A)
51+
test_with_logabsdet_jacobian(f, arg, getjacobian)
52+
end
53+
end
6054
end
61-
62-
@test isaprx(@inferred(with_logabsdet_jacobian(inv, x)), fwddiff_with_ladj(inv, x))
63-
@test isaprx(@inferred(with_logabsdet_jacobian(inv, A)), fwddiff_with_ladj(inv, A))
64-
@test isaprx(@inferred(with_logabsdet_jacobian(inv, ascomplex(CA))), (inv(ascomplex(CA)), fwddiff_ladj(CA -> asreal(inv(ascomplex(CA))), CA)))
6555

66-
for f in (exp, log, exp2, log2, exp10, log10, expm1, log1p)
67-
@test isaprx(@inferred(with_logabsdet_jacobian(f, x)), fwddiff_with_ladj(f, x))
56+
@testset "getjacobian on inv" begin
57+
for arg in (x, A, CA)
58+
test_with_logabsdet_jacobian(inv, arg, getjacobian)
59+
end
6860
end
6961
end

0 commit comments

Comments
 (0)