Skip to content

Commit 0b4fc8f

Browse files
committed
Don't use cu
1 parent 864da4f commit 0b4fc8f

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/PSOGPU.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,18 @@ function SciMLBase.__solve(prob::OptimizationProblem,
6363
gbest = pso_solve_cpu!(prob, init_gbest, particles; kwargs...)
6464
end
6565
else
66+
backend = opt.backend
67+
init_gbest, particles = init_particles(prob, opt.num_particles)
68+
# TODO: Do the equivalent of cu()/roc()
69+
particles_eltype = eltype(particles) === Float64 ? Float32 : eltype(particles)
70+
gpu_particles = KernelAbstractions.allocate(backend, particles_eltype, size(particles))
71+
copyto!(gpu_particles, particles)
72+
gpu_init_gbest = KernelAbstractions.allocate(backend, typeof(init_gbest), (1,))
73+
copyto!(gpu_init_gbest, [init_gbest])
6674
if opt.async
67-
init_gbest, particles = init_particles(prob, opt.num_particles)
68-
gpu_particles = cu(particles)
69-
init_gbest = cu([init_gbest])
70-
gbest = pso_solve_async_gpu!(prob, init_gbest, gpu_particles; kwargs...)
75+
gbest = pso_solve_async_gpu!(prob, gpu_init_gbest, gpu_particles; kwargs...)
7176
else
72-
init_gbest, particles = init_particles(prob, opt.num_particles)
73-
gpu_particles = cu(particles)
74-
init_gbest = cu([init_gbest])
75-
gbest = pso_solve_gpu!(prob, init_gbest, gpu_particles; kwargs...)
77+
gbest = pso_solve_gpu!(prob, gpu_init_gbest, gpu_particles; kwargs...)
7678
end
7779
end
7880

0 commit comments

Comments
 (0)