Skip to content

Commit 62195b6

Browse files
author
Jack Dunham
committed
Overhaul default_kwargs such that it mirrors the function signatures of the associated function.
- This comes at the cost of some verbosity regarding getting and splatting the kwargs from the region iterator, but the usefulness of `default_kwargs` is now much wider and also more well defined - Introduce macro `@default_kwargs` for doing this automatically.
1 parent c59a9c5 commit 62195b6

File tree

10 files changed

+177
-109
lines changed

10 files changed

+177
-109
lines changed

src/ITensorNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ include("treetensornetworks/projttns/projttn.jl")
4545
include("treetensornetworks/projttns/projttnsum.jl")
4646
include("treetensornetworks/projttns/projouterprodttn.jl")
4747

48+
include("solvers/default_kwargs.jl")
4849
include("solvers/local_solvers/eigsolve.jl")
4950
include("solvers/local_solvers/exponentiate.jl")
5051
include("solvers/local_solvers/runge_kutta.jl")
@@ -66,7 +67,6 @@ include("solvers/abstract_problem.jl")
6667
include("solvers/eigsolve.jl")
6768
include("solvers/applyexp.jl")
6869
include("solvers/fitting.jl")
69-
include("solvers/default_kwargs.jl")
7070

7171
include("apply.jl")
7272
include("inner.jl")

src/solvers/applyexp.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,34 @@ function region_plan(A::ApplyExpProblem; nsites, exponent_step, sweep_kwargs...)
2020
return applyexp_regions(state(A), exponent_step; nsites, sweep_kwargs...)
2121
end
2222

