diff --git a/Project.toml b/Project.toml index c635eea1..cbef6eb7 100644 --- a/Project.toml +++ b/Project.toml @@ -4,32 +4,34 @@ authors = ["Tobias Knopp "] version = "0.13.1-DEV" [deps] -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -LinearOperatorCollection = "a4a2c56f-fead-462a-a3ab-85921a5f2575" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" +Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" -StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearOperatorCollection = "a4a2c56f-fead-462a-a3ab-85921a5f2575" +LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" [compat] -IterativeSolvers = "0.9" -julia = "1.9" -StatsBase = "0.33, 0.34" +FFTW = "1.0" FLoops = "0.2" -VectorizationBase = "0.19, 0.21" +IterativeSolvers = "0.9" LinearOperatorCollection = "1.2" LinearOperators = "2.3.3" -FFTW = "1.0" +StatsBase = "0.33, 0.34" +VectorizationBase = "0.19, 0.21" +julia = "1.9" + +[extras] +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "Random", "FFTW"] diff --git a/src/Callbacks.jl b/src/Callbacks.jl index 7bd2eb4f..82faed86 100644 --- a/src/Callbacks.jl +++ b/src/Callbacks.jl @@ -1,3 +1,6 @@ +using ProgressMeter + + export CompareSolutionCallback mutable struct CompareSolutionCallback{T, F} ref::Vector{T} @@ -49,4 +52,33 @@ function (cb::StoreConvergenceCallback)(solver::AbstractLinearSolver, _) push!(values, meas[key]) cb.convMeas[key] = values end -end \ No newline at end of file +end + + +export ProgressBarCallback +""" + ProgressBarCallback() + +Callback that displays a progress bar for a solver. +""" +Base.@kwdef mutable struct ProgressBarCallback + meter::Union{Progress,Nothing} = nothing +end +ProgressBarCallback(solver::AbstractLinearSolver) = ProgressBarCallback(Progress(solver.iterations)) +ProgressBarCallback(iterations::Int) = ProgressBarCallback(Progress(iterations)) + +""" + (self::ProgressBarCallback)(solver::AbstractLinearSolver, iter_n::Int) + +Initializes the callback when `iter_n` is zero, then updates the progress bar. +""" +function (self::ProgressBarCallback)(solver::AbstractLinearSolver, iter_n::Int) + if iter_n != 0 + next!(self.meter) + end + + # lazy init for iter_n = 0 + if iter_n == 0 && isnothing(self.meter) + self.meter = Progress(solver.iterations) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ffd3b5b4..6a652839 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,4 +6,5 @@ using FFTW include("testKaczmarz.jl") include("testProxMaps.jl") include("testSolvers.jl") -include("testRegularization.jl") \ No newline at end of file +include("testRegularization.jl") +include("testCallbacks.jl") \ No newline at end of file diff --git a/test/testCallbacks.jl b/test/testCallbacks.jl new file mode 100644 index 00000000..e356b028 --- /dev/null +++ b/test/testCallbacks.jl @@ -0,0 +1,15 @@ +@testset "ProgressBarCallback" begin + A = [ + 0.831658 0.96717 + 0.383056 0.39043 + 0.820692 0.08118 + ] + x = [0.593; 0.269] + b = A * x + + solver = ADMM(A; iterations=50) + + _ = solve!(solver, b, callbacks=ProgressBarCallback()) + _ = solve!(solver, b, callbacks=ProgressBarCallback(solver)) + _ = solve!(solver, b, callbacks=ProgressBarCallback(50)) +end \ No newline at end of file