11abstract type ExtrapolationMutableCache <: OrdinaryDiffEqMutableCache end
22get_fsalfirstlast (cache:: ExtrapolationMutableCache , u) = (cache. fsalfirst, cache. k)
33
4+ # Helper function to determine appropriate thread count for array allocation
5+ # Uses maxthreadid() when threading is enabled, otherwise just 1 for maximum memory efficiency
6+ @inline function get_thread_count (alg)
7+ return isthreaded (alg. threading) ? Threads. maxthreadid () : 1
8+ end
9+
410@cache mutable struct AitkenNevilleCache{
511 uType,
612 rateType,
@@ -48,11 +54,11 @@ function alg_cache(alg::AitkenNeville, u, rate_prototype, ::Type{uEltypeNoUnits}
4854 T = Array {typeof(u), 2} (undef, alg. max_order, alg. max_order)
4955 # Array of arrays of length equal to number of threads to store intermediate
5056 # values of u and k. [Thread Safety]
51- u_tmps = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
52- k_tmps = Array {typeof(k), 1} (undef, Threads . nthreads ( ))
57+ u_tmps = Array {typeof(u), 1} (undef, get_thread_count (alg ))
58+ k_tmps = Array {typeof(k), 1} (undef, get_thread_count (alg ))
5359 # Initialize each element of u_tmps and k_tmps to different instance of
5460 # zeros array similar to u and k respectively
55- for i in 1 : Threads . nthreads ( )
61+ for i in 1 : get_thread_count (alg )
5662 u_tmps[i] = zero (u)
5763 k_tmps[i] = zero (rate_prototype)
5864 end
@@ -196,26 +202,26 @@ function alg_cache(alg::ImplicitEulerExtrapolation, u, rate_prototype,
196202 :: Type{tTypeNoUnits} , uprev, uprev2, f, t, dt, reltol, p, calck,
197203 :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
198204 u_tmp = zero (u)
199- u_tmps = Array {typeof(u_tmp), 1} (undef, Threads . nthreads ( ))
205+ u_tmps = Array {typeof(u_tmp), 1} (undef, get_thread_count (alg ))
200206
201207 u_tmps[1 ] = u_tmp
202- for i in 2 : Threads . nthreads ( )
208+ for i in 2 : get_thread_count (alg )
203209 u_tmps[i] = zero (u_tmp)
204210 end
205211
206- u_tmps2 = Array {typeof(u_tmp), 1} (undef, Threads . nthreads ( ))
212+ u_tmps2 = Array {typeof(u_tmp), 1} (undef, get_thread_count (alg ))
207213
208- for i in 1 : Threads . nthreads ( )
214+ for i in 1 : get_thread_count (alg )
209215 u_tmps2[i] = zero (u_tmp)
210216 end
211217
212218 utilde = zero (u)
213219 tmp = zero (u)
214220 k_tmp = zero (rate_prototype)
215- k_tmps = Array {typeof(k_tmp), 1} (undef, Threads . nthreads ( ))
221+ k_tmps = Array {typeof(k_tmp), 1} (undef, get_thread_count (alg ))
216222
217223 k_tmps[1 ] = k_tmp
218- for i in 2 : Threads . nthreads ( )
224+ for i in 2 : get_thread_count (alg )
219225 k_tmps[i] = zero (rate_prototype)
220226 end
221227
@@ -244,9 +250,9 @@ function alg_cache(alg::ImplicitEulerExtrapolation, u, rate_prototype,
244250 W_el = zero (J)
245251 end
246252
247- W = Array {typeof(W_el), 1} (undef, Threads . nthreads ( ))
253+ W = Array {typeof(W_el), 1} (undef, get_thread_count (alg ))
248254 W[1 ] = W_el
249- for i in 2 : Threads . nthreads ( )
255+ for i in 2 : get_thread_count (alg )
250256 if W_el isa WOperator
251257 W[i] = WOperator (f, dt, true )
252258 else
@@ -257,9 +263,9 @@ function alg_cache(alg::ImplicitEulerExtrapolation, u, rate_prototype,
257263 tf = TimeGradientWrapper (f, uprev, p)
258264 uf = UJacobianWrapper (f, t, p)
259265 linsolve_tmp = zero (rate_prototype)
260- linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, Threads . nthreads ( ))
266+ linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, get_thread_count (alg ))
261267
262- for i in 1 : Threads . nthreads ( )
268+ for i in 1 : get_thread_count (alg )
263269 linsolve_tmps[i] = zero (rate_prototype)
264270 end
265271
@@ -269,9 +275,9 @@ function alg_cache(alg::ImplicitEulerExtrapolation, u, rate_prototype,
269275 # Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
270276 # Pr = Diagonal(_vec(weight)))
271277
272- linsolve = Array {typeof(linsolve1), 1} (undef, Threads . nthreads ( ))
278+ linsolve = Array {typeof(linsolve1), 1} (undef, get_thread_count (alg ))
273279 linsolve[1 ] = linsolve1
274- for i in 2 : Threads . nthreads ( )
280+ for i in 2 : get_thread_count (alg )
275281 linprob = LinearProblem (W[i], _vec (linsolve_tmps[i]); u0 = _vec (k_tmps[i]))
276282 linsolve[i] = init (linprob, alg. linsolve,
277283 alias = LinearAliasSpecifier (alias_A = true , alias_b = true ))
@@ -285,9 +291,9 @@ function alg_cache(alg::ImplicitEulerExtrapolation, u, rate_prototype,
285291 sequence = generate_sequence (constvalue (uBottomEltypeNoUnits), alg)
286292 cc = alg_cache (alg, u, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits,
287293 tTypeNoUnits, uprev, uprev2, f, t, dt, reltol, p, calck, Val (false ))
288- diff1 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
289- diff2 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
290- for i in 1 : Threads . nthreads ( )
294+ diff1 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
295+ diff2 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
296+ for i in 1 : get_thread_count (alg )
291297 diff1[i] = zero (u)
292298 diff2[i] = zero (u)
293299 end
@@ -958,10 +964,10 @@ function alg_cache(alg::ExtrapolationMidpointDeuflhard, u, rate_prototype,
958964 utilde = zero (u)
959965 u_temp1 = zero (u)
960966 u_temp2 = zero (u)
961- u_temp3 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
962- u_temp4 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
967+ u_temp3 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
968+ u_temp4 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
963969
964- for i in 1 : Threads . nthreads ( )
970+ for i in 1 : get_thread_count (alg )
965971 u_temp3[i] = zero (u)
966972 u_temp4[i] = zero (u)
967973 end
@@ -975,8 +981,8 @@ function alg_cache(alg::ExtrapolationMidpointDeuflhard, u, rate_prototype,
975981
976982 fsalfirst = zero (rate_prototype)
977983 k = zero (rate_prototype)
978- k_tmps = Array {typeof(k), 1} (undef, Threads . nthreads ( ))
979- for i in 1 : Threads . nthreads ( )
984+ k_tmps = Array {typeof(k), 1} (undef, get_thread_count (alg ))
985+ for i in 1 : get_thread_count (alg )
980986 k_tmps[i] = zero (rate_prototype)
981987 end
982988
@@ -1097,10 +1103,10 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
10971103 utilde = zero (u)
10981104 u_temp1 = zero (u)
10991105 u_temp2 = zero (u)
1100- u_temp3 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1101- u_temp4 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1106+ u_temp3 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1107+ u_temp4 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
11021108
1103- for i in 1 : Threads . nthreads ( )
1109+ for i in 1 : get_thread_count (alg )
11041110 u_temp3[i] = zero (u)
11051111 u_temp4[i] = zero (u)
11061112 end
@@ -1114,8 +1120,8 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
11141120
11151121 fsalfirst = zero (rate_prototype)
11161122 k = zero (rate_prototype)
1117- k_tmps = Array {typeof(k), 1} (undef, Threads . nthreads ( ))
1118- for i in 1 : Threads . nthreads ( )
1123+ k_tmps = Array {typeof(k), 1} (undef, get_thread_count (alg ))
1124+ for i in 1 : get_thread_count (alg )
11191125 k_tmps[i] = zero (rate_prototype)
11201126 end
11211127
@@ -1134,9 +1140,9 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
11341140 W_el = zero (J)
11351141 end
11361142
1137- W = Array {typeof(W_el), 1} (undef, Threads . nthreads ( ))
1143+ W = Array {typeof(W_el), 1} (undef, get_thread_count (alg ))
11381144 W[1 ] = W_el
1139- for i in 2 : Threads . nthreads ( )
1145+ for i in 2 : get_thread_count (alg )
11401146 if W_el isa WOperator
11411147 W[i] = WOperator (f, dt, true )
11421148 else
@@ -1146,9 +1152,9 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
11461152 tf = TimeGradientWrapper (f, uprev, p)
11471153 uf = UJacobianWrapper (f, t, p)
11481154 linsolve_tmp = zero (rate_prototype)
1149- linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, Threads . nthreads ( ))
1155+ linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, get_thread_count (alg ))
11501156
1151- for i in 1 : Threads . nthreads ( )
1157+ for i in 1 : get_thread_count (alg )
11521158 linsolve_tmps[i] = zero (rate_prototype)
11531159 end
11541160
@@ -1158,9 +1164,9 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
11581164 # Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
11591165 # Pr = Diagonal(_vec(weight)))
11601166
1161- linsolve = Array {typeof(linsolve1), 1} (undef, Threads . nthreads ( ))
1167+ linsolve = Array {typeof(linsolve1), 1} (undef, get_thread_count (alg ))
11621168 linsolve[1 ] = linsolve1
1163- for i in 2 : Threads . nthreads ( )
1169+ for i in 2 : get_thread_count (alg )
11641170 linprob = LinearProblem (W[i], _vec (linsolve_tmps[i]); u0 = _vec (k_tmps[i]))
11651171 linsolve[i] = init (linprob, alg. linsolve,
11661172 alias = LinearAliasSpecifier (alias_A = true , alias_b = true ))
@@ -1170,9 +1176,9 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
11701176 grad_config = build_grad_config (alg, f, tf, du1, t)
11711177 jac_config = build_jac_config (alg, f, uf, du1, uprev, u, du1, du2)
11721178
1173- diff1 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1174- diff2 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1175- for i in 1 : Threads . nthreads ( )
1179+ diff1 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1180+ diff2 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1181+ for i in 1 : get_thread_count (alg )
11761182 diff1[i] = zero (u)
11771183 diff2[i] = zero (u)
11781184 end
@@ -1272,10 +1278,10 @@ function alg_cache(alg::ExtrapolationMidpointHairerWanner, u, rate_prototype,
12721278 utilde = zero (u)
12731279 u_temp1 = zero (u)
12741280 u_temp2 = zero (u)
1275- u_temp3 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1276- u_temp4 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1281+ u_temp3 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1282+ u_temp4 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
12771283
1278- for i in 1 : Threads . nthreads ( )
1284+ for i in 1 : get_thread_count (alg )
12791285 u_temp3[i] = zero (u)
12801286 u_temp4[i] = zero (u)
12811287 end
@@ -1287,8 +1293,8 @@ function alg_cache(alg::ExtrapolationMidpointHairerWanner, u, rate_prototype,
12871293 res = uEltypeNoUnits .(zero (u))
12881294 fsalfirst = zero (rate_prototype)
12891295 k = zero (rate_prototype)
1290- k_tmps = Array {typeof(k), 1} (undef, Threads . nthreads ( ))
1291- for i in 1 : Threads . nthreads ( )
1296+ k_tmps = Array {typeof(k), 1} (undef, get_thread_count (alg ))
1297+ for i in 1 : get_thread_count (alg )
12921298 k_tmps[i] = zero (rate_prototype)
12931299 end
12941300
@@ -1429,10 +1435,10 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
14291435 utilde = zero (u)
14301436 u_temp1 = zero (u)
14311437 u_temp2 = zero (u)
1432- u_temp3 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1433- u_temp4 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1438+ u_temp3 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1439+ u_temp4 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
14341440
1435- for i in 1 : Threads . nthreads ( )
1441+ for i in 1 : get_thread_count (alg )
14361442 u_temp3[i] = zero (u)
14371443 u_temp4[i] = zero (u)
14381444 end
@@ -1444,8 +1450,8 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
14441450 res = uEltypeNoUnits .(zero (u))
14451451 fsalfirst = zero (rate_prototype)
14461452 k = zero (rate_prototype)
1447- k_tmps = Array {typeof(k), 1} (undef, Threads . nthreads ( ))
1448- for i in 1 : Threads . nthreads ( )
1453+ k_tmps = Array {typeof(k), 1} (undef, get_thread_count (alg ))
1454+ for i in 1 : get_thread_count (alg )
14491455 k_tmps[i] = zero (rate_prototype)
14501456 end
14511457
@@ -1463,9 +1469,9 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
14631469 W_el = zero (J)
14641470 end
14651471
1466- W = Array {typeof(W_el), 1} (undef, Threads . nthreads ( ))
1472+ W = Array {typeof(W_el), 1} (undef, get_thread_count (alg ))
14671473 W[1 ] = W_el
1468- for i in 2 : Threads . nthreads ( )
1474+ for i in 2 : get_thread_count (alg )
14691475 if W_el isa WOperator
14701476 W[i] = WOperator (f, dt, true )
14711477 else
@@ -1476,9 +1482,9 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
14761482 tf = TimeGradientWrapper (f, uprev, p)
14771483 uf = UJacobianWrapper (f, t, p)
14781484 linsolve_tmp = zero (rate_prototype)
1479- linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, Threads . nthreads ( ))
1485+ linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, get_thread_count (alg ))
14801486
1481- for i in 1 : Threads . nthreads ( )
1487+ for i in 1 : get_thread_count (alg )
14821488 linsolve_tmps[i] = zero (rate_prototype)
14831489 end
14841490
@@ -1488,9 +1494,9 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
14881494 # Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
14891495 # Pr = Diagonal(_vec(weight)))
14901496
1491- linsolve = Array {typeof(linsolve1), 1} (undef, Threads . nthreads ( ))
1497+ linsolve = Array {typeof(linsolve1), 1} (undef, get_thread_count (alg ))
14921498 linsolve[1 ] = linsolve1
1493- for i in 2 : Threads . nthreads ( )
1499+ for i in 2 : get_thread_count (alg )
14941500 linprob = LinearProblem (W[i], _vec (linsolve_tmps[i]); u0 = _vec (k_tmps[i]))
14951501 linsolve[i] = init (linprob, alg. linsolve,
14961502 alias = LinearAliasSpecifier (alias_A = true , alias_b = true ))
@@ -1500,9 +1506,9 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
15001506 grad_config = build_grad_config (alg, f, tf, du1, t)
15011507 jac_config = build_jac_config (alg, f, uf, du1, uprev, u, du1, du2)
15021508
1503- diff1 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1504- diff2 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1505- for i in 1 : Threads . nthreads ( )
1509+ diff1 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1510+ diff2 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1511+ for i in 1 : get_thread_count (alg )
15061512 diff1[i] = zero (u)
15071513 diff2[i] = zero (u)
15081514 end
@@ -1627,10 +1633,10 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
16271633 utilde = zero (u)
16281634 u_temp1 = zero (u)
16291635 u_temp2 = zero (u)
1630- u_temp3 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1631- u_temp4 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1636+ u_temp3 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1637+ u_temp4 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
16321638
1633- for i in 1 : Threads . nthreads ( )
1639+ for i in 1 : get_thread_count (alg )
16341640 u_temp3[i] = zero (u)
16351641 u_temp4[i] = zero (u)
16361642 end
@@ -1642,8 +1648,8 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
16421648 res = uEltypeNoUnits .(zero (u))
16431649 fsalfirst = zero (rate_prototype)
16441650 k = zero (rate_prototype)
1645- k_tmps = Array {typeof(k), 1} (undef, Threads . nthreads ( ))
1646- for i in 1 : Threads . nthreads ( )
1651+ k_tmps = Array {typeof(k), 1} (undef, get_thread_count (alg ))
1652+ for i in 1 : get_thread_count (alg )
16471653 k_tmps[i] = zero (rate_prototype)
16481654 end
16491655
@@ -1661,9 +1667,9 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
16611667 W_el = zero (J)
16621668 end
16631669
1664- W = Array {typeof(W_el), 1} (undef, Threads . nthreads ( ))
1670+ W = Array {typeof(W_el), 1} (undef, get_thread_count (alg ))
16651671 W[1 ] = W_el
1666- for i in 2 : Threads . nthreads ( )
1672+ for i in 2 : get_thread_count (alg )
16671673 if W_el isa WOperator
16681674 W[i] = WOperator (f, dt, true )
16691675 else
@@ -1674,9 +1680,9 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
16741680 tf = TimeGradientWrapper (f, uprev, p)
16751681 uf = UJacobianWrapper (f, t, p)
16761682 linsolve_tmp = zero (rate_prototype)
1677- linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, Threads . nthreads ( ))
1683+ linsolve_tmps = Array {typeof(linsolve_tmp), 1} (undef, get_thread_count (alg ))
16781684
1679- for i in 1 : Threads . nthreads ( )
1685+ for i in 1 : get_thread_count (alg )
16801686 linsolve_tmps[i] = zero (rate_prototype)
16811687 end
16821688
@@ -1686,9 +1692,9 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
16861692 # Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
16871693 # Pr = Diagonal(_vec(weight)))
16881694
1689- linsolve = Array {typeof(linsolve1), 1} (undef, Threads . nthreads ( ))
1695+ linsolve = Array {typeof(linsolve1), 1} (undef, get_thread_count (alg ))
16901696 linsolve[1 ] = linsolve1
1691- for i in 2 : Threads . nthreads ( )
1697+ for i in 2 : get_thread_count (alg )
16921698 linprob = LinearProblem (W[i], _vec (linsolve_tmps[i]); u0 = _vec (k_tmps[i]))
16931699 linsolve[i] = init (linprob, alg. linsolve,
16941700 alias = LinearAliasSpecifier (alias_A = true , alias_b = true ))
@@ -1698,9 +1704,9 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
16981704 grad_config = build_grad_config (alg, f, tf, du1, t)
16991705 jac_config = build_jac_config (alg, f, uf, du1, uprev, u, du1, du2)
17001706
1701- diff1 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1702- diff2 = Array {typeof(u), 1} (undef, Threads . nthreads ( ))
1703- for i in 1 : Threads . nthreads ( )
1707+ diff1 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1708+ diff2 = Array {typeof(u), 1} (undef, get_thread_count (alg ))
1709+ for i in 1 : get_thread_count (alg )
17041710 diff1[i] = zero (u)
17051711 diff2[i] = zero (u)
17061712 end
0 commit comments