Skip to content

Commit ee04297

Browse files
committed
Count statistics inside calls and not in individual algorithms
1 parent 5def912 commit ee04297

File tree

7 files changed

+29
-31
lines changed

7 files changed

+29
-31
lines changed

src/gaussnewton.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
123123
A, b = cache.J, _vec(cache.fu)
124124
end
125125

126-
linres = dolinsolve(cache.alg.precs, cache.linsolve; A, b, linu = _vec(cache.du),
126+
linres = dolinsolve(cache, cache.alg.precs, cache.linsolve; A, b, linu = _vec(cache.du),
127127
cache.p, reltol = cache.abstol)
128128
cache.linsolve = linres.cache
129129
cache.du = _restructure(cache.du, linres.u)
@@ -142,9 +142,6 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip}
142142
@bb copyto!(cache.u_cache, cache.u)
143143
@bb copyto!(cache.dfu, cache.fu)
144144

145-
cache.stats.njacs += 1
146-
cache.stats.nsolve += 1
147-
cache.stats.nfactors += 1
148145
return nothing
149146
end
150147

src/klement.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
123123
A = ifelse(cache.J isa SMatrix || cache.J isa Number || !fact_done, cache.J, nothing)
124124

125125
# u = u - J \ fu
126-
linres = dolinsolve(alg.precs, cache.linsolve; A,
126+
linres = dolinsolve(cache, alg.precs, cache.linsolve; A,
127127
b = _vec(cache.fu), linu = _vec(cache.du), cache.p, reltol = cache.abstol)
128128
cache.linsolve = linres.cache
129129
cache.du = _restructure(cache.du, linres.u)
@@ -139,9 +139,6 @@ function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
139139

140140
@bb copyto!(cache.u_cache, cache.u)
141141

142-
cache.stats.nsolve += 1
143-
cache.stats.nfactors += 1
144-
145142
cache.force_stop && return nothing
146143

147144
# Update the Jacobian

src/levenberg.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,14 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
263263
else
264264
cache.rhs_tmp = _vcat(_vec(cache.fu), zero(_vec(cache.u)))
265265
end
266-
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
266+
linres = dolinsolve(cache, alg.precs, linsolve; A = cache.mat_tmp,
267267
b = cache.rhs_tmp, linu = _vec(cache.v), cache.p, reltol = cache.abstol)
268268
else
269269
@bb cache.u_cache_2 = transpose(cache.J) × cache.fu
270270
@bb @. cache.mat_tmp = cache.JᵀJ + cache.λ * cache.DᵀD
271-
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
272-
b = _vec(cache.u_cache_2), linu = _vec(cache.v), cache.p, reltol = cache.abstol)
271+
linres = dolinsolve(cache, alg.precs, linsolve;
272+
A = __maybe_symmetric(cache.mat_tmp), b = _vec(cache.u_cache_2),
273+
linu = _vec(cache.v), cache.p, reltol = cache.abstol)
273274
end
274275
cache.linsolve = linres.cache
275276
linu = _restructure(cache.v, linres.u)
@@ -293,20 +294,17 @@ function perform_step!(cache::LevenbergMarquardtCache{iip, fastls}) where {iip,
293294
else
294295
cache.rhs_tmp = _vcat(_vec(cache.fu_cache_2), zero(_vec(cache.u)))
295296
end
296-
linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.a),
297-
cache.p, reltol = cache.abstol)
297+
linres = dolinsolve(cache, alg.precs, linsolve; b = cache.rhs_tmp,
298+
linu = _vec(cache.a), cache.p, reltol = cache.abstol)
298299
else
299300
@bb cache.u_cache_2 = transpose(cache.J) × cache.fu_cache_2
300-
linres = dolinsolve(alg.precs, linsolve; b = _vec(cache.u_cache_2),
301+
linres = dolinsolve(cache, alg.precs, linsolve; b = _vec(cache.u_cache_2),
301302
linu = _vec(cache.a), cache.p, reltol = cache.abstol)
302303
end
303304
cache.linsolve = linres.cache
304305
linu = _restructure(cache.a, linres.u)
305306
@bb @. cache.a = -linu
306307

307-
cache.stats.nsolve += 2
308-
cache.stats.nfactors += 2
309-
310308
# Require acceptable steps to satisfy the following condition.
311309
norm_v = cache.internalnorm(cache.v)
312310
if 2 * cache.internalnorm(cache.a) cache.α_geodesic * norm_v

src/pseudotransient.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ function perform_step!(cache::PseudoTransientCache{iip}) where {iip}
127127
end
128128

129129
# u = u - J \ fu
130-
linres = dolinsolve(alg.precs, cache.linsolve; A, b = _vec(cache.fu),
130+
linres = dolinsolve(cache, alg.precs, cache.linsolve; A, b = _vec(cache.fu),
131131
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
132132
cache.linsolve = linres.cache
133133
cache.du = _restructure(cache.du, linres.u)
@@ -145,9 +145,6 @@ function perform_step!(cache::PseudoTransientCache{iip}) where {iip}
145145
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
146146

147147
@bb copyto!(cache.u_cache, cache.u)
148-
cache.stats.njacs += 1
149-
cache.stats.nsolve += 1
150-
cache.stats.nfactors += 1
151148
return nothing
152149
end
153150

src/raphson.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}
104104
cache.J = jacobian!!(cache.J, cache)
105105

106106
# u = u - J \ fu
107-
linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu),
107+
linres = dolinsolve(cache, alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu),
108108
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
109109
cache.linsolve = linres.cache
110110
cache.du = _restructure(cache.du, linres.u)
@@ -119,8 +119,5 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}
119119
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
120120

