Skip to content

Commit 3abd5a2

Browse files
committed
Simplify test_with_logabsdet_jacobian
1 parent 98eee46 commit 3abd5a2

File tree

5 files changed

+58
-72
lines changed

5 files changed

+58
-72
lines changed

src/test.jl

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33

44
"""
5-
ChangesOfVariables.test_with_logabsdet_jacobian(
6-
f, x, getjacobian, rv_and_back = x -> (x, identity);
7-
compare = isapprox, kwargs...
8-
)
5+
ChangesOfVariables.test_with_logabsdet_jacobian(f, x, getjacobian; compare = isapprox, kwargs...)
96
107
Test if [`with_logabsdet_jacobian(f, x)`](@ref) is implemented correctly.
118
@@ -17,25 +14,14 @@ So the test uses `getjacobian(f, x)` to calculate a reference Jacobian for
1714
the `getjacobian` function will do fine in most cases. If input and output
1815
of `f` are real scalar values, use `ForwardDiff.derivative`.
1916
20-
If `getjacobian(f, x)` can't handle the type of `x` or `f(x)` because they
21-
are not real-valued vectors, use the `rv_and_back` argument to pass a
22-
function with the following behavior
23-
24-
```julia
25-
v, back = rv_and_back(x)
26-
v isa AbstractVector{<:Real}
27-
back(v) == x
28-
```
29-
30-
If `test_inferred == true`, type inference on `with_logabsdet_jacobian` will
31-
be tested.
17+
Note that the result of `getjacobian(f, x)` must be a real-valued matrix
18+
or a real scalar, so you may need to use a custom `getjacobian` function
19+
that transforms the shape of `x` and `f(x)` internally, in conjunction
20+
with automatic differentiation.
3221
3322
`kwargs...` are forwarded to `compare`.
3423
"""
35-
function test_with_logabsdet_jacobian(
36-
f, x, getjacobian, rv_and_back = x -> (x, identity);
37-
compare = isapprox, kwargs...
38-
)
24+
function test_with_logabsdet_jacobian(f, x, getjacobian; compare = isapprox, kwargs...)
3925
@testset "test_with_logabsdet_jacobian: $f with input $x" begin
4026
ref_y, test_type_inference = try
4127
@inferred(f(x)), true
@@ -49,9 +35,7 @@ function test_with_logabsdet_jacobian(
4935
with_logabsdet_jacobian(f, x)
5036
end
5137

52-
V, to_x = rv_and_back(x)
53-
vf(V) = rv_and_back(f(to_x(V)))[1]
54-
ref_ladj = _generalized_logabsdet(getjacobian(vf, V))[1]
38+
ref_ladj = _generalized_logabsdet(getjacobian(f, x))[1]
5539

5640
@test compare(y, ref_y; kwargs...)
5741
@test compare(ladj, ref_ladj; kwargs...)

test/getjacobian.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
2+
3+
import ForwardDiff
4+
5+
torv_and_back(V::AbstractVector{<:Real}) = V, identity
6+
torv_and_back(x::Real) = [x], V -> V[1]
7+
torv_and_back(x::Complex) = [real(x), imag(x)], V -> Complex(V[1], V[2])
8+
torv_and_back(x::NTuple{N}) where N = [x...], V -> ntuple(i -> V[i], Val(N))
9+
10+
function torv_and_back(x::Ref)
11+
xval = x[]
12+
V, to_xval = torv_and_back(xval)
13+
back_to_ref(V) = Ref(to_xval(V))
14+
return (V, back_to_ref)
15+
end
16+
17+
torv_and_back(A::AbstractArray{<:Real}) = vec(A), V -> reshape(V, size(A))
18+
19+
function torv_and_back(A::AbstractArray{Complex{T}, N}) where {T<:Real, N}
20+
RA = cat(real.(A), imag.(A), dims = N+1)
21+
V, to_array = torv_and_back(RA)
22+
function back_to_complex(V)
23+
RA = to_array(V)
24+
Complex.(view(RA, map(_ -> :, size(A))..., 1), view(RA, map(_ -> :, size(A))..., 2))
25+
end
26+
return (V, back_to_complex)
27+
end
28+
29+
30+
function getjacobian(f, x)
31+
V, to_x = torv_and_back(x)
32+
vf(V) = torv_and_back(f(to_x(V)))[1]
33+
ForwardDiff.jacobian(vf, V)
34+
end
35+
36+
foo(x) = inv(exp(-x) + 1)

test/rv_and_back.jl

Lines changed: 0 additions & 26 deletions
This file was deleted.

test/test_test.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@ using ChangesOfVariables
44
using Test
55

66
using LinearAlgebra
7-
import ForwardDiff
87

98
using ChangesOfVariables: test_with_logabsdet_jacobian
109

10+
include("getjacobian.jl")
1111

12-
include("rv_and_back.jl")
1312

14-
@testset "rv_and_back" begin
13+
@testset "torv_and_back" begin
1514
for x in (rand(3), 0.5, Complex(0.2,0.7), (3,5,9), Ref(42), rand(3, 4, 5), Complex.(rand(3,5), rand(3,5)))
16-
V, to_x = rv_and_back(x)
15+
V, to_x = torv_and_back(x)
1716
@test V isa AbstractVector{<:Real}
18-
@test V == rv_and_back(x)[1]
17+
@test V == torv_and_back(x)[1]
1918
@test x isa Ref ? to_x(V)[] == x[] : to_x(V) == x
2019
end
2120
end
@@ -34,10 +33,10 @@ end
3433
@test_throws ErrorException @inferred with_logabsdet_jacobian(noninferrable_inv, rand(2, 2))
3534

3635
test_with_logabsdet_jacobian(inv, rx, ForwardDiff.derivative, atol = 10^-6)
37-
test_with_logabsdet_jacobian(inv, cx, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
36+
test_with_logabsdet_jacobian(inv, cx, getjacobian, atol = 10^-6)
3837
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, atol = 10^-6)
39-
test_with_logabsdet_jacobian(inv, CX, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
40-
test_with_logabsdet_jacobian(inv, CX, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
41-
test_with_logabsdet_jacobian(inv, CX, ForwardDiff.jacobian, rv_and_back, compare = myisapprox, atol = 10^-6)
42-
test_with_logabsdet_jacobian(noninferrable_inv, CX, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
38+
test_with_logabsdet_jacobian(inv, CX, getjacobian, atol = 10^-6)
39+
test_with_logabsdet_jacobian(inv, CX, getjacobian, atol = 10^-6)
40+
test_with_logabsdet_jacobian(inv, CX, getjacobian, compare = myisapprox, atol = 10^-6)
41+
test_with_logabsdet_jacobian(noninferrable_inv, CX, getjacobian, atol = 10^-6)
4342
end

test/test_with_ladj.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,60 +4,53 @@ using ChangesOfVariables
44
using Test
55

66
using LinearAlgebra
7-
import ForwardDiff
87

98
using ChangesOfVariables: test_with_logabsdet_jacobian
109

11-
include("rv_and_back.jl")
10+
include("getjacobian.jl")
1211

1312

1413
@testset "with_logabsdet_jacobian" begin
15-
foo(x) = inv(exp(-x) + 1)
16-
1714
function ChangesOfVariables.with_logabsdet_jacobian(::typeof(foo), x)
1815
y = foo(x)
1916
ladj = -x + 2 * log(y)
2017
(y, ladj)
2118
end
2219

23-
2420
x = 4.2
2521
X = rand(10)
2622
A = rand(5, 5)
2723
CA = Complex.(rand(5, 5), rand(5, 5))
2824

29-
30-
getjacobian = ForwardDiff.jacobian
31-
3225
isaprx(a, b; kwargs...) = isapprox(a,b; kwargs...)
3326
isaprx(a::NTuple{N,Any}, b::NTuple{N,Any}; kwargs...) where N = all(map((a,b) -> isaprx(a, b; kwargs...), a, b))
3427

3528

36-
test_with_logabsdet_jacobian(foo, x, getjacobian, rv_and_back)
29+
test_with_logabsdet_jacobian(foo, x, getjacobian)
3730

3831
@static if VERSION >= v"1.6"
39-
test_with_logabsdet_jacobian(log foo, x, getjacobian, rv_and_back)
32+
test_with_logabsdet_jacobian(log foo, x, getjacobian)
4033
end
4134

4235
@testset "getjacobian on mapped and broadcasted" begin
4336
for f in (Base.Fix1(map, foo), Base.Fix1(broadcast, foo))
4437
for arg in (x, fill(x,), Ref(x), (x,), X)
45-
test_with_logabsdet_jacobian(f, arg, getjacobian, rv_and_back, compare = isaprx)
38+
test_with_logabsdet_jacobian(f, arg, getjacobian, compare = isaprx)
4639
end
4740
end
4841
end
4942

5043
@testset "getjacobian on identity, adjoint and transpose" begin
5144
for f in (identity, adjoint, transpose)
5245
for arg in (x, A)
53-
test_with_logabsdet_jacobian(f, arg, getjacobian, rv_and_back)
46+
test_with_logabsdet_jacobian(f, arg, getjacobian)
5447
end
5548
end
5649
end
5750

5851
@testset "getjacobian on inv" begin
5952
for arg in (x, A, CA)
60-
test_with_logabsdet_jacobian(inv, arg, getjacobian, rv_and_back)
53+
test_with_logabsdet_jacobian(inv, arg, getjacobian)
6154
end
6255
end
6356
end

0 commit comments

Comments
 (0)