Skip to content

Commit 4305d3a

Browse files
authored
Update backend in lbfgs.jl
1 parent 542598b commit 4305d3a

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

test/lbfgs.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
using PSOGPU, Optimization, CUDA, StaticArrays
1+
using PSOGPU, Optimization, StaticArrays
2+
3+
DEVICE = get(ENV, "GROUP", "CUDA")
4+
5+
@eval using $(Symbol(DEVICE))
6+
7+
if DEVICE == "CUDA"
8+
backend = CUDABackend()
9+
elseif DEVICE == "AMDGPU"
10+
backend = ROCBackend()
11+
end
212

313
function objf(x, p)
414
return 1 - x[1]^2 - x[2]^2
@@ -28,17 +38,17 @@ l0 = rosenbrock(x0, p)
2838
maxiters = 20)
2939
@show sol.objective
3040
@time sol = Optimization.solve(prob,
31-
PSOGPU.ParallelPSOKernel(100, backend = CUDABackend()),
41+
PSOGPU.ParallelPSOKernel(100; backend),
3242
maxiters = 100)
3343
@show sol.objective
3444

3545
@time sol = Optimization.solve(prob,
36-
PSOGPU.HybridPSO(; backend = CUDABackend()),
46+
PSOGPU.HybridPSO(; backend),
3747
maxiters = 30)
3848
@show sol.objective
3949

4050
@time sol = Optimization.solve(prob,
41-
PSOGPU.HybridPSO(; local_opt = PSOGPU.BFGS(), backend = CUDABackend()),
51+
PSOGPU.HybridPSO(; local_opt = PSOGPU.BFGS(), backend = backend),
4252
maxiters = 30)
4353
@show sol.objective
4454

@@ -51,16 +61,16 @@ l0 = rosenbrock(x0, p)
5161
maxiters = 20)
5262
@show sol.objective
5363
@time sol = Optimization.solve(prob,
54-
PSOGPU.ParallelPSOKernel(100, backend = CUDABackend()),
64+
PSOGPU.ParallelPSOKernel(100, backend = backend),
5565
maxiters = 100)
5666
@show sol.objective
5767

5868
@time sol = Optimization.solve(prob,
59-
PSOGPU.HybridPSO(; backend = CUDABackend()),
69+
PSOGPU.HybridPSO(; backend = backend),
6070
local_maxiters = 30)
6171
@show sol.objective
6272

6373
@time sol = Optimization.solve(prob,
64-
PSOGPU.HybridPSO(; local_opt = PSOGPU.BFGS(), backend = CUDABackend()),
74+
PSOGPU.HybridPSO(; local_opt = PSOGPU.BFGS(), backend = backend),
6575
local_maxiters = 30)
6676
@show sol.objective

0 commit comments

Comments
 (0)