Skip to content

Commit cea00f5

Browse files
committed
Improve test_inverse interface
1 parent 2724c4c commit cea00f5

File tree

2 files changed

+13
-19
lines changed

2 files changed

+13
-19
lines changed

src/test.jl

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,22 @@
11
"""
2-
InverseFunctions.test_inverse(f, x; inv_inv_test = ===, kwargs...)
2+
InverseFunctions.test_inverse(f, x; compare=isapprox, kwargs...)
33
4-
Check if [`inverse(f)`](@ref) is implemented correctly.
4+
Test if [`inverse(f)`](@ref) is implemented correctly.
55
6-
The function checks if
7-
- `inverse(f)(f(x)) ≈ x` and
8-
- `inv_inv_test(inverse(inverse(f)), f)`.
6+
The function tests (as a `Test.@testset`) if
97
10-
With `inv_inv_test = ≈`, tests if the result of `inverse(inverse(f))(x)`
11-
is equal or approximately equal to `f(x)`.
8+
* `compare(inverse(f)(f(x)), x) == true` and
9+
* `compare(inverse(inverse(f))(x), f(x)) == true`.
1210
13-
`kwargs...` are passed to `isapprox`.
11+
`kwargs...` are forwarded to `compare`.
1412
"""
15-
function test_inverse(f, x; inv_inv_test = ===, kwargs...)
13+
function test_inverse(f, x; compare=isapprox, kwargs...)
1614
@testset "test_inverse: $f with input $x" begin
1715
y = f(x)
18-
@test (x2 = inverse(f)(y); x2 == x || isapprox(x2, x; kwargs...))
19-
@test let inv_inv_f = inverse(inverse(f))
20-
if inv_inv_test ==
21-
(y2 = inv_inv_f(x); y2 == y || isapprox(y2, y; kwargs...))
22-
else
23-
inv_inv_test(inv_inv_f, f)
24-
end
25-
end
16+
inverse_f = inverse(f)
17+
@test compare(inverse_f(y), x; kwargs...)
18+
inverse_inverse_f = inverse(inverse_f)
19+
@test compare(inverse_inverse_f(x), y; kwargs...)
2620
end
2721
return nothing
2822
end

test/test_inverse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ InverseFunctions.inverse(f) = Bar(inv(f.A))
2020

2121

2222
@testset "inverse" begin
23-
InverseFunctions.test_inverse(inverse, log)
23+
InverseFunctions.test_inverse(inverse, log, compare = ===)
2424

2525
x = rand()
2626
for f in (foo, inv_foo, exp, log, exp2, log2, exp10, log10, expm1, log1p)
@@ -39,7 +39,7 @@ InverseFunctions.inverse(f) = Bar(inv(f.A))
3939
end
4040
end
4141

42-
InverseFunctions.test_inverse(Bar(rand(3,3)), rand(3), inv_inv_test = )
42+
InverseFunctions.test_inverse(Bar(rand(3,3)), rand(3))
4343

4444
@static if VERSION >= v"1.6"
4545
InverseFunctions.test_inverse(log foo, x)

0 commit comments

Comments
 (0)