Skip to content

Commit fdc341a

Browse files
Merge pull request #2556 from mmesiti/multithreaded-ABM
Added threading to ABM algorithms using @..
2 parents 845ee9d + c56813f commit fdc341a

File tree

6 files changed

+242
-144
lines changed

6 files changed

+242
-144
lines changed

lib/OrdinaryDiffEqAdamsBashforthMoulton/src/adams_bashforth_moulton_caches.jl

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ get_fsalfirstlast(cache::ABMMutableCache, u) = (cache.fsalfirst, cache.k)
44
function get_fsalfirstlast(cache::ABMVariableCoefficientMutableCache, u)
55
(cache.fsalfirst, cache.k4)
66
end
7-
@cache mutable struct AB3Cache{uType, rateType} <: ABMMutableCache
7+
@cache mutable struct AB3Cache{uType, rateType, Thread} <: ABMMutableCache
88
u::uType
99
uprev::uType
1010
fsalfirst::rateType
@@ -14,6 +14,7 @@ end
1414
k::rateType
1515
tmp::uType
1616
step::Int
17+
thread::Thread
1718
end
1819

1920
@cache mutable struct AB3ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
@@ -32,7 +33,7 @@ function alg_cache(alg::AB3, u, rate_prototype, ::Type{uEltypeNoUnits},
3233
ralk2 = zero(rate_prototype)
3334
k = zero(rate_prototype)
3435
tmp = zero(u)
35-
AB3Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1)
36+
AB3Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1, alg.thread)
3637
end
3738

