Skip to content

Commit fd3396d

Browse files
committed
Support Base.Broadcast.BroadcastFunction
1 parent 0eefe9d commit fd3396d

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ InverseFunctions
66

77
This package defines the function [`inverse`](@ref). `inverse(f)` returns the inverse function of a function `f`, so that `inverse(f)(f(x)) ≈ x`.
88

9-
`inverse` supports mapped/broadcasted functions (via `Base.Fix1`) and (on Julia >=v1.6) function composition.
9+
`inverse` supports mapped/broadcasted functions (via `Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and (on Julia >=v1.6) function composition.
1010

1111
Implementations of `inverse(f)` for `identity`, `inv`, `adjoint` and `transpose` as well as for `exp`, `log`, `exp2`, `log2`, `exp10`, `log10`, `expm1`, `log1p` and `sqrt` are included.

src/inverse.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
66
Return the inverse of function `f`.
77
8-
`inverse` supports mapped and broadcasted functions (via `Base.Fix1`) and
9-
function composition (requires Julia >= 1.6).
8+
`inverse` supports mapped and broadcasted functions (via
9+
`Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and function composition
10+
(requires Julia >= 1.6).
1011
1112
# Examples
1213
@@ -27,7 +28,7 @@ true
2728
julia> inverse(inverse(foo)) === foo
2829
true
2930
30-
julia> broadcast_foo = Base.Fix1(broadcast, foo);
31+
julia> broadcast_foo = VERSION >= v"1.6" ? Base.Broadcast.BroadcastFunction(foo) : Base.Fix1(broadcast, foo);
3132
3233
julia> X = rand(10);
3334
@@ -84,14 +85,23 @@ inverse(::typeof(inverse)) = inverse
8485
Base.ComposedFunction(inv_inner, inv_outer)
8586
end
8687
end
88+
89+
function inverse(bf::Base.Broadcast.BroadcastFunction)
90+
inv_f_kernel = inverse(bf.f)
91+
if inv_f_kernel isa NoInverse
92+
NoInverse(bf)
93+
else
94+
Base.Broadcast.BroadcastFunction(inv_f_kernel)
95+
end
96+
end
8797
end
8898

8999
function inverse(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}})
90100
inv_f_kernel = inverse(mapped_f.x)
91101
if inv_f_kernel isa NoInverse
92102
NoInverse(mapped_f)
93103
else
94-
Base.Fix1(mapped_f.f, inverse(mapped_f.x))
104+
Base.Fix1(mapped_f.f, inv_f_kernel)
95105
end
96106
end
97107

test/test_inverse.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,18 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A))
2020

2121

2222
@testset "inverse" begin
23+
@static if VERSION >= v"1.6"
24+
_bc_func(f) = Base.Broadcast.BroadcastFunction(f)
25+
else
26+
_bc_func(f) = Base.Fix1(broadcast, f)
27+
end
28+
2329
f_without_inverse(x) = 1
2430
@test inverse(f_without_inverse) isa NoInverse
2531
@test_throws ErrorException inverse(f_without_inverse)(42)
2632
@test inverse(inverse(f_without_inverse)) === f_without_inverse
2733

28-
for f in (f_without_inverse exp, exp f_without_inverse, Base.Fix1(broadcast, f_without_inverse), Base.Fix1(map, f_without_inverse))
34+
for f in (f_without_inverse exp, exp f_without_inverse, _bc_func(f_without_inverse), Base.Fix1(broadcast, f_without_inverse), Base.Fix1(map, f_without_inverse))
2935
@test inverse(f) == NoInverse(f)
3036
@test inverse(inverse(f)) == f
3137
end
@@ -96,7 +102,7 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A))
96102
end
97103

98104
X = rand(5)
99-
for f in (Base.Fix1(broadcast, foo), Base.Fix1(map, foo))
105+
for f in (_bc_func(foo), Base.Fix1(broadcast, foo), Base.Fix1(map, foo))
100106
for x in (x, fill(x, 3), X)
101107
InverseFunctions.test_inverse(f, x)
102108
end

0 commit comments

Comments
 (0)