Skip to content

Commit 3ae00db

Browse files
committed
change find_min
1 parent cd87e31 commit 3ae00db

File tree

5 files changed

+60
-13
lines changed

5 files changed

+60
-13
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
99
CodecZstd = "6b39b394-51ab-5f42-8807-6242bab2b4c2"
1010
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1111
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
12-
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1312
LorentzVectorHEP = "f612022c-142a-473f-8cfd-a09cf3793c6c"
1413
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
14+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
1515
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1616

1717
[weakdeps]
@@ -31,10 +31,10 @@ EDM4hep = "0.4.0"
3131
EnumX = "1.0.4"
3232
JSON = "0.21"
3333
Logging = "1.9"
34-
LoopVectorization = "0.12.170"
3534
LorentzVectorHEP = "0.1.6"
3635
Makie = "0.20, 0.21, 0.22"
3736
MuladdMacro = "0.2.4"
37+
SIMD = "3.7.1"
3838
StructArrays = "0.6.18, 0.7"
3939
Test = "1.9"
4040
julia = "1.9"

src/JetReconstruction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module JetReconstruction
1818
using LorentzVectorHEP
1919
using MuladdMacro
2020
using StructArrays
21+
using SIMD
2122

2223
# Import from LorentzVectorHEP methods for those 4-vector types
2324
pt2(p::LorentzVector) = LorentzVectorHEP.pt2(p)

src/PlainAlgo.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using LoopVectorization
21

32
"""
43
dist(i, j, rapidity_array, phi_array)

src/TiledAlgoLL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
using Logging
77
using Accessors
8-
using LoopVectorization
98

109
# Include struct definitions and basic operations
1110
include("TiledAlgoLLStructs.jl")

src/Utils.jl

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,62 @@ array. The use of `@turbo` macro gives a significant performance boost.
149149
- `dij_min`: The minimum value in the first `n` elements of the `dij` array.
150150
- `best`: The index of the minimum value in the `dij` array.
151151
"""
152-
fast_findmin(dij, n) = begin
153-
# findmin(@inbounds @view dij[1:n])
154-
best = 1
155-
@inbounds dij_min = dij[1]
156-
@turbo for here in 2:n
157-
newmin = dij[here] < dij_min
158-
best = newmin ? here : best
159-
dij_min = newmin ? dij[here] : dij_min
152+
function fast_findmin end
153+
154+
if Sys.ARCH == :aarch64
155+
fast_findmin(dij, n) = _naive_fast_findmin(@view(dij[begin:n]))
156+
else
157+
function fast_findmin(dij, n)
158+
if n <= 8
159+
return _naive_fast_findmin(@view(dij[begin:n]))
160+
else
161+
return _simd_fast_findmin(dij, n)
162+
end
160163
end
161-
dij_min, best
164+
end
165+
166+
function _naive_fast_findmin(dij)
167+
x = @fastmath foldl(min, dij)
168+
i = findfirst(==(x), dij)::Int
169+
x, i
170+
end
171+
172+
function _simd_fast_findmin(dij::DenseVector{T}, n) where {T}
173+
laneIndices = SIMD.Vec{8, Int}((1, 2, 3, 4, 5, 6, 7, 8))
174+
minvals = SIMD.Vec{8, T}(Inf)
175+
min_indices = SIMD.Vec{8, Int}(0)
176+
177+
n_batches, remainder = divrem(n, 8)
178+
lane = VecRange{8}(0)
179+
i = 1
180+
@inbounds @fastmath for _ in 1:n_batches
181+
dijs = dij[lane + i]
182+
predicate = dijs < minvals
183+
minvals = vifelse(predicate, dijs, minvals)
184+
min_indices = vifelse(predicate, laneIndices, min_indices)
185+
186+
i += 8
187+
laneIndices += 8
188+
end
189+
190+
# last batch
191+
back_track = 8 - remainder
192+
i -= back_track
193+
laneIndices -= back_track
194+
195+
dijs = dij[lane + i]
196+
predicate = dijs < minvals
197+
minvals = vifelse(predicate, dijs, minvals)
198+
min_indices = vifelse(predicate, laneIndices, min_indices)
199+
200+
min_value = SIMD.minimum(minvals)
201+
min_index = @inbounds min_value == minvals[1] ? min_indices[1] :
202+
min_value == minvals[2] ? min_indices[2] :
203+
min_value == minvals[3] ? min_indices[3] :
204+
min_value == minvals[4] ? min_indices[4] :
205+
min_value == minvals[5] ? min_indices[5] :
206+
min_value == minvals[6] ? min_indices[6] :
207+
min_value == minvals[7] ? min_indices[7] : min_indices[8]
208+
209+
return min_value, min_index
162210
end

0 commit comments

Comments
 (0)