Skip to content

Commit bd6ef24

Browse files
committed
Clean up inbounds throughout EEAlgorithm.jl, switch to Val-based forwarders to avoid branching, fast path when p isa int
1 parent 1bb3d60 commit bd6ef24

File tree

1 file changed

+286
-24
lines changed

1 file changed

+286
-24
lines changed

src/EEAlgorithm.jl

Lines changed: 286 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ Calculate the Valencia distance between two jets `i` and `j` as
3737
# Returns
3838
- `Float64`: The Valencia distance between `i` and `j`.
3939
"""
40-
Base.@propagate_inbounds @inline function valencia_distance(eereco, i, j, R)
40+
Base.@propagate_inbounds @inline function valencia_distance_inv(eereco, i, j, invR2)
4141
angular_dist = angular_distance(eereco, i, j)
42-
# Valencia dij : min(E_i^{2β}, E_j^{2β}) * 2 * (1 - cos θ) / R²
42+
# Valencia dij : min(E_i^{2β}, E_j^{2β}) * 2 * (1 - cos θ) * invR2
4343
# Note that β plays the role of p in other algorithms, so E2p can be used.
44-
min(eereco[i].E2p, eereco[j].E2p) * 2 * angular_dist / (R * R)
44+
min(eereco[i].E2p, eereco[j].E2p) * 2 * angular_dist * invR2
45+
end
46+
47+
Base.@propagate_inbounds @inline function valencia_distance(eereco, i, j, R)
48+
return valencia_distance_inv(eereco, i, j, inv(R * R))
4549
end
4650

4751
"""
@@ -167,6 +171,90 @@ function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ =
167171
end
168172
end
169173

