Skip to content

Commit a074fc2

Browse files
author
Katharine Hyatt
committed
Fix JET by splitting macro up a bit
1 parent 74bed3d commit a074fc2

File tree

2 files changed

+88
-18
lines changed

2 files changed

+88
-18
lines changed

src/algorithms.jl

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,10 @@ macro algdef(name)
158158
end
159159

160160
"""
161-
@functiondef f
161+
@functiondef [n_args=1] f
162162
163163
Convenience macro to define the boilerplate code that dispatches between several versions of `f` and `f!`.
164-
By default, this enables the following signatures to be defined in terms of
164+
By default, `f` accepts a single argument `A`. This enables the following signatures to be defined in terms of
165165
the final `f!(A, out, alg::Algorithm)`.
166166
167167
```julia
@@ -171,18 +171,54 @@ the final `f!(A, out, alg::Algorithm)`.
171171
f!(A, alg::Algorithm)
172172
```
173173
174+
The number of inputs can be set with the `n_args` keyword
175+
argument, so that
176+
177+
```julia
178+
@functiondef n_args=2 f
179+
```
180+
181+
would create
182+
183+
```julia
184+
f(A, B; kwargs...)
185+
f(A, B, alg::Algorithm)
186+
f!(A, B, [out]; kwargs...)
187+
f!(A, B, alg::Algorithm)
188+
```
189+
174190
See also [`copy_input`](@ref), [`select_algorithm`](@ref) and [`initialize_output`](@ref).
175191
"""
176-
macro functiondef(f)
192+
macro functiondef(args...)
193+
kwargs = map(args[1:end-1]) do kwarg
194+
if kwarg isa Symbol
195+
:($kwarg = $kwarg)
196+
elseif Meta.isexpr(kwarg, :(=))
197+
kwarg
198+
else
199+
throw(ArgumentError("Invalid keyword argument '$kwarg'"))
200+
end
201+
end
202+
isempty(kwargs) || length(kwargs) == 1 || throw(ArgumentError("Only one keyword argument to `@functiondef` is supported"))
203+
f_n_args = 1 # default
204+
if length(kwargs) == 1
205+
kwarg = only(kwargs) # only one kwarg is currently supported, TODO modify if we support more
206+
key::Symbol, val = kwarg.args
207+
key === :n_args || throw(ArgumentError("Unsupported keyword argument $key to `@functiondef`"))
208+
(isa(val, Integer) && val > 0) || throw(ArgumentError("`n_args` keyword argument to `@functiondef` should be an integer > 0"))
209+
f_n_args = val
210+
end
211+
212+
f = args[end]
177213
f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`"))
178214
f! = Symbol(f, :!)
179215

180-
ex = quote
216+
# TODO is the right way?
217+
@gensym A B
218+
ex_single_arg = quote
181219
# out of place to inplace
182220
$f(A; kwargs...) = $f!(copy_input($f, A); kwargs...)
183221
$f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg)
184-
$f(A, B; kwargs...) = $f!(copy_input($f, A, B)...; kwargs...)
185-
$f(A, B, alg::AbstractAlgorithm) = $f!(copy_input($f, A, B)..., alg)
186222

187223
# fill in arguments
188224
function $f!(A; alg=nothing, kwargs...)
@@ -191,12 +227,6 @@ macro functiondef(f)
191227
function $f!(A, out; alg=nothing, kwargs...)
192228
return $f!(A, out, select_algorithm($f!, A, alg; kwargs...))
193229
end
194-
function $f!(A, B, out; alg=nothing, kwargs...)
195-
return $f!(A, B, out, select_algorithm($f!, (A, B), alg; kwargs...))
196-
end
197-
function $f!(A, B, alg::AbstractAlgorithm)
198-
return $f!(A, B, initialize_output($f!, A, B, alg), alg)
199-
end
200230
function $f!(A, alg::AbstractAlgorithm)
201231
return $f!(A, initialize_output($f!, A, alg), alg)
202232
end
@@ -210,9 +240,6 @@ macro functiondef(f)
210240
@inline function default_algorithm(::typeof($f), A; kwargs...)
211241
return default_algorithm($f!, A; kwargs...)
212242
end
213-
@inline function default_algorithm(::typeof($f), A, B; kwargs...)
214-
return default_algorithm($f!, A, B; kwargs...)
215-
end
216243
# define default algorithm fallbacks for out-of-place functions
217244
# in terms of the corresponding in-place function for types,
218245
# in principle this is covered by the definition above but
@@ -226,14 +253,57 @@ macro functiondef(f)
226253
@inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
227254
return default_algorithm($f!, A; kwargs...)
228255
end
256+
257+
# copy documentation to both functions
258+
Core.@__doc__ $f, $f!
259+
end
260+
ex_double_arg = quote
261+
# out of place to inplace
262+
$f(A, B; kwargs...) = $f!(copy_input($f, A, B)...; kwargs...)
263+
$f(A, B, alg::AbstractAlgorithm) = $f!(copy_input($f, A, B)..., alg)
264+
265+
# fill in arguments
266+
function $f!(A, B; alg=nothing, kwargs...)
267+
return $f!(A, B, select_algorithm($f!, (A, B), alg; kwargs...))
268+
end
269+
function $f!(A, B, out; alg=nothing, kwargs...)
270+
return $f!(A, B, out, select_algorithm($f!, (A, B), alg; kwargs...))
271+
end
272+
function $f!(A, B, alg::AbstractAlgorithm)
273+
return $f!(A, B, initialize_output($f!, A, B, alg), alg)
274+
end
275+
276+
# define fallbacks for algorithm selection
277+
@inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg}
278+
return select_algorithm($f!, A, alg; kwargs...)
279+
end
280+
# define default algorithm fallbacks for out-of-place functions
281+
# in terms of the corresponding in-place function
282+
@inline function default_algorithm(::typeof($f), A, B; kwargs...)
283+
return default_algorithm($f!, A, B; kwargs...)
284+
end
285+
# define default algorithm fallbacks for out-of-place functions
286+
# in terms of the corresponding in-place function for types,
287+
# in principle this is covered by the definition above but
288+
# it is necessary to avoid ambiguity errors with the generic definitions:
289+
# ```julia
290+
# default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
291+
# function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T}
292+
# throw(MethodError(default_algorithm, (f, T)))
293+
# end
294+
# ```
229295
@inline function default_algorithm(::typeof($f), ::Type{A}, ::Type{B}; kwargs...) where {A, B}
230296
return default_algorithm($f!, A, B; kwargs...)
231297
end
232298

233299
# copy documentation to both functions
234300
Core.@__doc__ $f, $f!
235301
end
236-
return esc(ex)
302+
if f_n_args == 1
303+
return esc(ex_single_arg)
304+
elseif f_n_args == 2
305+
return esc(ex_double_arg)
306+
end
237307
end
238308

239309
"""

src/interface/gen_eig.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ and the diagonal matrix `W` contains the associated generalized eigenvalues.
4040
4141
See also [`gen_eig_vals(!)`](@ref eig_vals).
4242
"""
43-
@functiondef gen_eig_full
43+
@functiondef n_args=2 gen_eig_full
4444

4545
"""
4646
gen_eig_vals(A, B; kwargs...) -> W
@@ -61,7 +61,7 @@ Compute the list of generalized eigenvalues of `A` and `B`.
6161
6262
See also [`gen_eig_full(!)`](@ref gen_eig_full).
6363
"""
64-
@functiondef gen_eig_vals
64+
@functiondef n_args=2 gen_eig_vals
6565

6666
# Algorithm selection
6767
# -------------------

0 commit comments

Comments
 (0)