@@ -192,67 +192,55 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
192
192
# Push to after reduction loop
193
193
setb = Expr (:(= ), Bᵥ′, :ξ )
194
194
push! (rblock. args, setb)
195
- # # Simplest variety
196
- # postj = Expr(:(=), Cᵥ′)
197
- # if length(rinds) == 1
198
- # push!(postj.args, d == 1 ? :j_1 :
199
- # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
200
- # else
201
- # setc = Expr(:call, :+)
202
- # for d ∈ rinds
203
- # push!(setc.args, d == 1 ? :j_1 :
204
- # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
205
- # end
206
- # push!(postj.args, setc)
207
- # end
208
195
# Potential loop-carried dependency
209
196
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
210
- setc = Expr (:call , :+ )
211
- for d ∈ rinds
212
- push! (setc. args, d == 1 ? :j_1 :
213
- Expr (:call , :* , ntuple (i -> Symbol (:D_ , i), d - 1 )... , Symbol (:j_ , d)))
214
- end
215
- for d ∈ nrinds
216
- push! (setc. args, d == 1 ? :i_1 :
217
- Expr (:call , :* , ntuple (i -> Symbol (:D_ , i), d - 1 )... , Symbol (:i_ , d)))
218
- end
219
- # These complete the expression: 1 + ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)(Iₖ - 1)
220
- push! (setc. args, 1 , :Dstar )
221
- postj = Expr (:(= ), Cᵥ′, setc)
222
- push! (rblock. args, postj)
223
- # strides, offsets
224
- t = Expr (:tuple )
225
- for d = 1 : N
226
- push! (t. args, Symbol (:D_ , d))
227
- end
228
- sz = Expr (:(= ), t, Expr (:call , :size , :A ))
229
- # ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ) : 1 + D₁ + D₁D₂ + ⋯ + D₁D₂⋯Dₖ₋₁
230
- dstar = Expr (:call , :+ , 1 )
231
- for d = 2 : N
232
- push! (dstar. args, d == 2 ? :D_1 : Expr (:call , :* , ntuple (i -> Symbol (:D_ , i), d - 1 )... ))
233
- end
234
- # # One might pre-compute the unchanging components of setc, but in actuality,
235
- # it has very little effect on performance.
236
- # tl = Expr(:tuple)
237
- # tr = Expr(:tuple)
238
- # for d = 3:N
239
- # push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
240
- # push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
241
- # end
197
+ # # The less efficient version (good for visualizing, though)
242
198
# setc = Expr(:call, :+)
243
199
# for d ∈ rinds
244
200
# push!(setc.args, d == 1 ? :j_1 :
245
- # Expr(:call, :*, Symbol(:D_, ntuple(identity , d - 1)...) , Symbol(:j_, d)))
201
+ # Expr(:call, :*, ntuple(i -> Symbol(:D_, i) , d - 1)..., Symbol(:j_, d)))
246
202
# end
247
203
# for d ∈ nrinds
248
204
# push!(setc.args, d == 1 ? :i_1 :
249
- # Expr(:call, :*, Symbol(:D_, ntuple(identity , d - 1)...) , Symbol(:i_, d)))
205
+ # Expr(:call, :*, ntuple(i -> Symbol(:D_, i) , d - 1)..., Symbol(:i_, d)))
250
206
# end
207
+ # # These complete the expression: 1 + ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)(Iₖ - 1)
251
208
# push!(setc.args, 1, :Dstar)
252
209
# postj = Expr(:(=), Cᵥ′, setc)
253
210
# push!(rblock.args, postj)
211
+ # # strides, offsets
212
+ # t = Expr(:tuple)
213
+ # for d = 1:N
214
+ # push!(t.args, Symbol(:D_, d))
215
+ # end
216
+ # sz = Expr(:(=), t, Expr(:call, :size, :A))
217
+ # # ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ) : 1 + D₁ + D₁D₂ + ⋯ + D₁D₂⋯Dₖ₋₁
218
+ # dstar = Expr(:call, :+, 1)
219
+ # for d = 2:N
220
+ # push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
221
+ # end
222
+ # Version which precomputes the unchanging components of setc.
223
+ tl = Expr (:tuple )
224
+ tr = Expr (:tuple )
225
+ for d = 3 : N
226
+ push! (tl. args, Symbol (:D_ , ntuple (identity, d - 1 )... ))
227
+ push! (tr. args, Expr (:call , :* , ntuple (i -> Symbol (:D_ , i), d - 1 )... ))
228
+ end
229
+ setc = Expr (:call , :+ )
230
+ for d ∈ rinds
231
+ push! (setc. args, d == 1 ? :j_1 :
232
+ Expr (:call , :* , Symbol (:D_ , ntuple (identity, d - 1 )... ), Symbol (:j_ , d)))
233
+ end
234
+ for d ∈ nrinds
235
+ push! (setc. args, d == 1 ? :i_1 :
236
+ Expr (:call , :* , Symbol (:D_ , ntuple (identity, d - 1 )... ), Symbol (:i_ , d)))
237
+ end
238
+ push! (setc. args, 1 , :Dstar )
239
+ postj = Expr (:(= ), Cᵥ′, setc)
240
+ push! (rblock. args, postj)
254
241
return quote
255
242
$ sz
243
+ $ tl = $ tr
256
244
Dstar = $ dstar
257
245
Dstar = - Dstar
258
246
Bᵥ = $ Bᵥ
@@ -282,10 +270,14 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
282
270
setmax = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newmax , Expr (:call , :f , A), :ξ ))
283
271
push! (block. args, setmax)
284
272
for d ∈ rinds
285
- setj = Expr (:(= ), Symbol (:j_ , d), Expr (:call , :ifelse , :newmax , Symbol (:i_ , d), Symbol (:j_ , d)))
273
+ setj = Expr (:(= ), Symbol (:j_ , d),
274
+ Expr (:call , :ifelse , :newmax , Symbol (:i_ , d), Symbol (:j_ , d)))
286
275
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
287
276
push! (block. args, setj)
288
277
end
278
+ # The less efficient version (good for visualizing, though). It does not matter
279
+ # here, as this just handles the reduction over all dimensions specified via dims.
280
+ # In other words, the computation of the linear index only happens once.
289
281
# ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
290
282
setc = Expr (:call , :+ )
291
283
for d ∈ rinds
@@ -309,24 +301,6 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
309
301
for d = 2 : N
310
302
push! (dstar. args, d == 2 ? :D_1 : Expr (:call , :* , ntuple (i -> Symbol (:D_ , i), d - 1 )... ))
311
303
end
312
- # # One might pre-compute the unchanging components of setc, but in actuality,
313
- # it has very little effect on performance.
314
- # tl = Expr(:tuple)
315
- # tr = Expr(:tuple)
316
- # for d = 3:N
317
- # push!(tl.args, Symbol(:D_, ntuple(identity, d - 1)...))
318
- # push!(tr.args, Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
319
- # end
320
- # setc = Expr(:call, :+)
321
- # for d ∈ rinds
322
- # push!(setc.args, d == 1 ? :j_1 :
323
- # Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:j_, d)))
324
- # end
325
- # for d ∈ nrinds
326
- # push!(setc.args, d == 1 ? :i_1 :
327
- # Expr(:call, :*, Symbol(:D_, ntuple(identity, d - 1)...), Symbol(:i_, d)))
328
- # end
329
- # push!(setc.args, 1, :Dstar)
330
304
return quote
331
305
$ js
332
306
$ sz
@@ -793,44 +767,53 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
793
767
# Push to after reduction loop
794
768
setb = Expr (:(= ), Bᵥ′, :ξ )
795
769
push! (rblock. args, setb)
796
- # # Simplest variety
797
- # postj = Expr(:(=), Cᵥ′)
798
- # if length(rinds) == 1
799
- # push!(postj.args, d == 1 ? :j_1 :
770
+ # Potential loop-carried dependency
771
+ # ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
772
+ # # The less efficient version (good for visualizing, though)
773
+ # setc = Expr(:call, :+)
774
+ # for d ∈ rinds
775
+ # push!(setc.args, d == 1 ? :j_1 :
800
776
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
801
- # else
802
- # setc = Expr(:call, :+)
803
- # for d ∈ rinds
804
- # push!(setc.args, d == 1 ? :j_1 :
805
- # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
806
- # end
807
- # push!(postj.args, setc)
808
777
# end
809
- # Potential loop-carried dependency
778
+ # for d ∈ nrinds
779
+ # push!(setc.args, d == 1 ? :i_1 :
780
+ # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
781
+ # end
782
+ # push!(setc.args, 1, :Dstar)
783
+ # postj = Expr(:(=), Cᵥ′, setc)
784
+ # push!(rblock.args, postj)
785
+ # # strides, offsets
786
+ # t = Expr(:tuple)
787
+ # for d = 1:N
788
+ # push!(t.args, Symbol(:D_, d))
789
+ # end
790
+ # sz = Expr(:(=), t, Expr(:call, :size, :A))
791
+ # dstar = Expr(:call, :+, 1)
792
+ # for d = 2:N
793
+ # push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
794
+ # end
795
+ # Version which precomputes the unchanging components of setc.
796
+ tl = Expr (:tuple )
797
+ tr = Expr (:tuple )
798
+ for d = 3 : N
799
+ push! (tl. args, Symbol (:D_ , ntuple (identity, d - 1 )... ))
800
+ push! (tr. args, Expr (:call , :* , ntuple (i -> Symbol (:D_ , i), d - 1 )... ))
801
+ end
810
802
setc = Expr (:call , :+ )
811
803
for d ∈ rinds
812
804
push! (setc. args, d == 1 ? :j_1 :
813
- Expr (:call , :* , ntuple (i -> Symbol (:D_ , i) , d - 1 )... , Symbol (:j_ , d)))
805
+ Expr (:call , :* , Symbol (:D_ , ntuple (identity , d - 1 )... ) , Symbol (:j_ , d)))
814
806
end
815
807
for d ∈ nrinds
816
808
push! (setc. args, d == 1 ? :i_1 :
817
- Expr (:call , :* , ntuple (i -> Symbol (:D_ , i) , d - 1 )... , Symbol (:i_ , d)))
809
+ Expr (:call , :* , Symbol (:D_ , ntuple (identity , d - 1 )... ) , Symbol (:i_ , d)))
818
810
end
819
811
push! (setc. args, 1 , :Dstar )
820
812
postj = Expr (:(= ), Cᵥ′, setc)
821
813
push! (rblock. args, postj)
822
- # strides, offsets
823
- t = Expr (:tuple )
824
- for d = 1 : N
825
- push! (t. args, Symbol (:D_ , d))
826
- end
827
- sz = Expr (:(= ), t, Expr (:call , :size , :A ))
828
- dstar = Expr (:call , :+ , 1 )
829
- for d = 2 : N
830
- push! (dstar. args, d == 2 ? :D_1 : Expr (:call , :* , ntuple (i -> Symbol (:D_ , i), d - 1 )... ))
831
- end
832
814
return quote
833
815
$ sz
816
+ $ tl = $ tr
834
817
Dstar = $ dstar
835
818
Dstar = - Dstar
836
819
Bᵥ = $ Bᵥ
@@ -855,10 +838,8 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
855
838
block = newblock
856
839
end
857
840
# Push to inside innermost loop
858
- # cmpr = Expr(:(=), :newmax, Expr(:call, :(>), A, :ξ))
859
841
cmpr = Expr (:(= ), :newmax , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
860
842
push! (block. args, cmpr)
861
- # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax, A, :ξ))
862
843
setmax = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newmax , Expr (:call , :f , A), :ξ ))
863
844
push! (block. args, setmax)
864
845
for d ∈ rinds
@@ -867,6 +848,10 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
867
848
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
868
849
push! (block. args, setj)
869
850
end
851
+ # The less efficient version (good for visualizing, though). It does not matter
852
+ # here, as this just handles the reduction over all dimensions specified via dims.
853
+ # In other words, the computation of the linear index only happens once.
854
+ # ∑ₖ₌₁ᴺ(∏ᵢ₌₁ᵏ⁻¹Dᵢ)Iₖ : I₁ + D₁I₂ + D₁D₂I₃ + ⋯ + D₁D₂⋯Dₖ₋₁Iₖ
870
855
setc = Expr (:call , :+ )
871
856
for d ∈ rinds
872
857
push! (setc. args, d == 1 ? :j_1 :
0 commit comments