Skip to content

Commit c3762e0

Browse files
committed
Update benchmarks
1 parent 9f83adf commit c3762e0

File tree

10 files changed

+2213
-45
lines changed

10 files changed

+2213
-45
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.0.0-DEV"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
DiffEqGPU = "071ae1c0-96b5-11e9-1965-c90190d839ea"
910
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1011
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

benchmarks/Fitzhugh_Nagumo/Manifest.toml

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.0"
44
manifest_format = "2.0"
5-
project_hash = "64f611d6591a3e31f2740258c68d607684b60a0f"
5+
project_hash = "5385867d5ac9adc9babdf0cd3ced91c0f6a2076f"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245"
@@ -150,6 +150,12 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
150150
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
151151
version = "0.1.1"
152152

153+
[[deps.BenchmarkTools]]
154+
deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"]
155+
git-tree-sha1 = "f1f03a9fa24271160ed7e73051fba3c1a759b53f"
156+
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
157+
version = "1.4.0"
158+
153159
[[deps.BitFlags]]
154160
git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b"
155161
uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35"
@@ -358,9 +364,9 @@ uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
358364
version = "4.1.1"
359365

360366
[[deps.DataAPI]]
361-
git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c"
367+
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
362368
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
363-
version = "1.15.0"
369+
version = "1.16.0"
364370

365371
[[deps.DataDeps]]
366372
deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"]
@@ -785,16 +791,22 @@ version = "1.0.0"
785791

786792
[[deps.JLD2]]
787793
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"]
788-
git-tree-sha1 = "315b508ec5df53936532097ffe6e5deacbf41861"
794+
git-tree-sha1 = "7c0008f0b7622c6c0ee5c65cbc667b69f8a65672"
789795
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
790-
version = "0.4.44"
796+
version = "0.4.45"
791797

792798
[[deps.JLLWrappers]]
793799
deps = ["Artifacts", "Preferences"]
794800
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
795801
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
796802
version = "1.5.0"
797803

804+
[[deps.JSON]]
805+
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
806+
git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a"
807+
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
808+
version = "0.21.4"
809+
798810
[[deps.JSON3]]
799811
deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"]
800812
git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b"
@@ -1215,10 +1227,10 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
12151227
version = "1.2.0"
12161228

12171229
[[deps.NonlinearSolve]]
1218-
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "EnumX", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"]
1219-
git-tree-sha1 = "72b036b728461272ae1b1c3f7096cb4c319d8793"
1230+
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "TimerOutputs"]
1231+
git-tree-sha1 = "78bdd3a4a62865cf43c53d63783b0cbfddcdbbe6"
12201232
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1221-
version = "3.4.0"
1233+
version = "3.5.0"
12221234

12231235
[deps.NonlinearSolve.extensions]
12241236
NonlinearSolveBandedMatricesExt = "BandedMatrices"
@@ -1295,9 +1307,9 @@ version = "0.5.5+0"
12951307

12961308
[[deps.Optimization]]
12971309
deps = ["ADTypes", "ArrayInterface", "ConsoleProgressMonitor", "DocStringExtensions", "LinearAlgebra", "Logging", "LoggingExtras", "Pkg", "Printf", "ProgressLogging", "Reexport", "SciMLBase", "SparseArrays", "SymbolicIndexingInterface", "TerminalLoggers"]
1298-
git-tree-sha1 = "31fe8abda56b2168d262b515ecabdb44d7c36b4d"
1310+
git-tree-sha1 = "e24a89f3f15fd4beff32a12bde4310768f47c5bc"
12991311
uuid = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1300-
version = "3.21.1"
1312+
version = "3.21.2"
13011313

13021314
[deps.Optimization.extensions]
13031315
OptimizationEnzymeExt = "Enzyme"
@@ -1447,6 +1459,10 @@ version = "0.5.5"
14471459
deps = ["Unicode"]
14481460
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
14491461

1462+
[[deps.Profile]]
1463+
deps = ["Printf"]
1464+
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
1465+
14501466
[[deps.ProgressLogging]]
14511467
deps = ["Logging", "SHA", "UUIDs"]
14521468
git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
34
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
45
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
56
PSOGPU = "ab63da0c-63b4-40fa-a3b7-d2cba5be6419"
7+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
1+
using Pkg
2+
Pkg.activate(@__DIR__)
13
using PSOGPU, OrdinaryDiffEq, StaticArrays
24

