Skip to content

Commit 73aa3c3

Browse files
committed
Add rv_and_back argument to test_with_logabsdet_jacobian
1 parent 08cb10b commit 73aa3c3

File tree

4 files changed

+79
-62
lines changed

4 files changed

+79
-62
lines changed

src/test.jl

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,22 @@
11
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
22

33

4-
_to_realvec_and_back(V::AbstractVector{<:Real}) = V, identity
5-
_to_realvec_and_back(x::Real) = [x], V -> V[1]
6-
_to_realvec_and_back(x::Complex) = [real(x), imag(x)], V -> Complex(V[1], V[2])
7-
_to_realvec_and_back(x::NTuple{N}) where N = [x...], V -> ntuple(i -> V[i], Val(N))
4+
_generalized_logabsdet(A) = logabsdet(A)
5+
_generalized_logabsdet(x::Real) = log(abs(x))
86

9-
function _to_realvec_and_back(x::Ref)
10-
xval = x[]
11-
V, to_xval = _to_realvec_and_back(xval)
12-
back_to_ref(V) = Ref(to_xval(V))
13-
return (V, back_to_ref)
14-
end
15-
16-
_to_realvec_and_back(A::AbstractArray{<:Real}) = vec(A), V -> reshape(V, size(A))
17-
18-
function _to_realvec_and_back(A::AbstractArray{Complex{T}, N}) where {T<:Real, N}
19-
RA = cat(real.(A), imag.(A), dims = N+1)
20-
V, to_array = _to_realvec_and_back(RA)
21-
function back_to_complex(V)
22-
RA = to_array(V)
23-
Complex.(view(RA, map(_ -> :, size(A))..., 1), view(RA, map(_ -> :, size(A))..., 2))
24-
end
25-
return (V, back_to_complex)
26-
end
27-
28-
29-
_to_realvec(x) = _to_realvec_and_back(x)[1]
30-
31-
32-
function _auto_with_logabsdet_jacobian(f, x, getjacobian)
7+
function _auto_with_logabsdet_jacobian(f, x, getjacobian, rv_and_back)
338
y = f(x)
34-
V, to_x = _to_realvec_and_back(x)
35-
vf(V) = _to_realvec(f(to_x(V)))
36-
ladj = logabsdet(getjacobian(vf, V))[1]
9+
V, to_x = rv_and_back(x)
10+
vf(V) = rv_and_back(f(to_x(V)))[1]
11+
ladj = _generalized_logabsdet(getjacobian(vf, V))[1]
3712
return (y, ladj)
3813
end
3914

4015

