Skip to content

Commit bbc8f29

Browse files
committed
Fix tests
1 parent 4cca414 commit bbc8f29

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ abstract type Algorithm <: AI.Algorithm end
99
abstract type State <: AI.State end
1010

1111
function AI.initialize_state!(
12-
problem::Problem, algorithm::Algorithm, state::State; kwargs...
12+
problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs...
1313
)
14+
for (k, v) in pairs(kwargs)
15+
setproperty!(state, k, v)
16+
end
17+
state.iteration = iteration
1418
AI.initialize_state!(
1519
problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state
1620
)

test/test_algorithmsinterfaceextensions.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ end
5555
stopping_criterion_state = AI.initialize_state(
5656
problem, algorithm, algorithm.stopping_criterion
5757
)
58-
state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state)
59-
60-
initial_iterate = [1.0, 2.0]
61-
AI.initialize_state!(problem, algorithm, state; iterate = initial_iterate)
62-
@test state.iterate == initial_iterate
58+
state = AIE.DefaultState(;
59+
iteration = 2, iterate = [0.0, 0.0], stopping_criterion_state
60+
)
61+
AI.initialize_state!(problem, algorithm, state)
62+
@test state.iterate == [0.0, 0.0]
6363
@test state.iteration == 0
64+
@test state.stopping_criterion_state == stopping_criterion_state
6465
end
6566

6667
@testset "initialize_state" begin
@@ -99,13 +100,14 @@ end
99100
state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate))
100101

101102
# Solve with custom initial iterate
103+
initial_iterate = [5.0, 10.0]
102104
final_state = AI.solve!(
103105
problem, algorithm, state; iterate = copy(initial_iterate)
104106
)
105107

106108
@test final_state.iteration == 3
107-
# Each step increments by 1, so after 3 steps: [10, 20] + 3 = [13, 23]
108-
@test final_state.iterate [13.0, 23.0]
109+
# Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13]
110+
@test final_state.iterate [8.0, 13.0]
109111

110112
# Test solve without exclamation
111113
problem2 = TestProblem([1.0, 2.0])

test/test_aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ using Aqua: Aqua
33
using Test: @testset
44

55
@testset "Code quality (Aqua.jl)" begin
6-
Aqua.test_all(ITensorNetworksNext)
6+
Aqua.test_all(ITensorNetworksNext; persistent_tasks = false)
77
end

0 commit comments

Comments
 (0)