23-
function update!(
24-
region_iterator::RegionIterator{<:ApplyExpProblem},
23+
@default_kwargs function update!(
24+
region_iter::RegionIterator{<:ApplyExpProblem},
2525
local_state;
2626
nsites,
2727
exponent_step,
2828
solver=runge_kutta_solver,
29-
kws...,
3029
)
31-
prob = problem(region_iterator)
30+
prob = problem(region_iter)
3231

3332
iszero(abs(exponent_step)) && return local_state
3433

34+
solver_kwargs = region_kwargs(solver, region_iter)
35+
3536
local_state, _ = solver(
36-
x -> optimal_map(operator(prob), x), exponent_step, local_state; kws...
37+
x -> optimal_map(operator(prob), x), exponent_step, local_state; solver_kwargs...
3738
)
3839
if nsites == 1
39-
curr_reg = current_region(region_iterator)
40-
next_reg = next_region(region_iterator)
40+
curr_reg = current_region(region_iter)
41+
next_reg = next_region(region_iter)
4142
if !isnothing(next_reg) && next_reg != curr_reg
4243
next_edge = first(edge_sequence_between_regions(state(prob), curr_reg, next_reg))
4344
v1, v2 = src(next_edge), dst(next_edge)
4445
psi = copy(state(prob))
4546
psi[v1], R = qr(local_state, uniqueinds(local_state, psi[v2]))
4647
shifted_operator = position(operator(prob), psi, NamedEdge(v1 => v2))
47-
R_t, _ = solver(x -> optimal_map(shifted_operator, x), -exponent_step, R; kws...)
48+
R_t, _ = solver(
49+
x -> optimal_map(shifted_operator, x), -exponent_step, R; solver_kwargs...
50+
)
4851
local_state = psi[v1] * R_t
4952
end
5053
end

src/solvers/default_kwargs.jl

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,115 @@
1+
using MacroTools
2+
13
"""
2-
default_kwargs(f, [obj = Any])
4+
default_kwargs(f::Function, args...; kwargs...)
35
4-
Return the default keyword arguments for the function `f`. These defaults may be
5-
derived from the contents or type of the second arugment `obj`.
6+
Returns a set of default keyword arguments, as a `NamedTuple`, for the function `f`
7+
depending on an arbitrary number of positional arguments. Any number of these default
8+
keyword arguments can optionally be overwritten by passing the the keyword as a
9+
keyword argument to this function.
10+
"""
11+
function default_kwargs(f::Function, args...; kwargs...)
12+
return default_kwargs(f, map(typeof, args)...; kwargs...)
13+
end
14+
default_kwargs(f::Function, ::Vararg{<:Type}; kwargs...) = (; kwargs...)
615

7-
## Interface
16+
"""
17+
@default_kwargs
818
9-
Given a function `f`, one can optionally set the default keyword arguments for this
10-
function by specializing either of the following two-argument methods:
19+
Automatically define a `default_kwargs` method for a given function. This macro should
20+
be applied before a function definition:
1121
```
12-
ITensorNetworks.default_kwargs(::typeof(f), prob::AbstractProblem)
13-
ITensorNetworks.default_kwargs(::typeof(f), ::Type{<:AbstractProblem})
22+
@default_kwargs astypes = true function f(args...; kwargs...)
23+
...
24+
end
25+
```
26+
If `astypes = true` then the `default_kwargs` method is defined in the
27+
type domain with respect to `args`, i.e.
28+
```
29+
default_kwargs(::typeof(f), arg::T; kwargs...) # astypes = false
30+
default_kwargs(::typeof(f), arg::Type{<:T}; kwargs...) # astypes = true
1431
```
15-
If one does not require the contents of `prob::Prob` to generate the defaults then it is
16-
recommended to dispatch on `Type{<:Prob}` directly (second method) so the defaults
17-
can be accessed without constructing an instance of a `Prob`.
18-
19-
The return value of `default_kwargs` should be a `NamedTuple`, and will overwrite any
20-
default values set in the function signature.
2132
"""
22-
default_kwargs(f) = default_kwargs(f, Any)
23-
default_kwargs(f, obj) = _default_kwargs_fallback(f, obj)
24-
25-
# To avoid annoying potential method ambiguities.
26-
function _default_kwargs_fallback(f, iter::RegionIterator)
27-
return default_kwargs(f, problem(iter))
28-
end
29-
function _default_kwargs_fallback(f, problem::AbstractProblem)
30-
return default_kwargs(f, typeof(problem))
33+
macro default_kwargs(args...)
34+
kwargs = (;)
35+
for opt in args
36+
if @capture(opt, key_ = val_)
37+
@info "" key val
38+
kwargs = merge(kwargs, NamedTuple{(key,)}((val,)))
39+
elseif opt === last(args)
40+
return default_kwargs_macro(opt; kwargs...)
41+
else
42+
throw(ArgumentError("Unknown expression object"))
43+
end
44+
end
3145
end
3246

33-
# Eventually we reach this if nothing is specialized.
34-
_default_kwargs_fallback(::Any, ::DataType) = (;)
47+
function default_kwargs_macro(function_def; astypes=true)
48+
if !isdef(function_def)
49+
throw(
50+
ArgumentError("The @default_kwargs macro must be followed by a function definition")
51+
)
52+
end
3553

36-
"""
37-
current_kwargs(f, iter::RegionIterator)
54+
ex = splitdef(function_def)
55+
new_ex = deepcopy(ex)
3856

39-
Return the keyword arguments to be passed to the function `f` for the current region
40-
defined by the stateful iterator `iter`.
41-
"""
42-
function current_kwargs(f::Function, iter::RegionIterator)
43-
region_kwargs = get(current_region_kwargs(iter), Symbol(f, :_kwargs), (;))
44-
rv = merge(default_kwargs(f, iter), region_kwargs)
45-
return rv
46-
end
57+
prev_kwargs = []
58+
59+
# Give very positional argument a name and escape the type.
60+
ex[:args] = map(ex[:args]) do arg
61+
@capture(arg, (name_::T_) | (::T_) | name_)
62+
if isnothing(name)
63+
name = gensym()
64+
end
65+
if isnothing(T)
66+
T = :Any
67+
end
68+
return :($(name)::$(esc(T)))
69+
end
70+
71+
# Replacing the kwargs values with the output of `default_kwargs`
72+
ex[:kwargs] = map(ex[:kwargs]) do kw
73+
@capture(kw, (key_::T_ = val_) | (key_ = val_) | key_)
74+
if !isnothing(val)
75+
kw.args[2] =
76+
:(default_kwargs($(esc(ex[:name])), $(ex[:args]...); $(prev_kwargs...)).$key)
77+
end
78+
push!(prev_kwargs, key)
79+
return kw
80+
end
81+
82+
# Promote to the type domain if wanted
83+
if astypes
84+
new_ex[:args] = map(ex[:args]) do arg
85+
@capture(arg, name_::T_)
86+
return :($(name)::Type{<:$T})
87+
end
88+
end
89+
90+
new_ex[:name] = :(ITensorNetworks.default_kwargs)
91+
new_ex[:args] = convert(Vector{Any}, ex[:args])
4792

48-
# Generic
93+
new_ex[:args] = pushfirst!(new_ex[:args], :(::typeof($(esc(ex[:name])))))
4994

50-
# I think these should be set independent of a function, but for now:
51-
function default_kwargs(::typeof(factorize), ::Any)
52-
return (; maxdim=typemax(Int), cutoff=0.0, mindim=1)
95+
# Escape anything on the right-hand side of a keyword definition.
96+
new_ex[:kwargs] = map(new_ex[:kwargs]) do kw
97+
@capture(kw, (key_ = val_) | key_)
98+
if !isnothing(val)
99+
kw.args[2] = esc(val)
100+
end
101+
return kw
102+
end
103+
104+
new_ex[:body] = :(return (; $(prev_kwargs...)))
105+
106+
# Escape the actual function name
107+
ex[:name] = :($(esc(ex[:name])))
108+
109+
rv = quote
110+
$(combinedef(ex))
111+
$(combinedef(new_ex))
112+
end
113+
114+
return rv
53115
end

src/solvers/eigsolve.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,26 @@ function set_truncation_info!(E::EigsolveProblem; spectrum=nothing)
2020
return E
2121
end
2222

23-
function update!(
24-
region_iterator::RegionIterator{<:EigsolveProblem}, local_state; outputlevel, solver
23+
@default_kwargs function update!(
24+
region_iter::RegionIterator{<:EigsolveProblem},
25+
local_state;
26+
outputlevel=0,
27+
solver=eigsolve_solver,
2528
)
26-
prob = problem(region_iterator)
29+
prob = problem(region_iter)
2730

2831
eigval, local_state = solver(
29-
ψ -> optimal_map(operator(prob), ψ),
30-
local_state;
31-
current_kwargs(solver, region_iterator)...,
32+
ψ -> optimal_map(operator(prob), ψ), local_state; region_kwargs(solver, region_iter)...
3233
)
3334

3435
prob.eigenvalue = eigval
3536

3637
if outputlevel >= 2
37-
@printf(
38-
" Region %s: energy = %.12f\n", current_region(region_iterator), eigenvalue(prob)
39-
)
38+
@printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob))
4039
end
4140
return local_state
4241
end
4342

