@@ -83,73 +83,6 @@ function init_states(model, nprt_per_rank::Int, n_tasks::Int, rng::AbstractRNG)
8383end
8484
8585function copy_states!(
86- particles:: AbstractMatrix{T} ,
87- buffer:: AbstractMatrix{T} ,
88- resampling_indices:: Vector{Int} ,
89- my_rank:: Int ,
90- nprt_per_rank:: Int ,
91- to:: TimerOutputs.TimerOutput = TimerOutputs. TimerOutput(),
92- dedup:: Bool = false
93- ) where T
94-
95- if dedup
96- return copy_states_dedup!(particles, buffer, resampling_indices, my_rank, nprt_per_rank, to)
97- end
98-
99- # These are the particle indices stored on this rank
100- particles_have = my_rank * nprt_per_rank + 1 : (my_rank + 1 ) * nprt_per_rank
101-
102- # These are the particle indices this rank should have after resampling
103- particles_want = resampling_indices[particles_have]
104-
105- # These are the ranks that have the particles this rank should have
106- rank_has = floor.(Int, (particles_want .- 1 ) / nprt_per_rank)
107-
108- # We could work out how many sends and receives we have to do and allocate
109- # this appropriately but, lazy
110- reqs = Vector{MPI. Request}(undef, 0 )
111-
112- # Send particles to processes that want them
113- @timeit_debug to " send loop" begin
114- for (k,id) in enumerate(resampling_indices)
115- rank_wants = floor(Int, (k - 1 ) / nprt_per_rank)
116- if id in particles_have && rank_wants != my_rank
117- local_id = id - my_rank * nprt_per_rank
118- req = MPI. Isend(view(particles, :, local_id), rank_wants, id, MPI. COMM_WORLD)
119- push!(reqs, req)
120- end
121- end
122- end
123-
124- # Receive particles this rank wants from ranks that have them
125- # If I already have them, just do a local copy
126- # Receive into a buffer so we dont accidentally overwrite stuff
127- @timeit_debug to " receive loop" begin
128- for (k,proc,id) in zip(1 : nprt_per_rank, rank_has, particles_want)
129- if proc == my_rank
130- @timeit_debug to " local copy" begin
131- local_id = id - my_rank * nprt_per_rank
132- buffer[:, k] .= view(particles, :, local_id)
133- end
134- else
135- @timeit_debug to " remote receive" begin
136- req = MPI. Irecv!(view(buffer, :, k), proc, id, MPI. COMM_WORLD)
137- push!(reqs,req)
138- end
139- end
140- end
141- end
142-
143- # Wait for all comms to complete
144- @timeit_debug to " waitall phase" MPI. Waitall(reqs)
145-
146- @timeit_debug to " buffer write-back" particles .= buffer
147-
148- end
149-
150- # An optimized version of copy_states that minimizes the number of messages sent
151- # by deduplicating particles that need to be sent between ranks.
152- function copy_states_dedup!(
15386 particles:: AbstractMatrix{T} ,
15487 buffer:: AbstractMatrix{T} ,
15588 resampling_indices:: Vector{Int} ,
@@ -243,7 +176,7 @@ function _determine_sends(resampling_indices::Vector{Int}, my_rank::Int, nprt_pe
243176 return sends_to
244177end
245178
246- function _categorize_wants(particles_want, my_rank:: Int , nprt_per_rank:: Int )
179+ function _categorize_wants(particles_want:: Vector{Int} , my_rank:: Int , nprt_per_rank:: Int )
247180 local_copies = Dict{Int, Vector{Int}}()
248181 remote_copies = Dict{Int, Vector{Int}}()
249182
0 commit comments