Skip to content

Commit 62d9bd5

Browse files
mattleblancgraeme-a-stewart
authored andcommitted
Fix enums, pass R2 as invR2 throughout
1 parent cad70b9 commit 62d9bd5

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

src/AlgorithmStrategyEnums.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ A dictionary that maps algorithm names to their corresponding power values.
4747
const algorithm2power = Dict(JetAlgorithm.AntiKt => -1,
4848
JetAlgorithm.CA => 0,
4949
JetAlgorithm.Kt => 1,
50-
JetAlgorithm.Durham => 1,
51-
JetAlgorithm.Valencia => 1)
50+
JetAlgorithm.Durham => 1)
5251

5352
"""
5453
get_algorithm_power(; algorithm::JetAlgorithm.Algorithm, p::Union{Real, Nothing}) -> Real

src/EEAlgorithm.jl

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,17 @@ Calculate the angular distance between two jets `i` and `j` using the formula
2323
end
2424

2525
"""
26-
dij_dist(eereco, i, j, dij_factor, algorithm::JetAlgorithm.Algorithm, R=4.0)
26+
dij_dist(eereco, i, j, dij_factor, algorithm::JetAlgorithm.Algorithm, invR2)
2727
2828
Compute dij distance for (Durham, EEKt, Valencia) using simple conditionals.
2929
Beam index `j==0` returns a large sentinel. For Valencia we use the full
3030
Valencia metric (independent of `dij_factor`).
3131
"""
32-
@inline function dij_dist(eereco, i, j, dij_factor, algorithm, R = 4.0)
33-
if !(algorithm isa JetAlgorithm.Algorithm)
34-
throw(ArgumentError("algorithm must be a JetAlgorithm.Algorithm"))
35-
end
32+
@inline function dij_dist(eereco, i, j, dij_factor, algorithm, invR2)
3633
j == 0 && return large_dij
3734
@inbounds begin
3835
if algorithm == JetAlgorithm.Valencia
39-
return valencia_distance(eereco, i, j, R)
36+
return valencia_distance(eereco, i, j, invR2)
4037
else
4138
# Durham & EEKt share same form here (min(E2p_i,E2p_j) * dij_factor * angular_metric)
4239
return min(eereco[i].E2p, eereco[j].E2p) * dij_factor * eereco[i].nndist
@@ -45,22 +42,21 @@ Valencia metric (independent of `dij_factor`).
4542
end
4643

4744
"""
48-
valencia_distance(eereco, i, j, R) -> Float64
45+
valencia_distance(eereco, i, j, invR2) -> Float64
4946
5047
Calculate the Valencia distance between two jets `i` and `j` as
51-
``min(E_i^{2β}, E_j^{2β}) * 2 * (1 - cos(θ_{ij})) / R²``.
48+
``min(E_i^{2β}, E_j^{2β}) * 2 * (1 - cos(θ_{ij})) * invR2``.
5249
5350
# Arguments
5451
- `eereco`: The array of `EERecoJet` objects.
5552
- `i`: The first jet.
5653
- `j`: The second jet.
57-
- `R`: The jet radius parameter.
54+
- `invR2`: The inverse square of the radius, i.e. ``1 / R^2``.
5855
5956
# Returns
6057
- `Float64`: The Valencia distance between `i` and `j`.
6158
"""
62-
Base.@propagate_inbounds @inline function valencia_distance(eereco, i, j, R)
63-
invR2 = inv(R * R)
59+
Base.@propagate_inbounds @inline function valencia_distance(eereco, i, j, invR2)
6460
@muladd angular_dist = 1.0 - eereco[i].nx * eereco[j].nx - eereco[i].ny * eereco[j].ny -
6561
eereco[i].nz * eereco[j].nz
6662
return min(eereco[i].E2p, eereco[j].E2p) * 2 * angular_dist * invR2
@@ -96,7 +92,7 @@ Base.@propagate_inbounds @inline function valencia_beam_distance(eereco, i, γ,
9692
end
9793
end
9894

99-
function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ = 1.0, R = 4.0)
95+
function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, invR2, γ = 1.0)
10096
# Get the initial nearest neighbours for each jet
10197
N = length(eereco)
10298
# Initialise sentinels so the first comparison always wins
@@ -114,7 +110,7 @@ function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ =
114110
# Metric used to pick the nearest neighbour
115111
if algorithm == JetAlgorithm.Valencia
116112
# Use canonical Valencia distance (StructArray-aware)
117-
this_metric = valencia_distance(eereco, i, j, R)
113+
this_metric = valencia_distance(eereco, i, j, invR2)
118114
else
119115
this_metric = angular_distance(eereco, i, j)
120116
end
@@ -131,9 +127,9 @@ function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ =
131127
# Nearest neighbour dij distance
132128
@inbounds for i in 1:N
133129
if algorithm == JetAlgorithm.Valencia
134-
eereco.dijdist[i] = valencia_distance(eereco, i, eereco[i].nni, R)
130+
eereco.dijdist[i] = valencia_distance(eereco, i, eereco[i].nni, invR2)
135131
else
136-
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor, algorithm, R)
132+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor, algorithm, invR2)
137133
end
138134
end
139135
# For the EEKt and Valencia algorithms, we need to check the beam distance as well
@@ -155,7 +151,7 @@ function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ =
155151
end
156152

157153
@inline function update_nn_no_cross!(eereco, i, N, algorithm::JetAlgorithm.Algorithm,
158-
dij_factor, β = 1.0, γ = 1.0, R = 4.0)
154+
dij_factor, invR2, β = 1.0, γ = 1.0)
159155
# Valencia metric is unbounded, others use a large finite value
160156
if algorithm == JetAlgorithm.Valencia
161157
eereco.nndist[i] = Inf
@@ -167,15 +163,15 @@ end
167163
@inbounds for j in 1:N
168164
if j != i
169165
this_metric = algorithm == JetAlgorithm.Valencia ?
170-
valencia_distance(eereco, i, j, R) :
166+
valencia_distance(eereco, i, j, invR2) :
171167
angular_distance(eereco, i, j)
172168
better_i = this_metric < eereco[i].nndist
173169
eereco.nndist[i] = better_i ? this_metric : eereco.nndist[i]
174170
eereco.nni[i] = better_i ? j : eereco.nni[i]
175171
end
176172
end
177173
# Set dij for i using unified dispatcher and apply beam checks
178-
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor, algorithm, R)
174+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor, algorithm, invR2)
179175
if algorithm == JetAlgorithm.EEKt
180176
beam_close = eereco[i].E2p < eereco[i].dijdist
181177
eereco.dijdist[i] = beam_close ? eereco[i].E2p : eereco.dijdist[i]
@@ -189,7 +185,7 @@ end
189185
end
190186

191187
@inline function update_nn_cross!(eereco, i, N, algorithm::JetAlgorithm.Algorithm,
192-
dij_factor, β = 1.0, γ = 1.0, R = 4.0)
188+
dij_factor, invR2, β = 1.0, γ = 1.0)
193189
# Valencia metric is unbounded, others use a large finite value
194190
if algorithm == JetAlgorithm.Valencia
195191
eereco.nndist[i] = Inf
@@ -201,7 +197,7 @@ end
201197
@inbounds for j in 1:N
202198
if j != i
203199
this_metric = algorithm == JetAlgorithm.Valencia ?
204-
valencia_distance(eereco, i, j, R) :
200+
valencia_distance(eereco, i, j, invR2) :
205201
angular_distance(eereco, i, j)
206202
better_i = this_metric < eereco[i].nndist
207203
eereco.nndist[i] = better_i ? this_metric : eereco.nndist[i]
@@ -210,7 +206,7 @@ end
210206
if this_metric < eereco[j].nndist
211207
eereco.nndist[j] = this_metric
212208
eereco.nni[j] = i
213-
eereco.dijdist[j] = dij_dist(eereco, j, i, dij_factor, algorithm, R)
209+
eereco.dijdist[j] = dij_dist(eereco, j, i, dij_factor, algorithm, invR2)
214210
if algorithm == JetAlgorithm.EEKt
215211
if eereco[j].E2p < eereco[j].dijdist
216212
eereco.dijdist[j] = eereco[j].E2p
@@ -227,7 +223,7 @@ end
227223
end
228224
end
229225
# Set dij for i using unified dispatcher and apply beam checks
230-
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor, algorithm, R)
226+
eereco.dijdist[i] = dij_dist(eereco, i, eereco[i].nni, dij_factor, algorithm, invR2)
231227
if algorithm == JetAlgorithm.EEKt
232228
beam_close = eereco[i].E2p < eereco[i].dijdist
233229
eereco.dijdist[i] = beam_close ? eereco[i].E2p : eereco.dijdist[i]
@@ -257,11 +253,11 @@ function ee_check_consistency(clusterseq, eereco, N)
257253
@debug "Consistency check passed"
258254
end
259255

260-
Base.@propagate_inbounds @inline function fill_reco_array!(eereco, particles, R2, p)
256+
Base.@propagate_inbounds @inline function fill_reco_array!(eereco, particles, invR2, p)
261257
@inbounds for i in eachindex(particles)
262258
eereco.index[i] = i
263259
eereco.nni[i] = 0
264-
eereco.nndist[i] = R2
260+
eereco.nndist[i] = inv(invR2) # R^2 as initial sentinel for angular algorithms
265261
# eereco.dijdist[i] = UNDEF # Does not need to be initialised
266262
eereco.nx[i] = nx(particles[i])
267263
eereco.ny[i] = ny(particles[i])
@@ -271,12 +267,12 @@ Base.@propagate_inbounds @inline function fill_reco_array!(eereco, particles, R2
271267
end
272268
end
273269

274-
Base.@propagate_inbounds @inline function insert_new_jet!(eereco, i, newjet_k, R2,
275-
merged_jet, p)
270+
Base.@propagate_inbounds @inline function insert_new_jet!(eereco, i, newjet_k, invR2,
271+
merged_jet, p)
276272
@inbounds begin
277273
eereco.index[i] = newjet_k
278274
eereco.nni[i] = 0
279-
eereco.nndist[i] = R2
275+
eereco.nndist[i] = inv(invR2)
280276
eereco.nx[i] = nx(merged_jet)
281277
eereco.ny[i] = ny(merged_jet)
282278
eereco.nz[i] = nz(merged_jet)
@@ -396,15 +392,17 @@ function ee_genkt_algorithm(particles::AbstractVector{T}; algorithm::JetAlgorith
396392
end
397393
end
398394

395+
# Compute invR2 once and thread it through
396+
invR2 = inv(R * R)
399397
# Now call the unified implementation with conditional logic.
400398
return _ee_genkt_algorithm(particles = recombination_particles, p = p, R = R,
401-
algorithm = algorithm, recombine = recombine, γ = γ)
399+
invR2 = invR2, algorithm = algorithm, recombine = recombine, γ = γ)
402400
end
403401

404402
"""
405403
_ee_genkt_algorithm(particles::AbstractVector{EEJet};
406-
algorithm::JetAlgorithm.Algorithm, p::Real, R = 4.0,
407-
recombine = addjets, γ::Real = 1.0,
404+
algorithm::JetAlgorithm.Algorithm, p::Real, R::Real,
405+
invR2::Union{Real, Nothing} = nothing, recombine = addjets, γ::Real = 1.0,
408406
beta::Union{Real, Nothing} = nothing)
409407
410408
This function is the internal implementation of the e+e- jet clustering
@@ -433,13 +431,16 @@ entry point to this jet reconstruction.
433431
reconstructed jets.
434432
"""
435433
function _ee_genkt_algorithm(; particles::AbstractVector{EEJet},
436-
algorithm::JetAlgorithm.Algorithm, p::Real, R = 4.0,
437-
recombine = addjets, γ::Real = 1.0,
434+
algorithm::JetAlgorithm.Algorithm, p::Real, R::Real,
435+
invR2::Union{Real, Nothing} = nothing, recombine = addjets, γ::Real = 1.0,
438436
beta::Union{Real, Nothing} = nothing)
439437
# Bounds
440438
N::Int = length(particles)
441439

442-
R2 = R^2
440+
# invR2 provided by caller when available; otherwise compute from R once here
441+
if invR2 === nothing
442+
invR2 = inv(R * R)
443+
end
443444
if algorithm == JetAlgorithm.Valencia && beta !== nothing
444445
p = beta
445446
end
@@ -464,7 +465,7 @@ function _ee_genkt_algorithm(; particles::AbstractVector{EEJet},
464465
# We need N slots for this array
465466
eereco = StructArray{EERecoJet}(undef, N)
466467

467-
fill_reco_array!(eereco, particles, R2, p)
468+
fill_reco_array!(eereco, particles, invR2, p)
468469

469470
# Setup the initial history and get the total energy
470471
history, Qtot = initial_history(particles)
@@ -473,7 +474,7 @@ function _ee_genkt_algorithm(; particles::AbstractVector{EEJet},
473474
Qtot)
474475

475476
# Run over initial pairs of jets to find nearest neighbours
476-
get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ, R)
477+
get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, invR2, γ)
477478

478479
# Only for debugging purposes...
479480
# ee_check_consistency(clusterseq, clusterseq_index, N, nndist, nndij, nni, "Start")
@@ -522,7 +523,7 @@ function _ee_genkt_algorithm(; particles::AbstractVector{EEJet},
522523
newjet_k, dij_min)
523524

524525
# Update the compact arrays, reusing the JetA slot
525-
insert_new_jet!(eereco, ijetA, newjet_k, R2, merged_jet, p)
526+
insert_new_jet!(eereco, ijetA, newjet_k, invR2, merged_jet, p)
526527
end
527528

528529
# Squash step - copy the final jet's compact data into the jetB slot
@@ -544,15 +545,15 @@ function _ee_genkt_algorithm(; particles::AbstractVector{EEJet},
544545
# plus "belt and braces" check for an invalid NN (>N)
545546
if (eereco[i].nni == ijetA) || (eereco[i].nni == ijetB) ||
546547
(eereco[i].nni > N)
547-
update_nn_no_cross!(eereco, i, N, algorithm, dij_factor, p, γ, R)
548+
update_nn_no_cross!(eereco, i, N, algorithm, dij_factor, invR2, p, γ)
548549
end
549550
end
550551
end
551552

552553
# Finally, we need to update the nearest neighbours for the new jet, checking both ways
553554
# (But only if there was a new jet!)
554555
if ijetA != ijetB
555-
update_nn_cross!(eereco, ijetA, N, algorithm, dij_factor, p, γ, R)
556+
update_nn_cross!(eereco, ijetA, N, algorithm, dij_factor, invR2, p, γ)
556557
end
557558

558559
# Only for debugging purposes...

0 commit comments

Comments
 (0)