44-
function default_kwargs(::typeof(update!), ::Type{<:EigsolveProblem})
45-
return (; outputlevel=0, solver=eigsolve_solver)
46-
end
47-
4843
function default_sweep_callback(
4944
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel=0
5045
)

src/solvers/extract.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
1-
function extract!(iter; kwargs...)
2-
return _extract_fallback!(iter; subspace_algorithm="nothing", kwargs...)
3-
end
4-
5-
# Internal function such that a method error can be thrown while still allowing a user
6-
# to specialize on `extract!`
7-
function _extract_fallback!(region_iter::RegionIterator; subspace_algorithm)
1+
function extract!(region_iter::RegionIterator; subspace_algorithm="nothing")
82
prob = problem(region_iter)
93
region = current_region(region_iter)
104

src/solvers/fitting.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ function extract!(region_iter::RegionIterator{<:FittingProblem})
4444
return local_tensor
4545
end
4646

47-
function update!(region_iter::RegionIterator{<:FittingProblem}, local_tensor; outputlevel)
47+
@default_kwargs function update!(
48+
region_iter::RegionIterator{<:FittingProblem}, local_tensor; outputlevel=0
49+
)
4850
F = problem(region_iter)
4951

5052
region = current_region(region_iter)
@@ -70,8 +72,7 @@ function fit_tensornetwork(
7072
nsites=1,
7173
outputlevel=0,
7274
normalize=true,
73-
maxdim=default_kwargs(factorize).maxdim,
74-
cutoff=default_kwargs(factorize).cutoff,
75+
factorize_kwargs,
7576
extra_sweep_kwargs...,
7677
)
7778
bpc = BeliefPropagationCache(overlap_network, args...)
@@ -84,7 +85,6 @@ function fit_tensornetwork(
8485

8586
insert!_kwargs = (; normalize, set_orthogonal_region=false)
8687
update!_kwargs = (; outputlevel)
87-
factorize_kwargs = (; maxdim, cutoff)
8888

8989
sweep_kwargs = (; nsites, outputlevel, update!_kwargs, insert!_kwargs, factorize_kwargs)
9090
kwargs_array = [(; sweep_kwargs..., extra_sweep_kwargs..., sweep) for sweep in 1:nsweeps]
@@ -109,12 +109,11 @@ end
109109
#end
110110

111111
function ITensors.apply(
112-
A::ITensorNetwork,
113-
x::ITensorNetwork;
114-
maxdim=default_kwargs(factorize).maxdim,
115-
sweep_kwargs...,
112+
A::ITensorNetwork, x::ITensorNetwork; maxdim=typemax(Int), cutoff=0.0, sweep_kwargs...
116113
)
117114
init_state = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space=maxdim)
118115
overlap_network = inner_network(x, A, init_state)
119-
return fit_tensornetwork(overlap_network; maxdim, sweep_kwargs...)
116+
return fit_tensornetwork(
117+
overlap_network; factorize_kwargs=(; maxdim, cutoff), sweep_kwargs...
118+
)
120119
end

