Skip to content

Commit a14cc8e

Browse files
committed
Comments from Jerry, try to fix branching
1 parent 87b241f commit a14cc8e

File tree

1 file changed

+42
-25
lines changed

1 file changed

+42
-25
lines changed

src/EEAlgorithm.jl

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -61,39 +61,56 @@ for unit direction cosines. Since ``sin^2 θ = 1 - nz^2``, we implement
6161
# Returns
6262
- `Float64`: The Valencia beam distance for jet `i`.
6363
"""
64-
@inline function valencia_beam_distance(eereco, i, γ, β)
64+
Base.@propagate_inbounds function valencia_beam_distance(eereco, i, γ, β)
6565
nz = @inbounds eereco[i].nz
6666
# sin^2(theta) = 1 - nz^2; beam distance independent of R
6767
sin2 = 1 - nz * nz
68-
@inbounds eereco[i].E2p * sin2^γ
68+
return eereco[i].E2p * sin2^γ
6969
end
7070

7171
"""
72-
dij_dist(eereco, i, j, dij_factor, algorithm = JetAlgorithm.Durham, R = 4.0)
72+
dij_dist(eereco, i, j, dij_factor, algorithm::JetAlgorithm.Algorithm, R = 4.0)
7373
74-
Calculate the dij distance between two ``e^+e^-``jets.
74+
Calculate the dij distance between two e⁺e⁻ jets. This is the public entry point.
75+
Internally, this forwards to a Val-based method for the given algorithm, which
76+
allows the compiler to specialize away branches when `algorithm` is a constant.
77+
"""
78+
@inline function dij_dist(eereco, i, j, dij_factor, algorithm::JetAlgorithm.Algorithm,
79+
R = 4.0)
80+
return dij_dist(eereco, i, j, dij_factor, Val(algorithm), R)
81+
end
7582

76-
# Arguments
77-
- `eereco`: The array of `EERecoJet` objects.
78-
- `i`: The first jet.
79-
- `j`: The second jet.
80-
- `dij_factor`: The scaling factor to multiply the dij distance by.
81-
- `algorithm`: The jet algorithm being used.
82-
- `R`: the radius or resolution parameter
83+
"""
84+
dij_dist(eereco, i, j, dij_factor, ::Val{JetAlgorithm.Durham}, R = 4.0)
85+
dij_dist(eereco, i, j, dij_factor, ::Val{JetAlgorithm.EEKt}, R = 4.0)
8386
84-
# Returns
85-
- The dij distance between `i` and `j`.
87+
Durham/EEKt dij distance:
88+
min(E_i^{2p}, E_j^{2p}) * dij_factor * (angular NN metric stored in nndist).
89+
For EEKt, dij_factor encodes the R-dependent normalization.
8690
"""
87-
@inline function dij_dist(eereco, i, j, dij_factor, algorithm = JetAlgorithm.Durham,
88-
R = 4.0)
89-
# Calculate the dij distance for jet i from jet j
91+
@inline function dij_dist(eereco, i, j, dij_factor, ::Val{JetAlgorithm.Durham}, R = 4.0)
9092
j == 0 && return large_dij
93+
@inbounds min(eereco[i].E2p, eereco[j].E2p) * dij_factor * eereco[i].nndist
94+
end
9195

92-
if algorithm == JetAlgorithm.Valencia
93-
@inbounds valencia_distance(eereco, i, j, R)
94-
else
95-
@inbounds min(eereco[i].E2p, eereco[j].E2p) * dij_factor * eereco[i].nndist
96-
end
96+
@inline function dij_dist(eereco, i, j, dij_factor, ::Val{JetAlgorithm.EEKt}, R = 4.0)
97+
j == 0 && return large_dij
98+
@inbounds min(eereco[i].E2p, eereco[j].E2p) * dij_factor * eereco[i].nndist
99+
end
100+
101+
"""
102+
dij_dist(eereco, i, j, dij_factor, ::Val{JetAlgorithm.Valencia}, R)
103+
104+
Valencia dij distance uses the full Valencia metric, including the 2*(1-cosθ)/R² factor.
105+
"""
106+
@inline function dij_dist(eereco, i, j, dij_factor, ::Val{JetAlgorithm.Valencia}, R)
107+
j == 0 && return large_dij
108+
@inbounds valencia_distance(eereco, i, j, R)
109+
end
110+
111+
# Fallback if a non-Algorithm token is passed
112+
@inline function dij_dist(eereco, i, j, dij_factor, algorithm, R = 4.0)
113+
throw(ArgumentError("Algorithm $algorithm not supported for dij_dist"))
97114
end
98115

99116
function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ = 1.0, R = 4.0)
@@ -141,7 +158,7 @@ function get_angular_nearest_neighbours!(eereco, algorithm, dij_factor, p, γ =
141158
end
142159
elseif algorithm == JetAlgorithm.Valencia
143160
@inbounds for i in 1:N
144-
valencia_beam_dist = valencia_beam_distance(eereco, i, γ, p)
161+
valencia_beam_dist = @inbounds valencia_beam_distance(eereco, i, γ, p)
145162
beam_closer = valencia_beam_dist < eereco[i].dijdist
146163
eereco.dijdist[i] = beam_closer ? valencia_beam_dist : eereco.dijdist[i]
147164
eereco.nni[i] = beam_closer ? 0 : eereco.nni[i]
@@ -174,7 +191,7 @@ function update_nn_no_cross!(eereco, i, N, algorithm, dij_factor, β = 1.0, γ =
174191
eereco.dijdist[i] = beam_close ? eereco[i].E2p : eereco.dijdist[i]
175192
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]
176193
elseif algorithm == JetAlgorithm.Valencia
177-
valencia_beam_dist = valencia_beam_distance(eereco, i, γ, β)
194+
valencia_beam_dist = @inbounds valencia_beam_distance(eereco, i, γ, β)
178195
beam_close = valencia_beam_dist < eereco[i].dijdist
179196
eereco.dijdist[i] = beam_close ? valencia_beam_dist : eereco.dijdist[i]
180197
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]
@@ -200,7 +217,7 @@ function update_nn_cross!(eereco, i, N, algorithm, dij_factor, β = 1.0, γ = 1.
200217
eereco.nni[j] = i
201218
# j will not be revisited, so update metric distance here
202219
if algorithm == JetAlgorithm.Valencia
203-
eereco.dijdist[j] = @inbounds valencia_distance(eereco, j, i, R)
220+
eereco.dijdist[j] = valencia_distance(eereco, j, i, R)
204221
else
205222
eereco.dijdist[j] = dij_dist(eereco, j, i, dij_factor, algorithm, R)
206223
end
@@ -229,7 +246,7 @@ function update_nn_cross!(eereco, i, N, algorithm, dij_factor, β = 1.0, γ = 1.
229246
eereco.dijdist[i] = beam_close ? eereco[i].E2p : eereco.dijdist[i]
230247
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]
231248
elseif algorithm == JetAlgorithm.Valencia
232-
valencia_beam_dist = valencia_beam_distance(eereco, i, γ, β)
249+
valencia_beam_dist = @inbounds valencia_beam_distance(eereco, i, γ, β)
233250
beam_close = valencia_beam_dist < eereco[i].dijdist
234251
eereco.dijdist[i] = beam_close ? valencia_beam_dist : eereco.dijdist[i]
235252
eereco.nni[i] = beam_close ? 0 : eereco.nni[i]

0 commit comments

Comments
 (0)