Skip to content

Commit 35d2f79

Browse files
committed
use SIMD.jl directly instead of LV.jl for fast_findmin()
1 parent 1b245ba commit 35d2f79

File tree

5 files changed

+36
-17
lines changed

5 files changed

+36
-17
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,31 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
99
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1010
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1111
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
12-
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1312
LorentzVectorHEP = "f612022c-142a-473f-8cfd-a09cf3793c6c"
1413
LorentzVectors = "3f54b04b-17fc-5cd4-9758-90c048d965e3"
1514
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
15+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
1616
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1717

1818
[weakdeps]
19-
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
2019
EDM4hep = "eb32b910-dde9-4347-8fce-cd6be3498f0c"
20+
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
2121

2222
[extensions]
23-
JetVisualisation = "Makie"
2423
EDM4hepJets = "EDM4hep"
24+
JetVisualisation = "Makie"
2525

2626
[compat]
2727
Accessors = "0.1.36"
2828
CodecZlib = "0.7.4"
2929
EDM4hep = "0.4.0"
3030
EnumX = "1.0.4"
3131
JSON = "0.21.4"
32-
LoopVectorization = "0.12.170"
3332
LorentzVectorHEP = "0.1.6"
3433
LorentzVectors = "0.4.3"
3534
Makie = "0.20, 0.21"
3635
MuladdMacro = "0.2.4"
36+
SIMD = "3.6"
3737
StructArrays = "0.6.18"
3838
julia = "1.9"
3939

src/JetReconstruction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module JetReconstruction
1818
using LorentzVectorHEP
1919
using MuladdMacro
2020
using StructArrays
21+
using SIMD
2122

2223
# Import from LorentzVectorHEP methods for those 4-vector types
2324
pt2(p::LorentzVector) = LorentzVectorHEP.pt2(p)

src/PlainAlgo.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using LoopVectorization
2-
31
"""
42
dist(i, j, rapidity_array, phi_array)
53

src/TiledAlgoLL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
using Logging
77
using Accessors
8-
using LoopVectorization
98

109
# Include struct definitions and basic operations
1110
include("TiledAlgoLLStructs.jl")

src/Utils.jl

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ end
123123
fast_findmin(dij, n)
124124
125125
Find the minimum value and its index in the first `n` elements of the `dij`
126-
array. The use of `@turbo` macro gives a significiant performance boost.
126+
array.
127127
128128
# Arguments
129129
- `dij`: An array of values.
@@ -133,14 +133,35 @@ array. The use of `@turbo` macro gives a significiant performance boost.
133133
- `dij_min`: The minimum value in the first `n` elements of the `dij` array.
134134
- `best`: The index of the minimum value in the `dij` array.
135135
"""
136-
fast_findmin(dij, n) = begin
137-
# findmin(@inbounds @view dij[1:n])
138-
best = 1
139-
@inbounds dij_min = dij[1]
140-
@turbo for here in 2:n
141-
newmin = dij[here] < dij_min
142-
best = newmin ? here : best
143-
dij_min = newmin ? dij[here] : dij_min
136+
function fast_findmin(x, n)
137+
laneIndices = SIMD.Vec{8, Int64}((1, 2, 3, 4, 5, 6, 7, 8))
138+
minvals = SIMD.Vec{8, Float64}(Inf)
139+
min_indices = SIMD.Vec{8, Int64}(0)
140+
141+
n_batches, remainder = divrem(n, 8)
142+
lane = VecRange{8}(0)
143+
i = 1
144+
@inbounds @fastmath for _ in 1:n_batches
145+
predicate = x[lane + i] < minvals
146+
minvals = vifelse(predicate, x[lane + i], minvals)
147+
min_indices = vifelse(predicate, laneIndices, min_indices)
148+
149+
i += 8
150+
laneIndices += 8
144151
end
145-
dij_min, best
152+
153+
min_value = SIMD.minimum(minvals)
154+
min_index = min_value == minvals[1] ? min_indices[1] : min_value == minvals[2] ? min_indices[2] :
155+
min_value == minvals[3] ? min_indices[3] : min_value == minvals[4] ? min_indices[4] :
156+
min_value == minvals[5] ? min_indices[5] : min_value == minvals[6] ? min_indices[6] :
157+
min_value == minvals[7] ? min_indices[7] : min_indices[8]
158+
159+
@inbounds @fastmath for _ in 1:remainder
160+
xi = x[i]
161+
pred = x[i] < min_value
162+
min_value = ifelse(pred, xi, min_value)
163+
min_index = ifelse(pred, i, min_index)
164+
i += 1
165+
end
166+
return min_value, min_index
146167
end

0 commit comments

Comments
 (0)