src/solvers/insert.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
using NamedGraphs: edgetype
22

3-
function insert!(region_iter, local_tensor; kwargs...)
4-
return _insert_fallback!(
5-
region_iter, local_tensor; normalize=false, set_orthogonal_region=true, kwargs...
6-
)
7-
end
8-
9-
function _insert_fallback!(region_iter, local_tensor; normalize, set_orthogonal_region)
3+
function insert!(region_iter, local_tensor; normalize=false, set_orthogonal_region=true)
104
prob = problem(region_iter)
115

126
region = current_region(region_iter)
@@ -19,7 +13,7 @@ function _insert_fallback!(region_iter, local_tensor; normalize, set_orthogonal_
1913
tags = ITensors.tags(psi, e)
2014

2115
U, C, spectrum = factorize(
22-
local_tensor, indsTe; tags, current_kwargs(factorize, region_iter)...
16+
local_tensor, indsTe; tags, region_kwargs(factorize, region_iter)...
2317
)
2418

2519
@preserve_graph psi[first(region)] = U

src/solvers/iterators.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,13 @@ function current_region(region_iter::RegionIterator)
7171
return region
7272
end
7373

74-
function current_region_kwargs(region_iter::RegionIterator)
74+
function region_kwargs(region_iter::RegionIterator)
7575
_, kwargs = current_region_plan(region_iter)
7676
return kwargs
7777
end
78+
function region_kwargs(f::Function, iter::RegionIterator)
79+
return get(region_kwargs(iter), Symbol(f, :_kwargs), (;))
80+
end
7881

7982
function prev_region(region_iter::RegionIterator)
8083
state(region_iter) <= 1 && return nothing
@@ -96,10 +99,27 @@ function increment!(region_iter::RegionIterator)
9699
return region_iter
97100
end
98101

102+
# Purely for our convenience:
103+
function extract!_kwargs(iter)
104+
f = extract!
105+
kwargs = region_kwargs(f, iter)
106+
return default_kwargs(f, iter; kwargs...)
107+
end
108+
function update!_kwargs(iter, local_state)
109+
f = update!
110+
kwargs = region_kwargs(f, iter)
111+
return default_kwargs(f, iter, local_state; kwargs...)
112+
end
113+
function insert!_kwargs(iter, local_state)
114+
f = insert!
115+
kwargs = region_kwargs(f, iter)
116+
return default_kwargs(f, iter, local_state; kwargs...)
117+
end
118+
99119
function compute!(iter::RegionIterator)
100-
local_state = extract!(iter; current_kwargs(extract!, iter)...)
101-
local_state = update!(iter, local_state; current_kwargs(update!, iter)...)
102-
insert!(iter, local_state; current_kwargs(insert!, iter)...)
120+
local_state = extract!(iter; extract!_kwargs(iter)...)
121+
local_state = update!(iter, local_state; update!_kwargs(iter, local_state)...)
122+
insert!(iter, local_state; insert!_kwargs(iter, local_state)...)
103123

104124
return iter
105125
end

0 commit comments

Comments
 (0)