@@ -14,26 +14,49 @@ function SciMLBase.init(
1414 backend = opt. backend
1515 @assert prob. u0 isa SArray
1616
17- # # initialize cache
18-
1917 # # Bounds check
2018 lb, ub = check_init_bounds (prob)
2119 lb, ub = check_init_bounds (prob)
2220 prob = remake (prob; lb = lb, ub = ub)
2321
24- init_gbest, particles = init_particles (prob, opt, typeof (prob. u0))
22+ particles = KernelAbstractions. allocate (
23+ backend, SPSOParticle{typeof (prob. u0), eltype (typeof (prob. u0))}, opt. num_particles)
24+ kernel! = gpu_init_particles! (backend)
2525
26- # TODO : Do the equivalent of cu()/roc()
27- particles_eltype = eltype (particles) === Float64 ? Float32 : eltype (particles)
28- gpu_particles = KernelAbstractions. allocate (backend,
29- particles_eltype,
30- size (particles))
31- copyto! (gpu_particles, particles)
32- gpu_init_gbest = KernelAbstractions. allocate (backend, typeof (init_gbest), (1 ,))
33- copyto! (gpu_init_gbest, [init_gbest])
26+ kernel! (particles, prob, opt, typeof (prob. u0); ndrange = opt. num_particles)
27+
28+ best_particle = minimum (particles)
29+ _init_gbest = SPSOGBest (best_particle. best_position, best_particle. best_cost)
30+
31+ init_gbest = KernelAbstractions. allocate (backend, typeof (_init_gbest), (1 ,))
32+ copyto! (init_gbest, [_init_gbest])
3433 return PSOCache{
35- typeof (prob), typeof (opt), typeof (gpu_particles), typeof (gpu_init_gbest)}(
36- prob, opt, gpu_particles, gpu_init_gbest)
34+ typeof (prob), typeof (opt), typeof (particles), typeof (init_gbest)}(
35+ prob, opt, particles, init_gbest)
36+ end
37+
38+ function SciMLBase. reinit! (cache:: PSOCache ; kwargs... )
39+ reinit_cache! (cache, cache. alg)
40+ end
41+
42+ function reinit_cache! (cache:: PSOCache , opt:: ParallelPSOKernel )
43+ prob = cache. prob
44+ backend = opt. backend
45+ particles = cache. particles
46+
47+ kernel! = PSOGPU. gpu_init_particles! (backend)
48+ kernel! (particles, prob, opt, typeof (prob. u0); ndrange = opt. num_particles)
49+
50+ best_particle = minimum (particles)
51+ _init_gbest = SPSOGBest (best_particle. best_position, best_particle. best_cost)
52+
53+ copyto! (cache. gbest, [_init_gbest])
54+
55+ return nothing
56+ end
57+
58+ function SciMLBase. solve! (cache, args... ; maxiters = 100 , kwargs... )
59+ solve! (cache, cache. alg, args... ; maxiters, kwargs... )
3760end
3861
3962function SciMLBase. solve! (
0 commit comments