Skip to content

Commit aca90cc

Browse files
author
Jack Dunham
committed
Add default_kwargs interface and associated macros
1 parent c4085f7 commit aca90cc

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

src/default_kwargs.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
using MacroTools: @capture, splitdef, combinedef, isdef
2+
3+
"""
4+
default_kwargs(f::Function, args...; kwargs...)
5+
6+
Returns a set of default keyword arguments, as a `NamedTuple`, for the function `f`
7+
depending on an arbitrary number of positional arguments. Any number of these default
8+
keyword arguments can optionally be overwritten by passing the the keyword as a
9+
keyword argument to this function.
10+
"""
11+
function default_kwargs(f::Function, args...; kwargs...)
12+
return default_kwargs(f, map(typeof, args)...; kwargs...)
13+
end
14+
default_kwargs(f::Function, ::Vararg{<:Type}; kwargs...) = (; kwargs...)
15+
16+
"""
17+
@define_default_kwargs
18+
19+
Automatically define a `default_kwargs` method for a given function. This macro should
20+
be applied before a function definition:
21+
```
22+
@define_default_kwargs function f(arg1::T1, arg2::T2, ...; kwargs...)
23+
...
24+
end
25+
```
26+
The defined `default_kwargs` method takes the form
27+
```
28+
default_kwargs(::typeof(f), arg1::T1, arg2::T2, ...; kwargs...)
29+
```
30+
i.e. the function signature mirrors that of the function signature of `f`.
31+
"""
32+
macro define_default_kwargs(function_def)
33+
return default_kwargs_macro(function_def)
34+
end
35+
36+
function default_kwargs_macro(function_def)
37+
if !isdef(function_def)
38+
throw(
39+
ArgumentError(
40+
"The @define_default_kwargs macro must be followed by a function definition"
41+
),
42+
)
43+
end
44+
45+
ex = splitdef(function_def)
46+
new_ex = deepcopy(ex)
47+
48+
prev_kwargs = []
49+
50+
# Give very positional argument a name and escape the type.
51+
ex[:args] = map(ex[:args]) do arg
52+
@capture(arg, (name_::T_) | (::T_) | name_)
53+
if isnothing(name)
54+
name = gensym()
55+
end
56+
if isnothing(T)
57+
T = :Any
58+
end
59+
return :($(name)::$(esc(T)))
60+
end
61+
62+
# Replacing the kwargs values with the output of `default_kwargs`
63+
ex[:kwargs] = map(ex[:kwargs]) do kw
64+
@capture(kw, (key_::T_ = val_) | (key_ = val_) | key_)
65+
if !isnothing(val)
66+
kw.args[2] =
67+
:(default_kwargs($(esc(ex[:name])), $(ex[:args]...); $(prev_kwargs...)).$key)
68+
end
69+
push!(prev_kwargs, key)
70+
return kw
71+
end
72+
73+
new_ex[:args] = convert(Vector{Any}, ex[:args])
74+
75+
new_ex[:name] = :(ITensorNetworks.default_kwargs)
76+
new_ex[:args] = pushfirst!(new_ex[:args], :(::typeof($(esc(ex[:name])))))
77+
78+
# Escape anything on the right-hand side of a keyword definition.
79+
new_ex[:kwargs] = map(new_ex[:kwargs]) do kw
80+
@capture(kw, (key_ = val_) | key_)
81+
if !isnothing(val)
82+
kw.args[2] = esc(val)
83+
end
84+
return kw
85+
end
86+
87+
new_ex[:body] = :(return (; $(prev_kwargs...)))
88+
89+
# Escape the actual function name
90+
ex[:name] = :($(esc(ex[:name])))
91+
92+
rv = quote
93+
$(combinedef(ex))
94+
$(combinedef(new_ex))
95+
end
96+
97+
return rv
98+
end
99+
100+
macro with_defaults(call_expr)
101+
if @capture(call_expr, (func_(args__; kwargs__)) | (func_(args__)))
102+
if isnothing(kwargs)
103+
kwargs = []
104+
end
105+
rv = quote
106+
$(esc(func))(
107+
$(esc.(args)...);
108+
default_kwargs($(esc(func)), $(esc.(args)...); $(esc.(kwargs)...))...,
109+
)
110+
end
111+
return rv
112+
else
113+
throw(ArgumentError("unable to parse function call expression, try including brackets in the macro call."))
114+
end
115+
end

0 commit comments

Comments
 (0)