Skip to content

Commit ee91afd

Browse files
Merge pull request #1338 from AayushSabharwal/as/inverse
feat: add ability to define and query function inverses
2 parents 942c56e + affad5d commit ee91afd

File tree

5 files changed

+217
-3
lines changed

5 files changed

+217
-3
lines changed

docs/src/manual/functions.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,21 @@ function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any
188188
Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1])))
189189
end
190190
```
191+
192+
## Inverse function registration
193+
194+
Symbolics.jl allows defining and querying the inverses of functions.
195+
196+
```@docs
197+
inverse
198+
left_inverse
199+
right_inverse
200+
@register_inverse
201+
has_inverse
202+
has_left_inverse
203+
has_right_inverse
204+
```
205+
206+
Symbolics.jl implements inverses for standard trigonometric and logarithmic functions,
207+
as well as their variants from `NaNMath`. It also implements inverses of
208+
`ComposedFunction`s.

src/Symbolics.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@ using Primes
1919

2020
using Reexport
2121

22-
using DomainSets
23-
2422
using Setfield
2523

26-
import DomainSets: Domain
24+
import DomainSets: Domain, DomainSets
2725

2826
using TermInterface
2927
import TermInterface: maketerm, iscall, operation, arguments, metadata
@@ -234,4 +232,7 @@ function __init__()
234232
end
235233
end
236234

235+
export inverse, left_inverse, right_inverse, @register_inverse, has_inverse, has_left_inverse, has_right_inverse
236+
include("inverse.jl")
237+
237238
end # module

src/inverse.jl

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
inverse(f)
3+
4+
Given a single-input single-output function `f`, return its inverse `g`. This requires
5+
that `f` is bijective. If `inverse` is defined for a function, `left_inverse` and
6+
`right_inverse` should return `inverse(f)`. `inverse(g)` should also be defined to
7+
return `f`.
8+
9+
See also: [`left_inverse`](@ref), [`right_inverse`](@ref), [`@register_inverse`](@ref).
10+
"""
11+
function inverse end
12+
13+
"""
14+
left_inverse(f)
15+
16+
Given a single-input single-output function `f`, return its left inverse `g`. This
17+
requires that `f` is injective. If `left_inverse` is defined for a function,
18+
`right_inverse` and `inverse` must not be defined and should error. `right_inverse(g)`
19+
should also be defined to return `f`.
20+
21+
See also: [`inverse`](@ref), [`right_inverse`](@ref), [`@register_inverse`](@ref).
22+
"""
23+
function left_inverse end
24+
25+
"""
26+
right_inverse(f)
27+
28+
Given a single-input single-output function `f`, return its right inverse `g`. This
29+
requires that `f` is surjective. If `right_inverse` is defined for a function,
30+
`left_inverse` and `inverse` must not be defined and should error. `left_inverse(g)`
31+
should also be defined to return `f`.
32+
33+
See also [`inverse`](@ref), [`left_inverse`](@ref), [`@register_inverse`](@ref).
34+
"""
35+
function right_inverse end
36+
37+
"""
38+
@register_inverse f g
39+
@register_inverse f g left
40+
@register_inverse f g right
41+
42+
Mark `f` and `g` as inverses of each other. By default, assume that `f` and `g` are
43+
bijective. Also defines `left_inverse` and `right_inverse` to call `inverse`. If the
44+
third argument is `left`, assume that `f` is injective and `g` is its left inverse. If
45+
the third argument is `right`, assume that `f` is surjective and `g` is its right
46+
inverse.
47+
"""
48+
macro register_inverse(f, g, dir::QuoteNode = :(:both))
49+
dir = dir.value
50+
if dir == :both
51+
quote
52+
(::typeof($inverse))(::typeof($f)) = $g
53+
(::typeof($inverse))(::typeof($g)) = $f
54+
(::typeof($left_inverse))(::typeof($f)) = $(inverse)($f)
55+
(::typeof($right_inverse))(::typeof($f)) = $(inverse)($f)
56+
(::typeof($left_inverse))(::typeof($g)) = $(inverse)($g)
57+
(::typeof($right_inverse))(::typeof($g)) = $(inverse)($g)
58+
end
59+
elseif dir == :left
60+
quote
61+
(::typeof($left_inverse))(::typeof($f)) = $g
62+
(::typeof($right_inverse))(::typeof($g)) = $f
63+
end
64+
elseif dir == :right
65+
quote
66+
(::typeof($right_inverse))(::typeof($f)) = $g
67+
(::typeof($left_inverse))(::typeof($g)) = $f
68+
end
69+
else
70+
throw(ArgumentError("The third argument to `@register_inverse` must be `left` or `right`"))
71+
end
72+
end
73+
74+
"""
75+
$(TYPEDSIGNATURES)
76+
77+
Check if the provided function has an inverse defined via [`inverse`](@ref). Uses
78+
`hasmethod` to perform the check.
79+
"""
80+
has_inverse(::T) where {T} = hasmethod(inverse, Tuple{T})
81+
82+
"""
83+
$(TYPEDSIGNATURES)
84+
85+
Check if the provided function has a left inverse defined via [`left_inverse`](@ref)
86+
Uses `hasmethod` to perform the check.
87+
"""
88+
has_left_inverse(::T) where {T} = hasmethod(left_inverse, Tuple{T})
89+
90+
"""
91+
$(TYPEDSIGNATURES)
92+
93+
Check if the provided function has a left inverse defined via [`left_inverse`](@ref)
94+
Uses `hasmethod` to perform the check.
95+
"""
96+
has_right_inverse(::T) where {T} = hasmethod(right_inverse, Tuple{T})
97+
98+
"""
99+
$(TYPEDSIGNATURES)
100+
101+
A simple utility function which returns the square of the input. Used to define
102+
the inverse of `sqrt`.
103+
"""
104+
square(x) = x ^ 2
105+
106+
"""
107+
$(TYPEDSIGNATURES)
108+
109+
A simple utility function which returns the cube of the input. Used to define
110+
the inverse of `cbrt`.
111+
"""
112+
cube(x) = x ^ 3
113+
114+
"""
115+
$(TYPEDSIGNATURES)
116+
117+
A simple utility function which takes `x` and returns `acos(x) / pi`. Used to
118+
define the inverse of `acospi`.
119+
"""
120+
acosbypi(x) = acos(x) / pi
121+
122+
@register_inverse sin asin
123+
@register_inverse cos acos
124+
@register_inverse tan atan
125+
@register_inverse csc acsc
126+
@register_inverse sec asec
127+
@register_inverse cot acot
128+
@register_inverse sind asind
129+
@register_inverse cosd acosd
130+
@register_inverse tand atand
131+
@register_inverse cscd acscd
132+
@register_inverse secd asecd
133+
@register_inverse cotd acotd
134+
@register_inverse sinh asinh
135+
@register_inverse cosh acosh
136+
@register_inverse tanh atanh
137+
@register_inverse csch acsch
138+
@register_inverse sech asech
139+
@register_inverse coth acoth
140+
@register_inverse cospi acosbypi
141+
@register_inverse SpecialFunctions.digamma SpecialFunctions.invdigamma
142+
@register_inverse log exp
143+
@register_inverse log2 exp2
144+
@register_inverse log10 exp10
145+
@register_inverse log1p expm1
146+
@register_inverse deg2rad rad2deg
147+
@register_inverse sqrt square :left
148+
@register_inverse cbrt cube
149+
@register_inverse NaNMath.sin NaNMath.asin
150+
@register_inverse NaNMath.cos NaNMath.acos
151+
# can't use macro since it would be a re-definition of `inverse(atan)`
152+
inverse(::typeof(NaNMath.tan)) = inverse(tan)
153+
inverse(::typeof(NaNMath.acosh)) = inverse(acosh)
154+
inverse(::typeof(NaNMath.atanh)) = inverse(atanh)
155+
inverse(::typeof(NaNMath.log)) = inverse(log)
156+
inverse(::typeof(NaNMath.log10)) = inverse(log10)
157+
inverse(::typeof(NaNMath.log1p)) = inverse(log1p)
158+
inverse(::typeof(NaNMath.log2)) = inverse(log2)
159+
left_inverse(::typeof(NaNMath.sqrt)) = left_inverse(sqrt)
160+
161+
function inverse(f::ComposedFunction)
162+
return inverse(f.inner) inverse(f.outer)
163+
end
164+
has_inverse(f::ComposedFunction) = has_inverse(f.inner) && has_inverse(f.outer)
165+
function left_inverse(f::ComposedFunction)
166+
return left_inverse(f.inner) left_inverse(f.outer)
167+
end
168+
function has_left_inverse(f::ComposedFunction)
169+
return has_left_inverse(f.inner) && has_left_inverse(f.outer)
170+
end
171+
function right_inverse(f::ComposedFunction)
172+
return right_inverse(f.inner) right_inverse(f.outer)
173+
end
174+
function has_right_inverse(f::ComposedFunction)
175+
return has_right_inverse(f.inner) && has_right_inverse(f.outer)
176+
end

test/inverse.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using Symbolics
2+
3+
@test inverse(sin) == left_inverse(sin) == right_inverse(sin) == asin
4+
@test inverse(asin) == left_inverse(asin) == right_inverse(asin) == sin
5+
@test has_inverse(sin) && has_left_inverse(sin) && has_right_inverse(sin)
6+
fn = left_inverse(sqrt)
7+
@test right_inverse(fn) == sqrt
8+
@test_throws MethodError inverse(sqrt)
9+
@test_throws MethodError right_inverse(sqrt)
10+
@test !has_inverse(fn) && !has_left_inverse(fn)
11+
@test !has_inverse(sqrt) && !has_right_inverse(sqrt)
12+
@test has_inverse(sin cos)
13+
@test !has_inverse(sin sqrt)
14+
@test has_left_inverse(sin sqrt)
15+
@test inverse(sin cos) == acos asin
16+
@test inverse(inverse(sin cos)) == sin cos
17+
@test right_inverse(left_inverse(sin sqrt)) == sin sqrt
18+
@test inverse(sin cos tan) == atan (acos asin)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ if GROUP == "All" || GROUP == "Core"
6060
@safetestset "Show Test" begin include("show.jl") end
6161
@safetestset "Utility Function Test" begin include("utils.jl") end
6262
@safetestset "RootFinding solver" begin include("solver.jl") end
63+
@safetestset "Function inverses test" begin include("inverse.jl") end
6364
end
6465
end
6566

0 commit comments

Comments
 (0)