174+
# Val-specialized nearest neighbour search (removes runtime branches in hot loops)
175+
@inline function get_angular_nearest_neighbours!(eereco, ::Val{JetAlgorithm.Durham},
176+
dij_factor, p, γ = 1.0, R = 4.0)
177+
N = length(eereco)
178+
# Nearest neighbour search using angular metric
179+
@inbounds for i in 1:N
180+
@inbounds for j in (i + 1):N
181+
this_metric = angular_distance(eereco, i, j)
182+
better_nndist_i = this_metric < eereco[i].nndist
183+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
184+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
185+
better_nndist_j = this_metric < eereco[j].nndist
186+
eereco.nndist[j] = better_nndist_j ? this_metric : eereco.nndist[j]
187+
eereco.nni[j] = better_nndist_j ? i : eereco.nni[j]
188+
end
189+
end
190+
@inbounds for i in 1:N
191+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor,
192+
Val(JetAlgorithm.Durham), R)
193+
end
194+
end
195+
196+
@inline function get_angular_nearest_neighbours!(eereco, ::Val{JetAlgorithm.EEKt},
197+
dij_factor, p, γ = 1.0, R = 4.0)
198+
N = length(eereco)
199+
@inbounds for i in 1:N
200+
@inbounds for j in (i + 1):N
201+
this_metric = angular_distance(eereco, i, j)
202+
better_nndist_i = this_metric < eereco[i].nndist
203+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
204+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
205+
better_nndist_j = this_metric < eereco[j].nndist
206+
eereco.nndist[j] = better_nndist_j ? this_metric : eereco.nndist[j]
207+
eereco.nni[j] = better_nndist_j ? i : eereco.nni[j]
208+
end
209+
end
210+
@inbounds for i in 1:N
211+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor,
212+
Val(JetAlgorithm.EEKt), R)
213+
end
214+
@inbounds for i in 1:N
215+
beam_closer = eereco[i].E2p < eereco[i].dijdist
216+
eereco.dijdist[i] = beam_closer ? eereco[i].E2p : eereco.dijdist[i]
217+
eereco.nni[i] = beam_closer ? 0 : eereco.nni[i]
218+
end
219+
end
220+
221+
@inline function get_angular_nearest_neighbours!(eereco, ::Val{JetAlgorithm.Valencia},
222+
dij_factor, p, γ = 1.0, R = 4.0)
223+
N = length(eereco)
224+
invR2 = inv(R * R)
225+
@inbounds for i in 1:N
226+
eereco.nndist[i] = Inf
227+
eereco.nni[i] = i
228+
end
229+
@inbounds for i in 1:N
230+
@inbounds for j in (i + 1):N
231+
this_metric = valencia_distance_inv(eereco, i, j, invR2)
232+
better_nndist_i = this_metric < eereco[i].nndist
233+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
234+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
235+
better_nndist_j = this_metric < eereco[j].nndist
236+
eereco.nndist[j] = better_nndist_j ? this_metric : eereco.nndist[j]
237+
eereco.nni[j] = better_nndist_j ? i : eereco.nni[j]
238+
end
239+
end
240+
@inbounds for i in 1:N
241+
eereco.dijdist[i] = valencia_distance_inv(eereco, i, eereco[i].nni, invR2)
242+
end
243+
@inbounds for i in 1:N
244+
valencia_beam_dist = valencia_beam_distance(eereco, i, γ, p)
245+
beam_closer = valencia_beam_dist < eereco[i].dijdist
246+
eereco.dijdist[i] = beam_closer ? valencia_beam_dist : eereco.dijdist[i]
247+
eereco.nni[i] = beam_closer ? 0 : eereco.nni[i]
248+
end
249+
end
250+
251+
# Forwarder to Val-specialized version
252+
@inline function get_angular_nearest_neighbours!(eereco,
253+
algorithm::JetAlgorithm.Algorithm,
254+
dij_factor, p, γ = 1.0, R = 4.0)
255+
return get_angular_nearest_neighbours!(eereco, Val(algorithm), dij_factor, p, γ, R)
256+
end
257+
170258
# Update the nearest neighbour for jet i, w.r.t. all other active jets
171259
function update_nn_no_cross!(eereco, i, N, algorithm, dij_factor, β = 1.0, γ = 1.0, R = 4.0)
172260
eereco.nndist[i] = algorithm == JetAlgorithm.Valencia ? Inf : large_distance
@@ -199,6 +287,68 @@ function update_nn_no_cross!(eereco, i, N, algorithm, dij_factor, β = 1.0, γ =
199287
end
200288
end
201289

290+
# Val-specialized no-cross update
291+
@inline function update_nn_no_cross!(eereco, i, N, ::Val{JetAlgorithm.Durham}, dij_factor,
292+
β = 1.0, γ = 1.0, R = 4.0)
293+
eereco.nndist[i] = large_distance
294+
eereco.nni[i] = i
295+
@inbounds for j in 1:N
296+
if j != i
297+
this_metric = angular_distance(eereco, i, j)
298+
better_nndist_i = this_metric < eereco[i].nndist
299+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
300+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
301+
end
302+
end
303+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor,
304+
Val(JetAlgorithm.Durham), R)
305+
end
306+
307+
@inline function update_nn_no_cross!(eereco, i, N, ::Val{JetAlgorithm.EEKt}, dij_factor,
308+
β = 1.0, γ = 1.0, R = 4.0)
309+
eereco.nndist[i] = large_distance
310+
eereco.nni[i] = i
311+
@inbounds for j in 1:N
312+
if j != i
313+
this_metric = angular_distance(eereco, i, j)
314+
better_nndist_i = this_metric < eereco[i].nndist
315+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
316+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
317+
end
318+
end
319+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor,
320+
Val(JetAlgorithm.EEKt), R)
321+
beam_close = eereco[i].E2p < eereco[i].dijdist
322+
eereco.dijdist[i] = beam_close ? eereco[i].E2p : eereco.dijdist[i]
323+
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]
324+
end
325+
326+
@inline function update_nn_no_cross!(eereco, i, N, ::Val{JetAlgorithm.Valencia},
327+
dij_factor, β = 1.0, γ = 1.0, R = 4.0)
328+
eereco.nndist[i] = Inf
329+
eereco.nni[i] = i
330+
invR2 = inv(R * R)
331+
@inbounds for j in 1:N
332+
if j != i
333+
this_metric = valencia_distance_inv(eereco, i, j, invR2)
334+
better_nndist_i = this_metric < eereco[i].nndist
335+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
336+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
337+
end
338+
end
339+
eereco.dijdist[i] = valencia_distance_inv(eereco, i, eereco[i].nni, invR2)
340+
valencia_beam_dist = valencia_beam_distance(eereco, i, γ, β)
341+
beam_close = valencia_beam_dist < eereco[i].dijdist
342+
eereco.dijdist[i] = beam_close ? valencia_beam_dist : eereco.dijdist[i]
343+
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]
344+
end
345+
346+
# Forwarder
347+
@inline function update_nn_no_cross!(eereco, i, N, algorithm::JetAlgorithm.Algorithm,
348+
dij_factor, β = 1.0, γ = 1.0, R = 4.0)
349+
return update_nn_no_cross!(eereco, i, N, Val(algorithm), dij_factor, β, γ, R)
350+
end
351+
202352
function update_nn_cross!(eereco, i, N, algorithm, dij_factor, β = 1.0, γ = 1.0, R = 4.0)
203353
# Update the nearest neighbour for jet i, w.r.t. all other active jets
204354
# also doing the cross check for the other jet
@@ -254,6 +404,93 @@ function update_nn_cross!(eereco, i, N, algorithm, dij_factor, β = 1.0, γ = 1.
254404
end
255405
end
256406

407+
# Val-specialized cross update
408+
@inline function update_nn_cross!(eereco, i, N, ::Val{JetAlgorithm.Durham}, dij_factor,
409+
β = 1.0, γ = 1.0, R = 4.0)
410+
eereco.nndist[i] = large_distance
411+
eereco.nni[i] = i
412+
@inbounds for j in 1:N
413+
if j != i
414+
this_metric = angular_distance(eereco, i, j)
415+
better_nndist_i = this_metric < eereco[i].nndist
416+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
417+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
418+
if this_metric < eereco[j].nndist
419+
eereco.nndist[j] = this_metric
420+
eereco.nni[j] = i
421+
eereco.dijdist[j] = dij_dist(eereco, j, i, dij_factor,
422+
Val(JetAlgorithm.Durham), R)
423+
end
424+
end
425+
end
426+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor,
427+
Val(JetAlgorithm.Durham), R)
428+
end
429+
430+
@inline function update_nn_cross!(eereco, i, N, ::Val{JetAlgorithm.EEKt}, dij_factor,
431+
β = 1.0, γ = 1.0, R = 4.0)
432+
eereco.nndist[i] = large_distance
433+
eereco.nni[i] = i
434+
@inbounds for j in 1:N
435+
if j != i
436+
this_metric = angular_distance(eereco, i, j)
437+
better_nndist_i = this_metric < eereco[i].nndist
438+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
439+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
440+
if this_metric < eereco[j].nndist
441+
eereco.nndist[j] = this_metric
442+
eereco.nni[j] = i
443+
eereco.dijdist[j] = dij_dist(eereco, j, i, dij_factor,
444+
Val(JetAlgorithm.EEKt), R)
445+
if eereco[j].E2p < eereco[j].dijdist
446+
eereco.dijdist[j] = eereco[j].E2p
447+
eereco.nni[j] = 0
448+
end
449+
end
450+
end
451+
end
452+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor,
453+
Val(JetAlgorithm.EEKt), R)
454+
beam_close = eereco[i].E2p < eereco[i].dijdist
455+
eereco.dijdist[i] = beam_close ? eereco[i].E2p : eereco.dijdist[i]
456+
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]
457+
end
458+
459+
@inline function update_nn_cross!(eereco, i, N, ::Val{JetAlgorithm.Valencia}, dij_factor,
460+
β = 1.0, γ = 1.0, R = 4.0)
461+
eereco.nndist[i] = Inf
462+
eereco.nni[i] = i
463+
invR2 = inv(R * R)
464+
@inbounds for j in 1:N
465+
if j != i
466+
this_metric = valencia_distance_inv(eereco, i, j, invR2)
467+
better_nndist_i = this_metric < eereco[i].nndist
468+
eereco.nndist[i] = better_nndist_i ? this_metric : eereco.nndist[i]
469+
eereco.nni[i] = better_nndist_i ? j : eereco.nni[i]
470+
if this_metric < eereco[j].nndist
471+
eereco.nndist[j] = this_metric
472+
eereco.nni[j] = i
473+
eereco.dijdist[j] = valencia_distance_inv(eereco, j, i, invR2)
474+
valencia_beam_dist = valencia_beam_distance(eereco, j, γ, β)
475+
if valencia_beam_dist < eereco[j].dijdist
476+
eereco.dijdist[j] = valencia_beam_dist
477+
eereco.nni[j] = 0
478+
end
479+
end
480+
end
481+
end
482+
eereco.dijdist[i] = valencia_distance_inv(eereco, i, eereco[i].nni, invR2)
483+
valencia_beam_dist = valencia_beam_distance(eereco, i, γ, β)
484+
beam_close = valencia_beam_dist < eereco[i].dijdist
485+
eereco.dijdist[i] = beam_close ? valencia_beam_dist : eereco.dijdist[i]
486+
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]
487+
end
488+
489+
@inline function update_nn_cross!(eereco, i, N, algorithm::JetAlgorithm.Algorithm,
490+
dij_factor, β = 1.0, γ = 1.0, R = 4.0)
491+
return update_nn_cross!(eereco, i, N, Val(algorithm), dij_factor, β, γ, R)
492+
end
493+
257494
function ee_check_consistency(clusterseq, eereco, N)
258495
# Check the consistency of the reconstruction state
259496
for i in 1:N
@@ -268,46 +505,71 @@ function ee_check_consistency(clusterseq, eereco, N)
268505
end
269506
end
270507
end
271-
@debug "Consistency check passed at $msg"
508+
@debug "Consistency check passed"
272509
end
273510

