Skip to content

Commit 4cca414

Browse files
committed
Cleanup and more testing
1 parent c0659ae commit 4cca414

File tree

6 files changed

+128
-101
lines changed

6 files changed

+128
-101
lines changed

src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ An algorithm that consists of running an algorithm at each iteration
196196
from a list of stored algorithms.
197197
=#
198198
@kwdef struct DefaultNestedAlgorithm{
199-
Algorithms <: AbstractVector{<:Algorithm},
199+
ChildAlgorithm <: Algorithm,
200+
Algorithms <: AbstractVector{ChildAlgorithm},
200201
StoppingCriterion <: AI.StoppingCriterion,
201202
} <: NestedAlgorithm
202203
algorithms::Algorithms
@@ -256,7 +257,8 @@ function AI.step!(
256257
end
257258

258259
@kwdef struct DefaultFlattenedAlgorithm{
259-
Algorithms <: AbstractVector{<:Algorithm},
260+
ChildAlgorithm <: Algorithm,
261+
Algorithms <: AbstractVector{ChildAlgorithm},
260262
StoppingCriterion <: AI.StoppingCriterion,
261263
} <: FlattenedAlgorithm
262264
algorithms::Algorithms
@@ -281,18 +283,19 @@ end
281283

282284
# Algorithm that only performs a single step.
283285
abstract type NonIterativeAlgorithm <: Algorithm end
286+
abstract type NonIterativeAlgorithmState <: State end
284287

285-
function Base.getproperty(algorithm::NonIterativeAlgorithm, name::Symbol)
286-
if name :stopping_criterion
287-
return AI.StopAfterIteration(1)
288-
else
289-
return getfield(algorithm, name)
290-
end
288+
function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...)
289+
return DefaultNonIterativeAlgorithmState(; kwargs...)
290+
end
291+
function AI.solve!(
292+
problem::Problem, algorithm::NonIterativeAlgorithm, state::State; kwargs...
293+
)
294+
return throw(MethodError(AI.solve!, (problem, algorithm, state)))
291295
end
292296

293-
abstract type NonIterativeAlgorithmState <: State end
294-
295-
mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: NonIterativeAlgorithmState
297+
@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <:
298+
NonIterativeAlgorithmState
296299
iterate::Iterate
297300
end
298301

src/ITensorNetworksNext.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ include("abstracttensornetwork.jl")
66
include("tensornetwork.jl")
77
include("TensorNetworkGenerators/TensorNetworkGenerators.jl")
88
include("contract_network.jl")
9-
include("sweeping/sweeping.jl")
109
include("sweeping/eigenproblem.jl")
1110

1211
end

src/sweeping/eigenproblem.jl

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,62 @@
11
import AlgorithmsInterface as AI
22
import .AlgorithmsInterfaceExtensions as AIE
33

4-
maybe_fill(value, len::Int) = fill(value, len)
5-
function maybe_fill(v::AbstractVector, len::Int)
6-
@assert length(v) == len
7-
return v
8-
end
4+
#=
5+
EigenProblem(operator)
96
10-
function dmrg_sweep(operator, algorithm, state)
11-
problem = select_problem(dmrg_sweep, operator, algorithm, state)
12-
return AI.solve(problem, algorithm; iterate = state).iterate
7+
Represents the problem we are trying to solve and minimal algorithm-independent
8+
information, so for an eigenproblem it is the operator we want the eigenvector of.
9+
=#
10+
struct EigenProblem{Operator} <: AIE.Problem
11+
operator::Operator
1312
end
14-
function dmrg_sweep(operator, state; kwargs...)
15-
algorithm = select_algorithm(dmrg_sweep, operator, state; kwargs...)
16-
return dmrg_sweep(operator, algorithm, state)
13+
14+
struct EigsolveRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm
15+
region::R
16+
kwargs::Kwargs
1717
end
18+
EigsolveRegion(region; kwargs...) = EigsolveRegion(region, (; kwargs...))
1819

