@@ -63,16 +63,18 @@ function SciMLBase.__solve(prob::OptimizationProblem,
6363 gbest = pso_solve_cpu! (prob, init_gbest, particles; kwargs... )
6464 end
6565 else
66+ backend = opt. backend
67+ init_gbest, particles = init_particles (prob, opt. num_particles)
68+ # TODO : Do the equivalent of cu()/roc()
69+ particles_eltype = eltype (particles) === Float64 ? Float32 : eltype (particles)
70+ gpu_particles = KernelAbstractions. allocate (backend, particles_eltype, size (particles))
71+ copyto! (gpu_particles, particles)
72+ gpu_init_gbest = KernelAbstractions. allocate (backend, typeof (init_gbest), (1 ,))
73+ copyto! (gpu_init_gbest, [init_gbest])
6674 if opt. async
67- init_gbest, particles = init_particles (prob, opt. num_particles)
68- gpu_particles = cu (particles)
69- init_gbest = cu ([init_gbest])
70- gbest = pso_solve_async_gpu! (prob, init_gbest, gpu_particles; kwargs... )
75+ gbest = pso_solve_async_gpu! (prob, gpu_init_gbest, gpu_particles; kwargs... )
7176 else
72- init_gbest, particles = init_particles (prob, opt. num_particles)
73- gpu_particles = cu (particles)
74- init_gbest = cu ([init_gbest])
75- gbest = pso_solve_gpu! (prob, init_gbest, gpu_particles; kwargs... )
77+ gbest = pso_solve_gpu! (prob, gpu_init_gbest, gpu_particles; kwargs... )
7678 end
7779 end
7880
0 commit comments