274-
function fill_reco_array!(eereco, particles, R2, p)
275-
for i in eachindex(particles)
511+
Base.@propagate_inbounds @inline function fill_reco_array!(eereco, particles, R2, p)
512+
@inbounds for i in eachindex(particles)
276513
eereco.index[i] = i
277514
eereco.nni[i] = 0
278515
eereco.nndist[i] = R2
279516
# eereco.dijdist[i] = UNDEF # Does not need to be initialised
280517
eereco.nx[i] = nx(particles[i])
281518
eereco.ny[i] = ny(particles[i])
282519
eereco.nz[i] = nz(particles[i])
283-
eereco.E2p[i] = energy(particles[i])^(2p)
520+
E = energy(particles[i])
521+
if p isa Int
522+
if p == 1
523+
eereco.E2p[i] = E * E
524+
else
525+
E2 = E * E
526+
eereco.E2p[i] = E2^p
527+
end
528+
else
529+
eereco.E2p[i] = E^(2p)
530+
end
284531
end
285532
end
286533

287-
@inline function insert_new_jet!(eereco, i, newjet_k, R2, merged_jet, p)
288-
eereco.index[i] = newjet_k
289-
eereco.nni[i] = 0
290-
eereco.nndist[i] = R2
291-
eereco.nx[i] = nx(merged_jet)
292-
eereco.ny[i] = ny(merged_jet)
293-
eereco.nz[i] = nz(merged_jet)
294-
eereco.E2p[i] = energy(merged_jet)^(2p)
534+
Base.@propagate_inbounds @inline function insert_new_jet!(eereco, i, newjet_k, R2,
535+
merged_jet, p)
536+
@inbounds begin
537+
eereco.index[i] = newjet_k
538+
eereco.nni[i] = 0
539+
eereco.nndist[i] = R2
540+
eereco.nx[i] = nx(merged_jet)
541+
eereco.ny[i] = ny(merged_jet)
542+
eereco.nz[i] = nz(merged_jet)
543+
E = energy(merged_jet)
544+
if p isa Int
545+
if p == 1
546+
eereco.E2p[i] = E * E
547+
else
548+
E2 = E * E
549+
eereco.E2p[i] = E2^p
550+
end
551+
else
552+
eereco.E2p[i] = E^(2p)
553+
end
554+
end
295555
end
296556