19-
function select_problem(::typeof(dmrg_sweep), operator, algorithm, state)
20-
return EigenProblem(operator)
20+
function AI.solve!(
21+
problem::EigenProblem, algorithm::EigsolveRegion, state::AIE.State; kwargs...
22+
)
23+
return error("EigsolveRegion step for EigenProblem not implemented yet.")
2124
end
22-
function select_algorithm(::typeof(dmrg_sweep), operator, state; regions, region_kwargs)
23-
region_kwargs′ = maybe_fill(region_kwargs, length(regions))
24-
return Sweep(length(regions)) do i
25-
return Returns(Region(regions[i]; region_kwargs′[i]...))
26-
end
25+
26+
maybe_fill(value, len::Int) = fill(value, len)
27+
function maybe_fill(v::AbstractVector, len::Int)
28+
@assert length(v) == len
29+
return v
2730
end
2831

2932
function dmrg(operator, algorithm, state)
30-
problem = select_problem(dmrg, operator, algorithm, state)
33+
problem = EigenProblem(operator)
3134
return AI.solve(problem, algorithm; iterate = state).iterate
3235
end
3336
function dmrg(operator, state; kwargs...)
37+
problem = EigenProblem(operator)
3438
algorithm = select_algorithm(dmrg, operator, state; kwargs...)
35-
return dmrg(operator, algorithm, state)
39+
return AI.solve(problem, algorithm; iterate = state).iterate
3640
end
3741

38-
function select_problem(::typeof(dmrg), operator, algorithm, state)
39-
return EigenProblem(operator)
42+
function repeat_last(v::AbstractVector, len::Int)
43+
length(v) len && return v
44+
return [v; fill(v[end], len - length(v))]
4045
end
41-
function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, region_kwargs)
42-
region_kwargs′ = maybe_fill(region_kwargs, nsweeps)
43-
return Sweeping(nsweeps) do i
44-
return select_algorithm(
45-
dmrg_sweep, operator, state;
46-
regions, region_kwargs = region_kwargs′[i],
47-
)
48-
end
46+
repeat_last(v, len::Int) = fill(v, len)
47+
function extend_columns(nt::NamedTuple, len::Int)
48+
return NamedTuple{keys(nt)}(map(v -> repeat_last(v, len), values(nt)))
4949
end
50-
51-
#=
52-
EigenProblem(operator)
53-
54-
Represents the problem we are trying to solve and minimal algorithm-independent
55-
information, so for an eigenproblem it is the operator we want the eigenvector of.
56-
=#
57-
struct EigenProblem{Operator} <: AIE.Problem
58-
operator::Operator
50+
function eachrow(nt::NamedTuple, len::Int)
51+
return [NamedTuple{keys(nt)}(map(v -> v[i], values(nt))) for i in 1:len]
5952
end
6053

61-
function AI.step!(problem::EigenProblem, algorithm::Region, state::AIE.State; kwargs...)
62-
return error("Region step for EigenProblem not implemented.")
54+
function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, kwargs...)
55+
extended_kwargs = extend_columns((; kwargs...), nsweeps)
56+
region_kwargs = eachrow(extended_kwargs, nsweeps)
57+
return AIE.nested_algorithm(nsweeps) do i
58+
return AIE.nested_algorithm(length(regions)) do j
59+
return EigsolveRegion(regions[j]; region_kwargs[i]...)
60+
end
61+
end
6362
end

src/sweeping/sweeping.jl

Lines changed: 0 additions & 40 deletions
This file was deleted.

