Skip to content

Commit 622fe22

Browse files
author
ucabc46
committed
optimize indices
1 parent d768e56 commit 622fe22

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

src/utils.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,52 @@
1+
using ExactOptimalTransport
12

23
function normalized_exp!(weight::AbstractVector)
34
weight .-= maximum(weight)
45
@. weight = exp(weight)
56
weight ./= sum(weight)
67
end
78

9+
function optimized_resample!(resampled_indices::AbstractVector{Int}, nrank::Int)
10+
nprt_per_rank = length(resampled_indices) ÷ nrank
11+
stock_queue = [Int[] for _ in 1:nrank]
12+
13+
# Assign each resampled index to its corresponding rank
14+
for resampled_idx in resampled_indices
15+
rank = div(resampled_idx - 1, nprt_per_rank) + 1
16+
push!(stock_queue[rank], resampled_idx)
17+
end
18+
19+
supply_vector = [length(stock_queue[rank]) for rank in 1:nrank]
20+
demand_vector = fill(nprt_per_rank, nrank)
21+
cost_matrix = ones(Int, nrank, nrank)
22+
for i in 1:nrank
23+
cost_matrix[i, i] = 0
24+
end
25+
26+
γ = emd(supply_vector, demand_vector, cost_matrix)
27+
28+
# update resampled_indices
29+
for i in 1:nrank
30+
idx = 1
31+
for j in 1:nrank
32+
nmove = Int(γ[j, i])
33+
for _ in 1:nmove
34+
resampled_indices[(i - 1) * nprt_per_rank + idx] = popfirst!(stock_queue[j])
35+
idx += 1
36+
end
37+
end
38+
end
39+
return resampled_indices
40+
end
41+
42+
843
# Resample particles from given weights using Stochastic Universal Sampling
944
function resample!(
1045
resampled_indices::AbstractVector{Int},
1146
weights::AbstractVector{T},
12-
rng::Random.AbstractRNG=Random.TaskLocalRNG()
47+
rng::Random.AbstractRNG=Random.TaskLocalRNG(),
48+
optimize_resample::Bool=false,
49+
nrank::Int=1,
1350
) where T
1451

1552
nprt = length(weights)
@@ -28,6 +65,10 @@ function resample!(
2865
end
2966
resampled_indices[ip] = k
3067
end
68+
69+
if optimize_resample
70+
resampled_indices .= optimized_resample!(resampled_indices, nrank)
71+
end
3172
end
3273

3374
function init_states(model, nprt_per_rank::Int, n_tasks::Int, rng::AbstractRNG)
@@ -174,6 +215,7 @@ function copy_states_dedup!(
174215
for (id, buffer_indices) in remote_copies
175216
if length(buffer_indices) > 1
176217
source_view = view(buffer, :, buffer_indices[1])
218+
# TODO: threading in chunks
177219
Threads.@threads for i in 2:length(buffer_indices)
178220
k = buffer_indices[i]
179221
buffer[:, k] .= source_view
@@ -182,8 +224,12 @@ function copy_states_dedup!(
182224
end
183225
end
184226

185-
@timeit_debug to "write from buffer" particles .= buffer
186-
227+
@timeit_debug to "write from buffer" begin
228+
Threads.@threads for j in 1:size(particles, 2)
229+
# @views creates a non-allocating view of the column, which is faster inside a loop
230+
@views particles[:, j] .= buffer[:, j]
231+
end
232+
end
187233
end
188234

189235
function _determine_sends(resampling_indices::Vector{Int}, my_rank::Int, nprt_per_rank::Int)

test/mpi_copy_states.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ my_size = MPI.Comm_size(MPI.COMM_WORLD)
3939

4040
println("Number of threads available: ", Threads.nthreads())
4141

42-
n_particle_per_rank = 1000
42+
n_particle_per_rank = 100
4343
n_particle = n_particle_per_rank * my_size
4444
verbose = "-v" in ARGS || "--verbose" in ARGS
4545
output_timer = "-t" in ARGS || "--output-timer" in ARGS

test/optimized_resample.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using ParticleDA
2+
3+
nrank = 5
4+
n_particle = 10
5+
6+
resampled_indices = sample_indices(n_particle, k=5, p=0.99)
7+
println("Resampled Indices: ", resampled_indices)
8+
9+
resampled_indices = ParticleDA.optimized_resample!(resampled_indices, nrank)
10+
println("Optimized Resampled Indices: ", resampled_indices)

0 commit comments

Comments
 (0)