3839
function alg_cache(alg::AB3, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -44,7 +45,7 @@ function alg_cache(alg::AB3, u, rate_prototype, ::Type{uEltypeNoUnits},
4445
AB3ConstantCache(k2, k3, 1)
4546
end
4647

47-
@cache mutable struct ABM32Cache{uType, rateType} <: ABMMutableCache
48+
@cache mutable struct ABM32Cache{uType, rateType, Thread} <: ABMMutableCache
4849
u::uType
4950
uprev::uType
5051
fsalfirst::rateType
@@ -54,6 +55,7 @@ end
5455
k::rateType
5556
tmp::uType
5657
step::Int
58+
thread::Thread
5759
end
5860

5961
@cache mutable struct ABM32ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
@@ -72,7 +74,7 @@ function alg_cache(alg::ABM32, u, rate_prototype, ::Type{uEltypeNoUnits},
7274
ralk2 = zero(rate_prototype)
7375
k = zero(rate_prototype)
7476
tmp = zero(u)
75-
ABM32Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1)
77+
ABM32Cache(u, uprev, fsalfirst, k2, k3, ralk2, k, tmp, 1, alg.thread)
7678
end
7779

7880
function alg_cache(alg::ABM32, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -84,7 +86,7 @@ function alg_cache(alg::ABM32, u, rate_prototype, ::Type{uEltypeNoUnits},
8486
ABM32ConstantCache(k2, k3, 1)
8587
end
8688

87-
@cache mutable struct AB4Cache{uType, rateType} <: ABMMutableCache
89+
@cache mutable struct AB4Cache{uType, rateType, Thread} <: ABMMutableCache
8890
u::uType
8991
uprev::uType
9092
fsalfirst::rateType
@@ -98,6 +100,7 @@ end
98100
t3::rateType
99101
t4::rateType
100102
step::Int
103+
thread::Thread
101104
end
102105

103106
@cache mutable struct AB4ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
@@ -121,7 +124,7 @@ function alg_cache(alg::AB4, u, rate_prototype, ::Type{uEltypeNoUnits},
121124
t2 = zero(rate_prototype)
122125
t3 = zero(rate_prototype)
123126
t4 = zero(rate_prototype)
124-
AB4Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k, tmp, t2, t3, t4, 1)
127+
AB4Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k, tmp, t2, t3, t4, 1, alg.thread)
125128
end
126129

127130
function alg_cache(alg::AB4, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -134,7 +137,7 @@ function alg_cache(alg::AB4, u, rate_prototype, ::Type{uEltypeNoUnits},
134137
AB4ConstantCache(k2, k3, k4, 1)
135138
end
136139

137-
@cache mutable struct ABM43Cache{uType, rateType} <: ABMMutableCache
140+
@cache mutable struct ABM43Cache{uType, rateType, Thread} <: ABMMutableCache
138141
u::uType
139142
uprev::uType
140143
fsalfirst::rateType
@@ -151,6 +154,7 @@ end
151154
t6::rateType
152155
t7::rateType
153156
step::Int
157+
thread::Thread
154158
end
155159

156160
@cache mutable struct ABM43ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
@@ -177,7 +181,8 @@ function alg_cache(alg::ABM43, u, rate_prototype, ::Type{uEltypeNoUnits},
177181
t5 = zero(rate_prototype)
178182
t6 = zero(rate_prototype)
179183
t7 = zero(rate_prototype)
180-
ABM43Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k, tmp, t2, t3, t4, t5, t6, t7, 1)
184+
ABM43Cache(u, uprev, fsalfirst, k2, k3, k4, ralk2, k,
185+
tmp, t2, t3, t4, t5, t6, t7, 1, alg.thread)
181186
end
182187

183188
function alg_cache(alg::ABM43, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -190,7 +195,7 @@ function alg_cache(alg::ABM43, u, rate_prototype, ::Type{uEltypeNoUnits},
190195
ABM43ConstantCache(k2, k3, k4, 1)
191196
end
192197

193-
@cache mutable struct AB5Cache{uType, rateType} <: ABMMutableCache
198+
@cache mutable struct AB5Cache{uType, rateType, Thread} <: ABMMutableCache
194199
u::uType
195200
uprev::uType
196201
fsalfirst::rateType
@@ -204,6 +209,7 @@ end
204209
t3::rateType
205210
t4::rateType
206211
step::Int
212+
thread::Thread
207213
end
208214

209215
@cache mutable struct AB5ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
@@ -228,7 +234,7 @@ function alg_cache(alg::AB5, u, rate_prototype, ::Type{uEltypeNoUnits},
228234
t2 = zero(rate_prototype)
229235
t3 = zero(rate_prototype)
230236
t4 = zero(rate_prototype)
231-
AB5Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp, t2, t3, t4, 1)
237+
AB5Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp, t2, t3, t4, 1, alg.thread)
232238
end
233239

234240
function alg_cache(alg::AB5, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -242,7 +248,7 @@ function alg_cache(alg::AB5, u, rate_prototype, ::Type{uEltypeNoUnits},
242248
AB5ConstantCache(k2, k3, k4, k5, 1)
243249
end
244250

245-
@cache mutable struct ABM54Cache{uType, rateType} <: ABMMutableCache
251+
@cache mutable struct ABM54Cache{uType, rateType, Thread} <: ABMMutableCache
246252
u::uType
247253
uprev::uType
248254
fsalfirst::rateType
@@ -260,6 +266,7 @@ end
260266
t7::rateType
261267
t8::rateType
262268
step::Int
269+
thread::Thread
263270
end
264271

265272
@cache mutable struct ABM54ConstantCache{rateType} <: OrdinaryDiffEqConstantCache
@@ -288,7 +295,8 @@ function alg_cache(alg::ABM54, u, rate_prototype, ::Type{uEltypeNoUnits},
288295
t6 = zero(rate_prototype)
289296
t7 = zero(rate_prototype)
290297
t8 = zero(rate_prototype)
291-
ABM54Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp, t2, t3, t4, t5, t6, t7, t8, 1)
298+
ABM54Cache(u, uprev, fsalfirst, k2, k3, k4, k5, k, tmp,
299+
t2, t3, t4, t5, t6, t7, t8, 1, alg.thread)
292300
end
293301

294302
function alg_cache(alg::ABM54, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -317,7 +325,7 @@ end
317325
end
318326

319327
@cache mutable struct VCAB3Cache{uType, rateType, TabType, bs3Type, tArrayType, cArrayType,
320-
uNoUnitsType, coefType, dtArrayType} <:
328+
uNoUnitsType, coefType, dtArrayType, Thread} <:
321329
ABMVariableCoefficientMutableCache
322330
u::uType
323331
uprev::uType
@@ -337,6 +345,7 @@ end
337345
utilde::uType
338346
tab::TabType
339347
step::Int
348+
thread::Thread
340349
end
341350

342351
function alg_cache(alg::VCAB3, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -395,7 +404,7 @@ function alg_cache(alg::VCAB3, u, rate_prototype, ::Type{uEltypeNoUnits},
395404
tmp = zero(u)
396405
utilde = zero(u)
397406
VCAB3Cache(u, uprev, fsalfirst, bs3cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕstar_n, β,
398-
order, atmp, tmp, utilde, tab, 1)
407+
order, atmp, tmp, utilde, tab, 1, alg.thread)
399408
end
400409

401410
@cache mutable struct VCAB4ConstantCache{rk4constcache, tArrayType, rArrayType, cArrayType,
@@ -413,7 +422,7 @@ end
413422
end
414423

415424
@cache mutable struct VCAB4Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
416-
uNoUnitsType, coefType, dtArrayType} <:
425+
uNoUnitsType, coefType, dtArrayType, Thread} <:
417426
ABMVariableCoefficientMutableCache
418427
u::uType
419428
uprev::uType
@@ -432,6 +441,7 @@ end
432441
tmp::uType
433442
utilde::uType
434443
step::Int
444+
thread::Thread
435445
end
436446

437447
function alg_cache(alg::VCAB4, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -489,7 +499,7 @@ function alg_cache(alg::VCAB4, u, rate_prototype, ::Type{uEltypeNoUnits},
489499
tmp = zero(u)
490500
utilde = zero(u)
491501
VCAB4Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕstar_n, β,
492-
order, atmp, tmp, utilde, 1)
502+
order, atmp, tmp, utilde, 1, alg.thread)
493503
end
494504

495505
# VCAB5
@@ -509,7 +519,7 @@ end
509519
end
510520

511521
@cache mutable struct VCAB5Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
512-
uNoUnitsType, coefType, dtArrayType} <:
522+
uNoUnitsType, coefType, dtArrayType, Thread} <:
513523
ABMVariableCoefficientMutableCache
514524
u::uType
515525
uprev::uType
@@ -528,6 +538,7 @@ end
528538
tmp::uType
529539
utilde::uType
530540
step::Int
541+
thread::Thread
531542
end
532543

533544
function alg_cache(alg::VCAB5, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -585,7 +596,7 @@ function alg_cache(alg::VCAB5, u, rate_prototype, ::Type{uEltypeNoUnits},
585596
tmp = zero(u)
586597
utilde = zero(u)
587598
VCAB5Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕstar_n, β,
588-
order, atmp, tmp, utilde, 1)
599+
order, atmp, tmp, utilde, 1, alg.thread)
589600
end
590601

591602
# VCABM3
@@ -607,7 +618,7 @@ end
607618

608619
@cache mutable struct VCABM3Cache{
609620
uType, rateType, TabType, bs3Type, tArrayType, cArrayType,
610-
uNoUnitsType, coefType, dtArrayType} <:
621+
uNoUnitsType, coefType, dtArrayType, Thread} <:
611622
ABMVariableCoefficientMutableCache
612623
u::uType
613624
uprev::uType
@@ -628,6 +639,7 @@ end
628639
utilde::uType
629640
tab::TabType
630641
step::Int
642+
thread::Thread
631643
end
632644

633645
function alg_cache(alg::VCABM3, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -691,7 +703,7 @@ function alg_cache(alg::VCABM3, u, rate_prototype, ::Type{uEltypeNoUnits},
691703
tmp = zero(u)
692704
utilde = zero(u)
693705
VCABM3Cache(u, uprev, fsalfirst, bs3cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1,
694-
ϕstar_n, β, order, atmp, tmp, utilde, tab, 1)
706+
ϕstar_n, β, order, atmp, tmp, utilde, tab, 1, alg.thread)
695707
end
696708

697709
# VCABM4
@@ -713,7 +725,7 @@ end
713725
end
714726

715727
@cache mutable struct VCABM4Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
716-
uNoUnitsType, coefType, dtArrayType} <:
728+
uNoUnitsType, coefType, dtArrayType, Thread} <:
717729
ABMVariableCoefficientMutableCache
718730
u::uType
719731
uprev::uType
@@ -733,6 +745,7 @@ end
733745
tmp::uType
734746
utilde::uType
735747
step::Int
748+
thread::Thread
736749
end
737750

738751
function alg_cache(alg::VCABM4, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -796,7 +809,7 @@ function alg_cache(alg::VCABM4, u, rate_prototype, ::Type{uEltypeNoUnits},
796809
tmp = zero(u)
797810
utilde = zero(u)
798811
VCABM4Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1,
799-
ϕstar_n, β, order, atmp, tmp, utilde, 1)
812+
ϕstar_n, β, order, atmp, tmp, utilde, 1, alg.thread)
800813
end
801814

802815
# VCABM5
@@ -818,7 +831,7 @@ end
818831
end
819832

820833
@cache mutable struct VCABM5Cache{uType, rateType, rk4cacheType, tArrayType, cArrayType,
821-
uNoUnitsType, coefType, dtArrayType} <:
834+
uNoUnitsType, coefType, dtArrayType, Thread} <:
822835
ABMVariableCoefficientMutableCache
823836
u::uType
824837
uprev::uType
@@ -838,6 +851,7 @@ end
838851
tmp::uType
839852
utilde::uType
840853
step::Int
854+
thread::Thread
841855
end
842856

843857
function alg_cache(alg::VCABM5, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -901,7 +915,7 @@ function alg_cache(alg::VCABM5, u, rate_prototype, ::Type{uEltypeNoUnits},
901915
tmp = zero(u)
902916
utilde = zero(u)
903917
VCABM5Cache(u, uprev, fsalfirst, rk4cache, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1,
904-
ϕstar_n, β, order, atmp, tmp, utilde, 1)
918+
ϕstar_n, β, order, atmp, tmp, utilde, 1, alg.thread)
905919
end
906920

907921
# VCABM
@@ -924,7 +938,7 @@ end
924938
end
925939

926940
@cache mutable struct VCABMCache{uType, rateType, dtType, tArrayType, cArrayType,
927-
uNoUnitsType, coefType, dtArrayType} <:
941+
uNoUnitsType, coefType, dtArrayType, Thread} <:
928942
ABMVariableCoefficientMutableCache
929943
u::uType
930944
uprev::uType
@@ -952,6 +966,7 @@ end
952966
atmpm2::uNoUnitsType
953967
atmpp1::uNoUnitsType
954968
step::Int
969+
thread::Thread
955970
end
956971

957972
function alg_cache(alg::VCABM, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -1023,5 +1038,5 @@ function alg_cache(alg::VCABM, u, rate_prototype, ::Type{uEltypeNoUnits},
10231038
VCABMCache(
10241039
u, uprev, fsalfirst, k4, ϕstar_nm1, dts, c, g, ϕ_n, ϕ_np1, ϕstar_n, β, order,
10251040
max_order, atmp, tmp, ξ, ξ0, utilde, utildem1, utildem2, utildep1, atmpm1,
1026-
atmpm2, atmpp1, 1)
1041+
atmpm2, atmpp1, 1, alg.thread)
10271042
end

0 commit comments

Comments
 (0)