121121
@bb copyto!(cache.u_cache, cache.u)
122-
cache.stats.njacs += 1
123-
cache.stats.nsolve += 1
124-
cache.stats.nfactors += 1
125122
return nothing
126123
end

src/trustRegion.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ function perform_step!(cache::TrustRegionCache{iip}) where {iip}
360360

361361
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
362362
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
363-
linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J,
363+
linres = dolinsolve(cache, cache.alg.precs, cache.linsolve, A = cache.J,
364364
b = _vec(cache.fu), linu = _vec(cache.u_gauss_newton), p = cache.p,
365365
reltol = cache.abstol)
366366
cache.linsolve = linres.cache
@@ -375,8 +375,6 @@ function perform_step!(cache::TrustRegionCache{iip}) where {iip}
375375
@bb @. cache.u_cache_2 = cache.u + cache.du
376376
evaluate_f(cache, cache.u_cache_2, cache.p, Val{:fu_cache_2}())
377377
trust_region_step!(cache)
378-
cache.stats.nsolve += 1
379-
cache.stats.nfactors += 1
380378
return nothing
381379
end
382380

src/utils.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,26 @@ end
8888

8989
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing
9090

91-
function dolinsolve(precs::P, linsolve::FakeLinearSolveJLCache; A = nothing,
91+
function dolinsolve(cache, precs::P, linsolve::FakeLinearSolveJLCache; A = nothing,
9292
linu = nothing, b = nothing, du = nothing, p = nothing, weight = nothing,
9393
cachedata = nothing, reltol = nothing, reuse_A_if_factorization = false) where {P}
94+
# Update Statistics
95+
cache.stats.nsolve += 1
96+
cache.stats.nfactors += !(A isa Number)
97+
9498
A !== nothing && (linsolve.A = A)
9599
b !== nothing && (linsolve.b = b)
96100
linres = linsolve.A \ linsolve.b
97101
return FakeLinearSolveJLResult(linsolve, linres)
98102
end
99103

100-
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
104+
function dolinsolve(cache, precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
101105
du = nothing, p = nothing, weight = nothing, cachedata = nothing, reltol = nothing,
102106
reuse_A_if_factorization = false) where {P}
107+
# Update Statistics
108+
cache.stats.nsolve += 1
109+
cache.stats.nfactors += 1
110+
103111
# Some Algorithms would reuse factorization but it causes the cache to not reset in
104112
# certain cases
105113
if A !== nothing
@@ -108,10 +116,16 @@ function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing
108116
(alg isa LinearSolve.DefaultLinearSolver && !(alg ==
109117
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)))
110118
# Factorization Algorithm
111-
!reuse_A_if_factorization && (linsolve.A = A)
119+
if reuse_A_if_factorization
120+
cache.stats.nfactors -= 1
121+
else
122+
linsolve.A = A
123+
end
112124
else
113125
linsolve.A = A
114126
end
127+
else
128+
cache.stats.nfactors -= 1
115129
end
116130
b !== nothing && (linsolve.b = b)
117131
linu !== nothing && (linsolve.u = linu)

0 commit comments

Comments
 (0)