Skip to content

Commit 8a821d9

Browse files
author
ucabc46
committed
fix comments, logs
1 parent ba8ef37 commit 8a821d9

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

extra/weak_scaling/run_particleda.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ open(parameters_file, "w") do io
4949
YAML.write(io, parameters)
5050
end
5151

52-
@info "Optimized resampling enabled: ", parameters["filter"]["optimize_resampling"]
52+
@info "Optimized resampling enabled: $(parameters["filter"]["optimize_resampling"])"
5353

5454
final_states, final_statistics = run_particle_filter(
5555
LLW2d.init, parameters_file, observation_file, filter_type, summary_stat_type

src/utils.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,12 @@ function copy_states!(
9191
to::TimerOutputs.TimerOutput = TimerOutputs.TimerOutput()
9292
) where T
9393

94-
# Same as copy_states
94+
# These are the particle indices stored on this rank
9595
particles_have = my_rank * nprt_per_rank + 1:(my_rank + 1) * nprt_per_rank
96+
97+
# These are the particle indices this rank should have after resampling
9698
particles_want = resampling_indices[particles_have]
99+
97100
reqs = Vector{MPI.Request}(undef, 0)
98101

99102
# Determine which particles need to be sent where
@@ -183,12 +186,12 @@ function _categorize_wants(particles_want::Vector{Int}, my_rank::Int, nprt_per_r
183186
for k in 1:nprt_per_rank
184187
id = particles_want[k]
185188
source_rank = floor(Int, (id - 1) / nprt_per_rank)
189+
dict = source_rank == my_rank ? local_copies : remote_copies
186190

187-
if source_rank == my_rank
188-
get!(() -> Int[], local_copies, id) |> v -> push!(v, k)
189-
else
190-
get!(() -> Int[], remote_copies, id) |> v -> push!(v, k)
191+
vec = get!(dict, id) do
192+
Int[] # initialize a new vector only if id not present
191193
end
194+
push!(vec, k)
192195
end
193196
return local_copies, remote_copies
194197
end

test/mpi_optimized_copy_states.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ MPI.Init()
3737
my_rank = MPI.Comm_rank(MPI.COMM_WORLD)
3838
my_size = MPI.Comm_size(MPI.COMM_WORLD)
3939

40-
@info "Number of threads available: ", Threads.nthreads()
40+
@info "Number of threads available: $(Threads.nthreads())"
4141

4242
n_particle_per_rank = 1000
4343
n_particle = n_particle_per_rank * my_size
@@ -48,13 +48,10 @@ if output_timer
4848
error("Please provide the output filename for timers.")
4949
end
5050
output_filename = ARGS[2]
51-
@info "Outputting timers to HDF5 file '$output_filename'"
51+
@info "Outputting timers to HDF5 file '$(output_filename)'"
5252
end
53-
# default: dedup enabled for testing
54-
no_dedup = "-nd" in ARGS || "--no-dedup" in ARGS
55-
@info "Deduplication enabled: ", !no_dedup
5653
optimize_resample = "-o" in ARGS || "--optimize-resample" in ARGS
57-
@info "Optimized resampling enabled: ", optimize_resample
54+
@info "Optimized resampling enabled: $(optimize_resample)"
5855

5956
n_float_per_particle = 100000
6057
# total number of floats per rank
@@ -88,7 +85,7 @@ local_timer_dicts = Dict{String, Dict{String,Any}}()
8885

8986
for (trial_name, indices_func) in trial_sets
9087
if verbose && my_rank == 0
91-
@info "Resampling particles to indices ", indices
88+
@info "Resampling particles to indices $(indices)"
9289
end
9390
indices = collect(1:n_particle) # Placeholder for actual indices
9491
# repeat experiment 10 times to get average time
@@ -102,7 +99,7 @@ for (trial_name, indices_func) in trial_sets
10299
my_rank,
103100
n_particle_per_rank
104101
)
105-
@info "Starting timed runs for trial '$trial_name'..."
102+
@info "Starting timed runs for trial '$(trial_name)'..."
106103

107104
timer = TimerOutputs.TimerOutput("copy_states")
108105
for _ in 1:10
@@ -143,7 +140,7 @@ for (trial_name, indices_func) in trial_sets
143140
show(stdout, "text/plain", local_states);
144141
@info "rank $(my_rank): expected ="
145142
show(stdout, "text/plain", expected);
146-
@info "rank $(my_rank): match = ", match
143+
@info "rank $(my_rank): match = $(match)"
147144
end
148145
MPI.Barrier(MPI.COMM_WORLD)
149146
end

0 commit comments

Comments
 (0)