1+ using ExactOptimalTransport
12
23function normalized_exp!(weight:: AbstractVector )
34 weight .- = maximum(weight)
45 @. weight = exp(weight)
56 weight ./= sum(weight)
67end
78
9+ function optimized_resample!(resampled_indices:: AbstractVector{Int} , nrank:: Int )
10+ nprt_per_rank = length(resampled_indices) ÷ nrank
11+ stock_queue = [Int[] for _ in 1 : nrank]
12+
13+ # Assign each resampled index to its corresponding rank
14+ for resampled_idx in resampled_indices
15+ rank = div(resampled_idx - 1 , nprt_per_rank) + 1
16+ push!(stock_queue[rank], resampled_idx)
17+ end
18+
19+ supply_vector = [length(stock_queue[rank]) for rank in 1 : nrank]
20+ demand_vector = fill(nprt_per_rank, nrank)
21+ cost_matrix = ones(Int, nrank, nrank)
22+ for i in 1 : nrank
23+ cost_matrix[i, i] = 0
24+ end
25+
26+ γ = emd(supply_vector, demand_vector, cost_matrix)
27+
28+ # update resampled_indices
29+ for i in 1 : nrank
30+ idx = 1
31+ for j in 1 : nrank
32+ nmove = Int(γ[j, i])
33+ for _ in 1 : nmove
34+ resampled_indices[(i - 1 ) * nprt_per_rank + idx] = popfirst!(stock_queue[j])
35+ idx += 1
36+ end
37+ end
38+ end
39+ return resampled_indices
40+ end
41+
42+
843# Resample particles from given weights using Stochastic Universal Sampling
944function resample!(
1045 resampled_indices:: AbstractVector{Int} ,
1146 weights:: AbstractVector{T} ,
12- rng:: Random.AbstractRNG = Random. TaskLocalRNG()
47+ rng:: Random.AbstractRNG = Random. TaskLocalRNG(),
48+ optimize_resample:: Bool = false ,
49+ nrank:: Int = 1 ,
1350) where T
1451
1552 nprt = length(weights)
@@ -28,6 +65,10 @@ function resample!(
2865 end
2966 resampled_indices[ip] = k
3067 end
68+
69+ if optimize_resample
70+ resampled_indices .= optimized_resample!(resampled_indices, nrank)
71+ end
3172end
3273
3374function init_states(model, nprt_per_rank:: Int , n_tasks:: Int , rng:: AbstractRNG )
@@ -174,6 +215,7 @@ function copy_states_dedup!(
174215 for (id, buffer_indices) in remote_copies
175216 if length(buffer_indices) > 1
176217 source_view = view(buffer, :, buffer_indices[1 ])
218+ # TODO : threading in chunks
177219 Threads. @threads for i in 2 : length(buffer_indices)
178220 k = buffer_indices[i]
179221 buffer[:, k] .= source_view
@@ -182,8 +224,12 @@ function copy_states_dedup!(
182224 end
183225 end
184226
185- @timeit_debug to " write from buffer" particles .= buffer
186-
227+ @timeit_debug to " write from buffer" begin
228+ Threads. @threads for j in 1 : size(particles, 2 )
229+ # @views creates a non-allocating view of the column, which is faster inside a loop
230+ @views particles[:, j] .= buffer[:, j]
231+ end
232+ end
187233end
188234
189235function _determine_sends(resampling_indices:: Vector{Int} , my_rank:: Int , nprt_per_rank:: Int )
0 commit comments