35
function f(u, p, t)
4-
a,b,τinv,l = p
5-
v,w = u
6-
dv = v - v^3/3 -w + l
7-
dw = τinv*(v + a - b*w)
6+
a, b, τinv, l = p
7+
v, w = u
8+
dv = v - v^3 / 3 - w + l
9+
dw = τinv * (v + a - b * w)
810
return SVector{2}(dv, dw)
911
end
1012

11-
p = @SArray [0.7f0,0.8f0,0.08f0,0.5f0] # Parameters used to construct the dataset
13+
p = @SArray [0.7f0, 0.8f0, 0.08f0, 0.5f0] # Parameters used to construct the dataset
1214
r0 = @SArray [1.0f0; 1.0f0] # initial value
1315
tspan = (0.0f0, 30.0f0) # sample of 3000 observations over the (0,30) timespan
14-
prob = ODEProblem(f, r0, tspan,p)
16+
prob = ODEProblem(f, r0, tspan, p)
1517

1618
tspan2 = (0.0f0, 3.0f0) # sample of 300 observations with a timestep of 0.01
17-
prob_short = ODEProblem(f, r0, tspan2,p)
19+
prob_short = ODEProblem(f, r0, tspan2, p)
1820

19-
dt = 30.0f0/3000f0
21+
dt = 30.0f0 / 3000.0f0
2022
tf = 30.0f0
21-
tinterval = 0f0:dt:tf
22-
t = collect(tinterval)
23+
tinterval = 0.0f0:dt:tf
24+
t = collect(tinterval)
2325

2426
h = 0.01f0
2527
M = 300
@@ -28,20 +30,23 @@ tstop = tstart + M * h
2830
tinterval_short = 0:h:tstop
2931
t_short = collect(tinterval_short)
3032

31-
data_sol_short = solve(prob_short,Vern9(),saveat=t_short,reltol=1f-6,abstol=1f-6)
33+
data_sol_short = solve(prob_short,
34+
Vern9(),
35+
saveat = t_short,
36+
reltol = 1.0f-6,
37+
abstol = 1.0f-6)
3238
data_short = convert(Array, data_sol_short) # This operation produces column major dataset obs as columns, equations as rows
33-
data_sol = solve(prob,Vern9(),saveat=t,reltol=1f-6,abstol=1f-6)
39+
data_sol = solve(prob, Vern9(), saveat = t, reltol = 1.0f-6, abstol = 1.0f-6)
3440
data = convert(Array, data_sol)
3541

36-
using Plots
42+
# using Plots
3743

38-
plot(data_sol_short)
44+
# plot(data_sol_short)
3945

40-
plot(data_sol)
46+
# plot(data_sol)
4147

4248
n_particles = 10_000
4349

44-
4550
# obj_short = build_loss_objective(prob_short,Tsit5(),L2Loss(t_short,data_short),tstops=t_short)
4651
function loss(u, p)
4752
odeprob, t = p
@@ -50,8 +55,8 @@ function loss(u, p)
5055
sum(abs2, data_short .- pred)
5156
end
5257

53-
lb = @SArray fill(0.f0, 4)
54-
ub = @SArray fill(5.f0, 4)
58+
lb = @SArray fill(0.0f0, 4)
59+
ub = @SArray fill(5.0f0, 4)
5560

5661
optprob = OptimizationProblem(loss, prob.p, (prob, t_short); lb = lb, ub = ub)
5762

@@ -68,20 +73,30 @@ gpu_particles = cu(particles)
6873

6974
CUDA.allowscalar(false)
7075

71-
function prob_func(prob, gpu_particle)
72-
return remake(prob, p = (prob.p[1], gpu_particle.position))
73-
end
74-
7576
using Adapt
7677

7778
losses = adapt(CUDABackend(), ones(eltype(prob.u0), (1, n_particles)))
7879

7980
solver_cache = (; losses, gpu_particles, gpu_data, gbest)
8081

82+
adaptive = false
83+
8184
@time gsol = PSOGPU.parameter_estim_ode!(prob,
8285
solver_cache,
8386
lb,
84-
ub;
87+
ub, Val(adaptive);
8588
saveat = t_short,
8689
dt = 0.1f0,
87-
maxiters = 1)
90+
maxiters = 100)
91+
92+
using BenchmarkTools
93+
94+
@benchmark PSOGPU.parameter_estim_ode!($prob,
95+
$(deepcopy(solver_cache)),
96+
$lb,
97+
$ub, $Val(adaptive);
98+
saveat = t_short,
99+
dt = 0.1f0,
100+
maxiters = 100)
101+
102+
@show gbest.cost, gsol

0 commit comments

Comments
 (0)