297557
"""
298558
copy_to_slot!(eereco, i, j)
299559
300560
Copy the contents of slot `i` in the `eereco` array to slot `j`.
301561
"""
302-
@inline function copy_to_slot!(eereco, i, j)
303-
eereco.index[j] = eereco.index[i]
304-
eereco.nni[j] = eereco.nni[i]
305-
eereco.nndist[j] = eereco.nndist[i]
306-
eereco.dijdist[j] = eereco.dijdist[i]
307-
eereco.nx[j] = eereco.nx[i]
308-
eereco.ny[j] = eereco.ny[i]
309-
eereco.nz[j] = eereco.nz[i]
310-
eereco.E2p[j] = eereco.E2p[i]
562+
Base.@propagate_inbounds @inline function copy_to_slot!(eereco, i, j)
563+
@inbounds begin
564+
eereco.index[j] = eereco.index[i]
565+
eereco.nni[j] = eereco.nni[i]
566+
eereco.nndist[j] = eereco.nndist[i]
567+
eereco.dijdist[j] = eereco.dijdist[i]
568+
eereco.nx[j] = eereco.nx[i]
569+
eereco.ny[j] = eereco.ny[i]
570+
eereco.nz[j] = eereco.nz[i]
571+
eereco.E2p[j] = eereco.E2p[i]
572+
end
311573
end
312574

313575
"""

0 commit comments

Comments
 (0)