4116
"""
4217
ChangesOfVariables.test_with_logabsdet_jacobian(
43-
f, x, getjacobian;
44-
test_inferred::Bool = true, kwargs...
18+
f, x, getjacobian, rv_and_back = x -> (x, identity);
19+
compare = isapprox, test_inferred::Bool = true, kwargs...
4520
)
4621
4722
Test if [`with_logabsdet_jacobian(f, x)`](@ref) is implemented correctly.
@@ -51,26 +26,35 @@ equal to `(f(x), logabsdet(getjacobian(f, x)))`
5126
5227
So the test uses `getjacobian(f, x)` to calculate a reference Jacobian for
5328
`f` at `x`. Passing `ForwardDiff.jabobian`, `Zygote.jacobian` or similar as
54-
the `getjacobian` function will do fine in most cases.
29+
the `getjacobian` function will do fine in most cases. If input and output
30+
of `f` are real scalar values, use `ForwardDiff.derivative`.
31+
32+
If `getjacobian(f, x)` can't handle the type of `x` of `f(x)` because they
33+
are not real-valued vectors, use the `rv_and_back` argument to pass a
34+
function with the following behavior
5535
56-
If `x` or `f(x)` are real-valued scalars or complex-valued scalars or arrays,
57-
the test will try to reshape them automatically, to account for limitations
58-
of (e.g.) `ForwardDiff` and to ensure the result of `getjacobian` is a real
59-
matrix.
36+
```julia
37+
v, back = rv_and_back(x)
38+
v isa AbstractVector{<:Real}
39+
back(v) == x
40+
```
6041
61-
If `test_inferred == true` will test type inference on
62-
`with_logabsdet_jacobian`.
42+
If `test_inferred == true`, type inference on `with_logabsdet_jacobian` will
43+
be tested.
6344
64-
`kwargs...` are forwarded to `isapprox`.
45+
`kwargs...` are forwarded to `compare`.
6546
"""
66-
function test_with_logabsdet_jacobian(f, x, getjacobian; compare=isapprox, test_inferred::Bool = true, kwargs...)
47+
function test_with_logabsdet_jacobian(
48+
f, x, getjacobian, rv_and_back = x -> (x, identity);
49+
compare = isapprox, test_inferred::Bool = true, kwargs...
50+
)
6751
@testset "test_with_logabsdet_jacobian: $f with input $x" begin
6852
y, ladj = if test_inferred
6953
@inferred with_logabsdet_jacobian(f, x)
7054
else
7155
with_logabsdet_jacobian(f, x)
7256
end
73-
ref_y, ref_ladj = _auto_with_logabsdet_jacobian(f, x, getjacobian)
57+
ref_y, ref_ladj = _auto_with_logabsdet_jacobian(f, x, getjacobian, rv_and_back)
7458
@test compare(y, ref_y; kwargs...)
7559
@test compare(ladj, ref_ladj; kwargs...)
7660
end

test/rv_and_back.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
2+
3+
4+
rv_and_back(V::AbstractVector{<:Real}) = V, identity
5+
rv_and_back(x::Real) = [x], V -> V[1]
6+
rv_and_back(x::Complex) = [real(x), imag(x)], V -> Complex(V[1], V[2])
7+
rv_and_back(x::NTuple{N}) where N = [x...], V -> ntuple(i -> V[i], Val(N))
8+
9+
function rv_and_back(x::Ref)
10+
xval = x[]
11+
V, to_xval = rv_and_back(xval)
12+
back_to_ref(V) = Ref(to_xval(V))
13+
return (V, back_to_ref)
14+
end
15+
16+
rv_and_back(A::AbstractArray{<:Real}) = vec(A), V -> reshape(V, size(A))
17+
18+
function rv_and_back(A::AbstractArray{Complex{T}, N}) where {T<:Real, N}
19+
RA = cat(real.(A), imag.(A), dims = N+1)
20+
V, to_array = rv_and_back(RA)
21+
function back_to_complex(V)
22+
RA = to_array(V)
23+
Complex.(view(RA, map(_ -> :, size(A))..., 1), view(RA, map(_ -> :, size(A))..., 2))
24+
end
25+
return (V, back_to_complex)
26+
end

test/test_test.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,33 @@ import ForwardDiff
99
using ChangesOfVariables: test_with_logabsdet_jacobian
1010

1111

12-
@testset "test_with_logabsdet_jacobian" begin
13-
@testset "ChangesOfVariables._to_realvec_and_back" begin
14-
for x in (rand(3), 0.5, Complex(0.2,0.7), (3,5,9), Ref(42), rand(3, 4, 5), Complex.(rand(3,5), rand(3,5)))
15-
V, to_x = ChangesOfVariables._to_realvec_and_back(x)
16-
@test V isa AbstractVector{<:Real}
17-
@test V == ChangesOfVariables._to_realvec(x)
18-
@test x isa Ref ? to_x(V)[] == x[] : to_x(V) == x
19-
end
12+
include("rv_and_back.jl")
13+
14+
@testset "rv_and_back" begin
15+
for x in (rand(3), 0.5, Complex(0.2,0.7), (3,5,9), Ref(42), rand(3, 4, 5), Complex.(rand(3,5), rand(3,5)))
16+
V, to_x = rv_and_back(x)
17+
@test V isa AbstractVector{<:Real}
18+
@test V == rv_and_back(x)[1]
19+
@test x isa Ref ? to_x(V)[] == x[] : to_x(V) == x
2020
end
21+
end
2122

