Skip to content

Commit 6cb2c58

Browse files
Switch findmin and friends over to more efficient linear index
1 parent b1469bf commit 6cb2c58

File tree

2 files changed

+97
-120
lines changed

2 files changed

+97
-120
lines changed

src/vfindminmax.jl

Lines changed: 77 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -192,67 +192,55 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
192192
# Push to after reduction loop
193193
setb = Expr(:(=), Bᵥ′, )
194194
push!(rblock.args, setb)
195-
# # Simplest variety
196-
# postj = Expr(:(=), Cᵥ′)
197-
# if length(rinds) == 1
198-
# push!(postj.args, d == 1 ? :j_1 :
199-
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
200-
# else
201-
# setc = Expr(:call, :+)
202-
# for d ∈ rinds
203-
# push!(setc.args, d == 1 ? :j_1 :
204-
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
205-
# end
206-
# push!(postj.args, setc)
207-
# end
208195
# Potential loop-carried dependency
209196
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
210-
setc = Expr(:call, :+)
211-
for d rinds
212-
push!(setc.args, d == 1 ? :j_1 :
213-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
214-
end
215-
for d nrinds
216-
push!(setc.args, d == 1 ? :i_1 :
217-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
218-
end
219-
# These complete the expression: 1 + ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)(Iₖ - 1)
220-
push!(setc.args, 1, :Dstar)
221-
postj = Expr(:(=), Cᵥ′, setc)
222-
push!(rblock.args, postj)
223-
# strides, offsets
224-
t = Expr(:tuple)
225-
for d = 1:N
226-
push!(t.args, Symbol(:D_, d))
227-
end
228-
sz = Expr(:(=), t, Expr(:call, :size, :A))
229-
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ) : 1 + D₁ + D₁D₂ + ⋯ + D₁D₂⋯Dₖ₋₁
230-
dstar = Expr(:call, :+, 1)
231-
for d = 2:N
232-
push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
233-
end
234-
# # One might pre-compute the unchanging components of setc, but in actuality,
235-
# it has very little effect on performance.
236-
# tl = Expr(:tuple)
237-
# tr = Expr(:tuple)
238-
# for d = 3:N
239-
# push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
240-
# push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
241-
# end
197+
# # The less efficient version (good for visualizing, though)
242198
# setc = Expr(:call, :+)
243199
# for d ∈ rinds
244200
# push!(setc.args, d == 1 ? :j_1 :
245-
# Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:j_, d)))
201+
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
246202
# end
247203
# for d ∈ nrinds
248204
# push!(setc.args, d == 1 ? :i_1 :
249-
# Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:i_, d)))
205+
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
250206
# end
207+
# # These complete the expression: 1 + ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)(Iₖ - 1)
251208
# push!(setc.args, 1, :Dstar)
252209
# postj = Expr(:(=), Cᵥ′, setc)
253210
# push!(rblock.args, postj)
211+
# # strides, offsets
212+
# t = Expr(:tuple)
213+
# for d = 1:N
214+
# push!(t.args, Symbol(:D_, d))
215+
# end
216+
# sz = Expr(:(=), t, Expr(:call, :size, :A))
217+
# # ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ) : 1 + D₁ + D₁D₂ + ⋯ + D₁D₂⋯Dₖ₋₁
218+
# dstar = Expr(:call, :+, 1)
219+
# for d = 2:N
220+
# push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
221+
# end
222+
# Version which precomputes the unchanging components of setc.
223+
tl = Expr(:tuple)
224+
tr = Expr(:tuple)
225+
for d = 3:N
226+
push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
227+
push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
228+
end
229+
setc = Expr(:call, :+)
230+
for d rinds
231+
push!(setc.args, d == 1 ? :j_1 :
232+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:j_, d)))
233+
end
234+
for d nrinds
235+
push!(setc.args, d == 1 ? :i_1 :
236+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:i_, d)))
237+
end
238+
push!(setc.args, 1, :Dstar)
239+
postj = Expr(:(=), Cᵥ′, setc)
240+
push!(rblock.args, postj)
254241
return quote
255242
$sz
243+
$tl = $tr
256244
Dstar = $dstar
257245
Dstar = -Dstar
258246
Bᵥ = $Bᵥ
@@ -282,10 +270,14 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
282270
setmax = Expr(:(=), , Expr(:call, :ifelse, :newmax, Expr(:call, :f, A), ))
283271
push!(block.args, setmax)
284272
for d rinds
285-
setj = Expr(:(=), Symbol(:j_, d), Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
273+
setj = Expr(:(=), Symbol(:j_, d),
274+
Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
286275
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
287276
push!(block.args, setj)
288277
end
278+
# The less efficient version (good for visualizing, though). It does not matter
279+
# here, as this just handles the reduction over all dimensions specified via dims.
280+
# In other words, the computation of the linear index only happens once.
289281
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
290282
setc = Expr(:call, :+)
291283
for d rinds
@@ -309,24 +301,6 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
309301
for d = 2:N
310302
push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
311303
end
312-
# # One might pre-compute the unchanging components of setc, but in actuality,
313-
# it has very little effect on performance.
314-
# tl = Expr(:tuple)
315-
# tr = Expr(:tuple)
316-
# for d = 3:N
317-
# push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
318-
# push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
319-
# end
320-
# setc = Expr(:call, :+)
321-
# for d ∈ rinds
322-
# push!(setc.args, d == 1 ? :j_1 :
323-
# Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:j_, d)))
324-
# end
325-
# for d ∈ nrinds
326-
# push!(setc.args, d == 1 ? :i_1 :
327-
# Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:i_, d)))
328-
# end
329-
# push!(setc.args, 1, :Dstar)
330304
return quote
331305
$js
332306
$sz
@@ -793,44 +767,53 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
793767
# Push to after reduction loop
794768
setb = Expr(:(=), Bᵥ′, )
795769
push!(rblock.args, setb)
796-
# # Simplest variety
797-
# postj = Expr(:(=), Cᵥ′)
798-
# if length(rinds) == 1
799-
# push!(postj.args, d == 1 ? :j_1 :
770+
# Potential loop-carried dependency
771+
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
772+
# # The less efficient version (good for visualizing, though)
773+
# setc = Expr(:call, :+)
774+
# for d ∈ rinds
775+
# push!(setc.args, d == 1 ? :j_1 :
800776
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
801-
# else
802-
# setc = Expr(:call, :+)
803-
# for d ∈ rinds
804-
# push!(setc.args, d == 1 ? :j_1 :
805-
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
806-
# end
807-
# push!(postj.args, setc)
808777
# end
809-
# Potential loop-carried dependency
778+
# for d ∈ nrinds
779+
# push!(setc.args, d == 1 ? :i_1 :
780+
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
781+
# end
782+
# push!(setc.args, 1, :Dstar)
783+
# postj = Expr(:(=), Cᵥ′, setc)
784+
# push!(rblock.args, postj)
785+
# # strides, offsets
786+
# t = Expr(:tuple)
787+
# for d = 1:N
788+
# push!(t.args, Symbol(:D_, d))
789+
# end
790+
# sz = Expr(:(=), t, Expr(:call, :size, :A))
791+
# dstar = Expr(:call, :+, 1)
792+
# for d = 2:N
793+
# push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
794+
# end
795+
# Version which precomputes the unchanging components of setc.
796+
tl = Expr(:tuple)
797+
tr = Expr(:tuple)
798+
for d = 3:N
799+
push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
800+
push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
801+
end
810802
setc = Expr(:call, :+)
811803
for d rinds
812804
push!(setc.args, d == 1 ? :j_1 :
813-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
805+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:j_, d)))
814806
end
815807
for d nrinds
816808
push!(setc.args, d == 1 ? :i_1 :
817-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
809+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:i_, d)))
818810
end
819811
push!(setc.args, 1, :Dstar)
820812
postj = Expr(:(=), Cᵥ′, setc)
821813
push!(rblock.args, postj)
822-
# strides, offsets
823-
t = Expr(:tuple)
824-
for d = 1:N
825-
push!(t.args, Symbol(:D_, d))
826-
end
827-
sz = Expr(:(=), t, Expr(:call, :size, :A))
828-
dstar = Expr(:call, :+, 1)
829-
for d = 2:N
830-
push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
831-
end
832814
return quote
833815
$sz
816+
$tl = $tr
834817
Dstar = $dstar
835818
Dstar = -Dstar
836819
Bᵥ = $Bᵥ
@@ -855,10 +838,8 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
855838
block = newblock
856839
end
857840
# Push to inside innermost loop
858-
# cmpr = Expr(:(=), :newmax, Expr(:call, :(>), A, :ξ))
859841
cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
860842
push!(block.args, cmpr)
861-
# setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax, A, :ξ))
862843
setmax = Expr(:(=), , Expr(:call, :ifelse, :newmax, Expr(:call, :f, A), ))
863844
push!(block.args, setmax)
864845
for d rinds
@@ -867,6 +848,10 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
867848
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
868849
push!(block.args, setj)
869850
end
851+
# The less efficient version (good for visualizing, though). It does not matter
852+
# here, as this just handles the reduction over all dimensions specified via dims.
853+
# In other words, the computation of the linear index only happens once.
854+
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
870855
setc = Expr(:call, :+)
871856
for d rinds
872857
push!(setc.args, d == 1 ? :j_1 :

src/vfindminmax_vararg.jl

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -170,33 +170,29 @@ function staticdim_findminmax_vararg_quote(OP, I, static_dims::Vector{Int}, N::I
170170
push!(rblock.args, setb)
171171
# Potential loop-carried dependency
172172
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
173+
# Version which precomputes the unchanging components of setc.
174+
tl = Expr(:tuple)
175+
tr = Expr(:tuple)
176+
for d = 3:N
177+
push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
178+
push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
179+
end
173180
setc = Expr(:call, :+)
174181
for d rinds
175182
push!(setc.args, d == 1 ? :j_1 :
176-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
183+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:j_, d)))
177184
end
178185
for d nrinds
179186
push!(setc.args, d == 1 ? :i_1 :
180-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
187+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:i_, d)))
181188
end
182-
# These complete the expression: 1 + ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)(Iₖ - 1)
183189
push!(setc.args, 1, :Dstar)
184190
postj = Expr(:(=), Cᵥ′, setc)
185191
push!(rblock.args, postj)
186-
# strides, offsets
187-
td = Expr(:tuple)
188-
for d = 1:N
189-
push!(td.args, Symbol(:D_, d))
190-
end
191-
sz = Expr(:(=), td, Expr(:call, :size, :A_1))
192-
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ) : 1 + D₁ + D₁D₂ + ⋯ + D₁D₂⋯Dₖ₋₁
193-
dstar = Expr(:call, :+, 1)
194-
for d = 2:N
195-
push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
196-
end
197192
return quote
198193
$t = As
199194
$sz
195+
$tl = $tr
200196
Dstar = $dstar
201197
Dstar = -Dstar
202198
Bᵥ = $Bᵥ
@@ -474,33 +470,29 @@ function staticdim_tfindminmax_vararg_quote(OP, I, static_dims::Vector{Int}, N::
474470
push!(rblock.args, setb)
475471
# Potential loop-carried dependency
476472
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
473+
# Version which precomputes the unchanging components of setc.
474+
tl = Expr(:tuple)
475+
tr = Expr(:tuple)
476+
for d = 3:N
477+
push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
478+
push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
479+
end
477480
setc = Expr(:call, :+)
478481
for d rinds
479482
push!(setc.args, d == 1 ? :j_1 :
480-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
483+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:j_, d)))
481484
end
482485
for d nrinds
483486
push!(setc.args, d == 1 ? :i_1 :
484-
Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
487+
Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:i_, d)))
485488
end
486-
# These complete the expression: 1 + ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)(Iₖ - 1)
487489
push!(setc.args, 1, :Dstar)
488490
postj = Expr(:(=), Cᵥ′, setc)
489491
push!(rblock.args, postj)
490-
# strides, offsets
491-
td = Expr(:tuple)
492-
for d = 1:N
493-
push!(td.args, Symbol(:D_, d))
494-
end
495-
sz = Expr(:(=), td, Expr(:call, :size, :A_1))
496-
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ) : 1 + D₁ + D₁D₂ + ⋯ + D₁D₂⋯Dₖ₋₁
497-
dstar = Expr(:call, :+, 1)
498-
for d = 2:N
499-
push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
500-
end
501492
return quote
502493
$t = As
503494
$sz
495+
$tl = $tr
504496
Dstar = $dstar
505497
Dstar = -Dstar
506498
Bᵥ = $Bᵥ

0 commit comments

Comments
 (0)