Skip to content

Commit 3ec06fa

Browse files
committed
Automatic type inference test in test_with_logabsdet_jacobian
1 parent 73aa3c3 commit 3ec06fa

File tree

2 files changed

+34
-30
lines changed

2 files changed

+34
-30
lines changed

src/test.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,10 @@
11
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
22

33

4-
_generalized_logabsdet(A) = logabsdet(A)
5-
_generalized_logabsdet(x::Real) = log(abs(x))
6-
7-
function _auto_with_logabsdet_jacobian(f, x, getjacobian, rv_and_back)
8-
y = f(x)
9-
V, to_x = rv_and_back(x)
10-
vf(V) = rv_and_back(f(to_x(V)))[1]
11-
ladj = _generalized_logabsdet(getjacobian(vf, V))[1]
12-
return (y, ladj)
13-
end
14-
15-
164
"""
175
ChangesOfVariables.test_with_logabsdet_jacobian(
186
f, x, getjacobian, rv_and_back = x -> (x, identity);
19-
compare = isapprox, test_inferred::Bool = true, kwargs...
7+
compare = isapprox, kwargs...
208
)
219
2210
Test if [`with_logabsdet_jacobian(f, x)`](@ref) is implemented correctly.
@@ -46,17 +34,31 @@ be tested.
4634
"""
4735
function test_with_logabsdet_jacobian(
4836
f, x, getjacobian, rv_and_back = x -> (x, identity);
49-
compare = isapprox, test_inferred::Bool = true, kwargs...
37+
compare = isapprox, kwargs...
5038
)
5139
@testset "test_with_logabsdet_jacobian: $f with input $x" begin
52-
y, ladj = if test_inferred
40+
ref_y, test_type_inference = try
41+
@inferred(f(x)), true
42+
catch err
43+
f(x), false
44+
end
45+
46+
y, ladj = if test_type_inference
5347
@inferred with_logabsdet_jacobian(f, x)
5448
else
5549
with_logabsdet_jacobian(f, x)
5650
end
57-
ref_y, ref_ladj = _auto_with_logabsdet_jacobian(f, x, getjacobian, rv_and_back)
51+
52+
V, to_x = rv_and_back(x)
53+
vf(V) = rv_and_back(f(to_x(V)))[1]
54+
ref_ladj = _generalized_logabsdet(getjacobian(vf, V))[1]
55+
5856
@test compare(y, ref_y; kwargs...)
5957
@test compare(ladj, ref_ladj; kwargs...)
6058
end
6159
return nothing
6260
end
61+
62+
63+
_generalized_logabsdet(A) = logabsdet(A)
64+
_generalized_logabsdet(x::Real) = log(abs(x))

test/test_test.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,22 @@ end
2222

2323

2424
@testset "test_with_logabsdet_jacobian" begin
25-
x = Complex(0.2, -0.7)
26-
y, ladj_y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, x, ForwardDiff.jacobian, rv_and_back)
27-
@test y == inv(x)
28-
@test ladj_y -4 * log(abs(x))
29-
30-
X = Complex.(randn(3,3), randn(3,3))
31-
Y, ladj_Y = ChangesOfVariables._auto_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back)
32-
@test Y == inv(X)
33-
@test ladj_Y -4 * 3 * logabsdet(X)[1]
25+
rx = 0.5
26+
cx = Complex(0.2, -0.7)
27+
X = rand(3, 3)
28+
CX = Complex.(randn(3,3), randn(3,3))
3429

3530
myisapprox(a, b; kwargs...) = isapprox(a, b; kwargs...)
36-
test_with_logabsdet_jacobian(inv, 0.5, ForwardDiff.derivative, test_inferred = true, atol = 10^-6)
37-
test_with_logabsdet_jacobian(inv, rand(2,2), ForwardDiff.jacobian, test_inferred = true, atol = 10^-6)
38-
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back, test_inferred = true, atol = 10^-6)
39-
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back, test_inferred = false, atol = 10^-6)
40-
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, rv_and_back, compare = myisapprox, atol = 10^-6)
31+
32+
noninferrable_inv(x) = x!=rand(size(x)...) ? inv(x) : ""
33+
ChangesOfVariables.with_logabsdet_jacobian(::typeof(noninferrable_inv), x) = noninferrable_inv(x), with_logabsdet_jacobian(inv, x)[2]
34+
@test_throws ErrorException @inferred with_logabsdet_jacobian(noninferrable_inv, rand(2, 2))
35+
36+
test_with_logabsdet_jacobian(inv, rx, ForwardDiff.derivative, atol = 10^-6)
37+
test_with_logabsdet_jacobian(inv, cx, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
38+
test_with_logabsdet_jacobian(inv, X, ForwardDiff.jacobian, atol = 10^-6)
39+
test_with_logabsdet_jacobian(inv, CX, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
40+
test_with_logabsdet_jacobian(inv, CX, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
41+
test_with_logabsdet_jacobian(inv, CX, ForwardDiff.jacobian, rv_and_back, compare = myisapprox, atol = 10^-6)
42+
test_with_logabsdet_jacobian(noninferrable_inv, CX, ForwardDiff.jacobian, rv_and_back, atol = 10^-6)
4143
end

0 commit comments

Comments
 (0)