Skip to content

Commit 3f5c97c

Browse files
author
Jack Dunham
committed
Mutating functions now return the first argument before any additional data.
1 parent d39f09e commit 3f5c97c

File tree

8 files changed

+31
-23
lines changed

8 files changed

+31
-23
lines changed

src/solvers/applyexp.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ end
2929
)
3030
prob = problem(region_iter)
3131

32-
iszero(abs(exponent_step)) && return local_state
32+
if iszero(abs(exponent_step))
33+
return region_iter, local_state
34+
end
3335

3436
solver_kwargs = region_kwargs(solver, region_iter)
3537

@@ -54,7 +56,7 @@ end
5456

5557
prob.current_exponent += exponent_step
5658

57-
return local_state
59+
return region_iter, local_state
5860
end
5961

6062
function default_sweep_callback(
@@ -91,7 +93,7 @@ function applyexp(
9193
]
9294
sweep_iter = SweepIterator(init_prob, kws_array)
9395

94-
converged_prob = sweep_solve!(sweep_callback, sweep_iter)
96+
converged_prob = problem(sweep_solve!(sweep_callback, sweep_iter))
9597

9698
return state(converged_prob)
9799
end

src/solvers/eigsolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737
if outputlevel >= 2
3838
@printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob))
3939
end
40-
return local_state
40+
return region_iter, local_state
4141
end
4242

4343
function default_sweep_callback(
@@ -64,7 +64,7 @@ function eigsolve(operator, init_state; nsweeps, nsites=1, outputlevel=0, sweep_
6464
state=align_indices(init_state), operator=ProjTTN(align_indices(operator))
6565
)
6666
sweep_iter = SweepIterator(init_prob, nsweeps; nsites, outputlevel, sweep_kwargs...)
67-
prob = sweep_solve!(sweep_iter)
67+
prob = problem(sweep_solve!(sweep_iter))
6868
return eigenvalue(prob), state(prob)
6969
end
7070

src/solvers/extract.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ function extract!(region_iter::RegionIterator; subspace_algorithm="nothing")
77

88
prob.state = psi
99

10-
local_state = subspace_expand!(region_iter, local_state; subspace_algorithm)
10+
_, local_state = subspace_expand!(region_iter, local_state; subspace_algorithm)
1111
shifted_operator = position(operator(prob), state(prob), region)
1212

1313
prob.operator = shifted_operator
1414

15-
return local_state
15+
return region_iter, local_state
1616
end

src/solvers/fitting.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function extract!(region_iter::RegionIterator{<:FittingProblem})
4141
prob.state = tn
4242
prob.gauge_region = region
4343

44-
return local_tensor
44+
return region_iter, local_tensor
4545
end
4646

4747
@define_default_kwargs function update!(
@@ -58,7 +58,7 @@ end
5858
@printf(" Region %s: squared overlap = %.12f\n", region, overlap(F))
5959
end
6060

61-
return local_tensor
61+
return region_iter, local_tensor
6262
end
6363

6464
function region_plan(F::FittingProblem; nsites, sweep_kwargs...)
@@ -90,7 +90,7 @@ function fit_tensornetwork(
9090
kwargs_array = [(; sweep_kwargs..., extra_sweep_kwargs..., sweep) for sweep in 1:nsweeps]
9191

9292
sweep_iter = SweepIterator(init_prob, kwargs_array)
93-
converged_prob = sweep_solve!(sweep_iter)
93+
converged_prob = problem(sweep_solve!(sweep_iter))
9494

9595
return rename_vertices(inv_vertex_map(overlap_network), ket(converged_prob))
9696
end
@@ -109,7 +109,11 @@ end
109109
#end
110110

111111
function ITensors.apply(
112-
A::ITensorNetwork, x::ITensorNetwork; maxdim=typemax(Int), cutoff=0.0, sweep_kwargs...
112+
A::AbstractITensorNetwork,
113+
x::AbstractITensorNetwork;
114+
maxdim=typemax(Int),
115+
cutoff=0.0,
116+
sweep_kwargs...,
113117
)
114118
init_state = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space=maxdim)
115119
overlap_network = inner_network(x, A, init_state)

src/solvers/insert.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ function insert!(region_iter, local_tensor; normalize=false, set_orthogonal_regi
2828

2929
prob.state = psi
3030

31-
return prob
31+
return region_iter
3232
end

src/solvers/iterators.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ function increment!(region_iter::RegionIterator)
100100
end
101101

102102
function compute!(iter::RegionIterator)
103-
local_state = @with_defaults extract!(iter; region_kwargs(extract!, iter)...)
104-
local_state = @with_defaults update!(iter, local_state; region_kwargs(update!, iter)...)
103+
_, local_state = @with_defaults extract!(iter; region_kwargs(extract!, iter)...)
104+
_, local_state = @with_defaults update!(
105+
iter, local_state; region_kwargs(update!, iter)...
106+
)
105107
@with_defaults insert!(iter, local_state; region_kwargs(insert!, iter)...)
106108

107109
return iter

src/solvers/subspace/densitymatrix.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,23 @@ using Printf: @printf
1010
psi = copy(state(prob))
1111

1212
prev_vertex_set = setdiff(pos(operator(prob)), region)
13-
(length(prev_vertex_set) != 1) && return local_state
13+
(length(prev_vertex_set) != 1) && return region_iter, local_state
1414
prev_vertex = only(prev_vertex_set)
1515
A = psi[prev_vertex]
1616

1717
next_vertices = filter(v -> (hascommoninds(psi[v], A)), region)
18-
isempty(next_vertices) && return local_state
18+
isempty(next_vertices) && return region_iter, local_state
1919
next_vertex = only(next_vertices)
2020
C = psi[next_vertex]
2121

2222
a = commonind(A, C)
23-
isnothing(a) && return local_state
23+
isnothing(a) && return region_iter, local_state
2424
basis_size = prod(dim.(uniqueinds(A, C)))
2525

2626
expanded_maxdim = compute_expansion(
2727
dim(a), basis_size; region_kwargs(compute_expansion, region_iter)...
2828
)
29-
expanded_maxdim <= 0 && return local_state
29+
expanded_maxdim <= 0 && return region_iter, local_state
3030

3131
envs = environments(operator(prob))
3232
H = operator(operator(prob))
@@ -50,7 +50,7 @@ using Printf: @printf
5050
end
5151
if norm(dag(U) * A) > 1E-10
5252
@printf("Warning: |U*A| = %.3E in subspace expansion\n", norm(dag(U) * A))
53-
return local_state
53+
return region_iter, local_state
5454
end
5555

5656
Ax, ax = directsum(A => a, U => commonind(U, D))
@@ -61,5 +61,5 @@ using Printf: @printf
6161

6262
prob.state = psi
6363

64-
return local_state
64+
return region_iter, local_state
6565
end

src/solvers/subspace/subspace.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ using NDTensors.BackendSelection: Backend, @Backend_str
77
backend = Backend(subspace_algorithm)
88

99
if backend isa Backend"nothing"
10-
return local_state
10+
return region_iter, local_state
1111
end
1212

13-
local_state = @with_defaults subspace_expand!(
13+
_, local_state = @with_defaults subspace_expand!(
1414
backend, region_iter, local_state; region_kwargs(subspace_expand!, region_iter)...
1515
)
1616

17-
return local_state
17+
return region_iter, local_state
1818
end
1919

2020
function subspace_expand!(backend, region_iterator, local_state; kwargs...)

0 commit comments

Comments
 (0)