test/test_dmrg.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import AlgorithmsInterface as AI
2+
using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm
3+
import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE
4+
using Test: @test, @testset
5+
6+
@testset "select_algorithm(dmrg, ...)" begin
7+
operator = "operator"
8+
init = "init"
9+
nsweeps = 3
10+
regions = ["region1", "region2"]
11+
maxdim = [10, 20]
12+
cutoff = 1.0e-7
13+
algorithm = select_algorithm(dmrg, operator, init; nsweeps, regions, maxdim, cutoff)
14+
@test algorithm isa AIE.NestedAlgorithm
15+
@test length(algorithm.algorithms) == nsweeps
16+
17+
maxdims = [10, 20, 20]
18+
cutoffs = [1.0e-7, 1.0e-7, 1.0e-7]
19+
algorithm′ = AIE.nested_algorithm(nsweeps) do i
20+
return AIE.nested_algorithm(length(regions)) do j
21+
return EigsolveRegion(
22+
regions[j];
23+
maxdim = maxdims[i],
24+
cutoff = cutoffs[i],
25+
)
26+
end
27+
end
28+
for i in 1:nsweeps
29+
for j in 1:length(regions)
30+
@test algorithm.algorithms[i].algorithms[j] ==
31+
algorithm′.algorithms[i].algorithms[j]
32+
end
33+
end
34+
end

test/test_sweeping.jl

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,65 @@
11
import AlgorithmsInterface as AI
2-
using ITensorNetworksNext: Region, Sweep, Sweeping
32
import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE
43
using Test: @test, @testset
54

65
struct TestProblem <: AIE.Problem
76
end
87

9-
function AI.step!(problem::TestProblem, algorithm::Region, state::AIE.State; kwargs...)
10-
state.iterate = algorithm.region
8+
struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm
9+
region::R
10+
kwargs::Kwargs
11+
end
12+
TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...))
13+
14+
function AI.solve!(problem::TestProblem, algorithm::TestRegion, state::AIE.State; kwargs...)
15+
new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar)
16+
state.iterate = [state.iterate; [new_iterate]]
1117
return state
1218
end
1319

1420
@testset "Sweeping" begin
15-
@testset "Region" begin
16-
algorithm = Region("region"; foo = 1, bar = 2)
21+
@testset "TestRegion" begin
22+
algorithm = TestRegion("region"; foo = 1, bar = 2)
1723
@test algorithm isa AIE.NonIterativeAlgorithm
1824
@test algorithm isa AIE.Algorithm
1925
@test algorithm isa AI.Algorithm
2026
@test algorithm.region == "region"
2127
@test algorithm.kwargs == (; foo = 1, bar = 2)
22-
@test Region(; region = "region", foo = 1, bar = 2) == algorithm
2328

2429
problem = TestProblem()
25-
iterate = ""
30+
iterate = []
2631
state = AI.solve(problem, algorithm; iterate)
27-
@test state.iterate == "region"
32+
@test state.iterate == [(; region = "region", foo = 1, bar = 2)]
2833
end
2934
@testset "Sweep" begin
35+
algorithm = AIE.nested_algorithm(3) do i
36+
return TestRegion("region$i"; foo = i, bar = 2i)
37+
end
38+
problem = TestProblem()
39+
iterate = []
40+
state = AI.solve(problem, algorithm; iterate)
41+
@test state.iterate == [
42+
(; region = "region1", foo = 1, bar = 2),
43+
(; region = "region2", foo = 2, bar = 4),
44+
(; region = "region3", foo = 3, bar = 6),
45+
]
3046
end
3147
@testset "Sweeping" begin
48+
algorithm = AIE.nested_algorithm(2) do i
49+
AIE.nested_algorithm(3) do j
50+
return TestRegion("sweep$i, region$j"; foo = (i, j), bar = (2i, 2j))
51+
end
52+
end
53+
problem = TestProblem()
54+
iterate = []
55+
state = AI.solve(problem, algorithm; iterate)
56+
@test state.iterate == [
57+
(; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)),
58+
(; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)),
59+
(; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)),
60+
(; region = "sweep2, region1", foo = (2, 1), bar = (4, 2)),
61+
(; region = "sweep2, region2", foo = (2, 2), bar = (4, 4)),
62+
(; region = "sweep2, region3", foo = (2, 3), bar = (4, 6)),
63+
]
3264
end
3365
end

0 commit comments

Comments
 (0)