Skip to content

Commit 2724c4c

Browse files
committed
Relax requirements on inverse of inverse
1 parent 491a54b commit 2724c4c

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

src/inverse.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ true
4343
# Implementation
4444
4545
Implementations of `inverse(::typeof(f))` have to satisfy
46+
4647
* `inverse(f)(f(x)) ≈ x` for all `x` in the domain of `f`, and
47-
* `inverse(inverse(f)) === f`.
48+
* `inverse(inverse(f))` is equivalent (ideally identical/equal) to `f`.
4849
4950
You can check your implementation with [`InverseFunctions.test_inverse`](@ref).
5051
"""

src/test.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
"""
2-
InverseFunctions.test_inverse(f, x; kwargs...)
2+
InverseFunctions.test_inverse(f, x; inv_inv_test = ===, kwargs...)
33
44
Check if [`inverse(f)`](@ref) is implemented correctly.
55
66
The function checks if
77
- `inverse(f)(f(x)) ≈ x` and
8-
- `inverse(inverse(f)) === f`.
8+
- `inv_inv_test(inverse(inverse(f)), f)`.
99
10-
All keyword arguments are passed to `isapprox`.
10+
With `inv_inv_test = ≈`, tests if the result of `inverse(inverse(f))(x)`
11+
is equal or approximately equal to `f(x)`.
12+
13+
`kwargs...` are passed to `isapprox`.
1114
"""
12-
function test_inverse(f, x; kwargs...)
15+
function test_inverse(f, x; inv_inv_test = ===, kwargs...)
1316
@testset "test_inverse: $f with input $x" begin
14-
inverse_f = inverse(f)
15-
@test (x2 = inverse_f(f(x)); x2 == x || isapprox(inverse_f(f(x)), x; kwargs...))
16-
@test inverse(inverse_f) === f
17+
y = f(x)
18+
@test (x2 = inverse(f)(y); x2 == x || isapprox(x2, x; kwargs...))
19+
@test let inv_inv_f = inverse(inverse(f))
20+
if inv_inv_test ==
21+
(y2 = inv_inv_f(x); y2 == y || isapprox(y2, y; kwargs...))
22+
else
23+
inv_inv_test(inv_inv_f, f)
24+
end
25+
end
1726
end
1827
return nothing
1928
end

test/test_inverse.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ inv_foo(y) = log(y / (1 - y))
1010
InverseFunctions.inverse(::typeof(foo)) = inv_foo
1111
InverseFunctions.inverse(::typeof(inv_foo)) = foo
1212

13+
14+
struct Bar{MT<:AbstractMatrix}
15+
A::MT
16+
end
17+
18+
(f::Bar)(x) = f.A * x
19+
InverseFunctions.inverse(f) = Bar(inv(f.A))
20+
21+
1322
@testset "inverse" begin
1423
InverseFunctions.test_inverse(inverse, log)
1524

@@ -30,6 +39,8 @@ InverseFunctions.inverse(::typeof(inv_foo)) = foo
3039
end
3140
end
3241

42+
InverseFunctions.test_inverse(Bar(rand(3,3)), rand(3), inv_inv_test = )
43+
3344
@static if VERSION >= v"1.6"
3445
InverseFunctions.test_inverse(log foo, x)
3546
end

0 commit comments

Comments
 (0)