23+
24+
@testset "test_with_logabsdet_jacobian" begin
2225
x = Complex(0.2, -0.7)
23-
y, ladj_y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, x, ForwardDiff.jacobian)
26+
y, ladj_y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, x, ForwardDiff.jacobian, rv_and_back)
2427
@test y == inv(x)
2528
@test ladj_y -4 * log(abs(x))
2629

2730
X = Complex.(randn(3,3), randn(3,3))
28-
Y, ladj_Y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian)
31+
Y, ladj_Y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back)
2932
@test Y == inv(X)
30-
@test ladj_Y == -4 * 3 * logabsdet(X)[1]
33+
@test ladj_Y -4 * 3 * logabsdet(X)[1]
3134

3235
myisapprox(a, b; kwargs...) = isapprox(a, b; kwargs...)
33-
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, test_inferred = true, atol = 10^-6)
34-
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, test_inferred = false, atol = 10^-6)
35-
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, compare = myisapprox, atol = 10^-6)
36+
test_with_logabsdet_jacobian(inv, 0.5, ForwardDiff.derivative, test_inferred = true, atol = 10^-6)
37+
test_with_logabsdet_jacobian(inv, rand(2,2), ForwardDiff.jacobian, test_inferred = true, atol = 10^-6)
38+
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back, test_inferred = true, atol = 10^-6)
39+
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back, test_inferred = false, atol = 10^-6)
40+
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back, compare = myisapprox, atol = 10^-6)
3641
end

test/test_with_ladj.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import ForwardDiff
88

99
using ChangesOfVariables: test_with_logabsdet_jacobian
1010

11+
include("rv_and_back.jl")
12+
1113

1214
@testset "with_logabsdet_jacobian" begin
1315
foo(x) = inv(exp(-x) + 1)
@@ -31,31 +33,31 @@ using ChangesOfVariables: test_with_logabsdet_jacobian
3133
isaprx(a::NTuple{N,Any}, b::NTuple{N,Any}; kwargs...) where N = all(map((a,b) -> isaprx(a, b; kwargs...), a, b))
3234

3335

34-
test_with_logabsdet_jacobian(foo, x, getjacobian)
36+
test_with_logabsdet_jacobian(foo, x, getjacobian, rv_and_back)
3537

3638
@static if VERSION >= v"1.6"
37-
test_with_logabsdet_jacobian(log foo, x, getjacobian)
39+
test_with_logabsdet_jacobian(log foo, x, getjacobian, rv_and_back)
3840
end
3941

4042
@testset "getjacobian on mapped and broadcasted" begin
4143
for f in (Base.Fix1(map, foo), Base.Fix1(broadcast, foo))
4244
for arg in (x, fill(x,), Ref(x), (x,), X)
43-
test_with_logabsdet_jacobian(f, arg, getjacobian, compare = isaprx)
45+
test_with_logabsdet_jacobian(f, arg, getjacobian, rv_and_back, compare = isaprx)
4446
end
4547
end
4648
end
4749

4850
@testset "getjacobian on identity, adjoint and transpose" begin
4951
for f in (identity, adjoint, transpose)
5052
for arg in (x, A)
51-
test_with_logabsdet_jacobian(f, arg, getjacobian)
53+
test_with_logabsdet_jacobian(f, arg, getjacobian, rv_and_back)
5254
end
5355
end
5456
end
5557

5658
@testset "getjacobian on inv" begin
5759
for arg in (x, A, CA)
58-
test_with_logabsdet_jacobian(inv, arg, getjacobian)
60+
test_with_logabsdet_jacobian(inv, arg, getjacobian, rv_and_back)
5961
end
6062
end
6163
end

0 commit comments

Comments
 (0)