@@ -76,12 +76,9 @@ vfindmin(A::AbstractArray, dims) = vfindminmax(identity, <, typemax, A, dims)
76
76
77
77
# over all dims
78
78
@generated function vfindminmax (f:: F , op:: OP , init:: I , A:: AbstractArray{T, N} , :: Colon ) where {F, OP, I, T, N}
79
- # fsym = F.instance
80
79
opsym = OP. instance
81
80
initsym = I. instance
82
- # Tₒ = promote_type(Base.promote_op(fsym, T), T) # attempt at protecting against Union{}
83
81
quote
84
- # m = $initsym($Tₒ)
85
82
m = $ initsym (Base. promote_op (f, $ T))
86
83
j = 1
87
84
@turbo for i ∈ eachindex (A)
@@ -179,14 +176,14 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
179
176
block = newblock
180
177
end
181
178
# Push to inside innermost loop
182
- cmpr = Expr (:(= ), :newmax , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
179
+ cmpr = Expr (:(= ), :newm , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
183
180
push! (block. args, cmpr)
184
- setmax = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newmax , Expr (:call , :f , A), :ξ ))
185
- push! (block. args, setmax )
181
+ setm = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newm , Expr (:call , :f , A), :ξ ))
182
+ push! (block. args, setm )
186
183
for d ∈ rinds
187
184
setj = Expr (:(= ), Symbol (:j_ , d),
188
- Expr (:call , :ifelse , :newmax , Symbol (:i_ , d), Symbol (:j_ , d)))
189
- # setj = :($(Symbol(:j_, d)) = ifelse(newmax , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
185
+ Expr (:call , :ifelse , :newm , Symbol (:i_ , d), Symbol (:j_ , d)))
186
+ # setj = :($(Symbol(:j_, d)) = ifelse(newm , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
190
187
push! (block. args, setj)
191
188
end
192
189
# Push to after reduction loop
@@ -265,14 +262,14 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
265
262
block = newblock
266
263
end
267
264
# Push to inside innermost loop
268
- cmpr = Expr (:(= ), :newmax , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
265
+ cmpr = Expr (:(= ), :newm , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
269
266
push! (block. args, cmpr)
270
- setmax = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newmax , Expr (:call , :f , A), :ξ ))
271
- push! (block. args, setmax )
267
+ setm = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newm , Expr (:call , :f , A), :ξ ))
268
+ push! (block. args, setm )
272
269
for d ∈ rinds
273
270
setj = Expr (:(= ), Symbol (:j_ , d),
274
- Expr (:call , :ifelse , :newmax , Symbol (:i_ , d), Symbol (:j_ , d)))
275
- # setj = :($(Symbol(:j_, d)) = ifelse(newmax , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
271
+ Expr (:call , :ifelse , :newm , Symbol (:i_ , d), Symbol (:j_ , d)))
272
+ # setj = :($(Symbol(:j_, d)) = ifelse(newm , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
276
273
push! (block. args, setj)
277
274
end
278
275
# The less efficient version (good for visualizing, though). It does not matter
@@ -364,241 +361,9 @@ end
364
361
365
362
# In the case of rinds = ∅, this just corresponds to a map
366
363
@generated function _vfindminmax! (f:: F , op:: OP , init:: I , B:: AbstractArray{Tₒ, N} , C:: AbstractArray{Tₗ, N} , A:: AbstractArray{T, N} , dims:: Tuple{} ) where {F, OP, I, Tₒ, Tₗ, T, N}
367
- # :(copyto!(B, A); copyto!(C, LinearIndices(A)); return B, C)
368
364
:(vvmap! (f, B, A); copyto! (C, LinearIndices (A)); return B, C)
369
365
end
370
366
371
-
372
- # ###############
373
- # function vfindminmax2(f::F, op::OP, init::I, A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {F, OP, I, T, N, M}
374
- # Dᴬ = size(A)
375
- # Dᴮ′ = ntuple(d -> d ∈ dims ? 1 : Dᴬ[d], Val(N))
376
- # B = similar(A, Base.promote_op(f, T), Dᴮ′)
377
- # C = similar(A, Int, Dᴮ′)
378
- # _vfindminmax2!(f, op, init, B, C, A, dims)
379
- # return B, CartesianIndices(A)[C]
380
- # end
381
-
382
- # function staticdim_findminmax2_quote(OP, I, static_dims::Vector{Int}, N::Int)
383
- # A = Expr(:ref, :A, ntuple(d -> Symbol(:i_, d), N)...)
384
- # Bᵥ = Expr(:call, :view, :B)
385
- # Cᵥ = Expr(:call, :view, :C)
386
- # Bᵥ′ = Expr(:ref, :Bᵥ)
387
- # Cᵥ′ = Expr(:ref, :Cᵥ)
388
- # rinds = Int[]
389
- # nrinds = Int[]
390
- # for d = 1:N
391
- # if d ∈ static_dims
392
- # push!(Bᵥ.args, Expr(:call, :firstindex, :B, d))
393
- # push!(Cᵥ.args, Expr(:call, :firstindex, :C, d))
394
- # push!(rinds, d)
395
- # else
396
- # push!(Bᵥ.args, :)
397
- # push!(Cᵥ.args, :)
398
- # push!(nrinds, d)
399
- # push!(Bᵥ′.args, Symbol(:i_, d))
400
- # push!(Cᵥ′.args, Symbol(:i_, d))
401
- # end
402
- # end
403
- # reverse!(rinds)
404
- # reverse!(nrinds)
405
- # if !isempty(nrinds)
406
- # block = Expr(:block)
407
- # loops = Expr(:for, :($(Symbol(:i_, nrinds[1])) = indices((A, B, C), $(nrinds[1]))), block)
408
- # for d ∈ @view(nrinds[2:end])
409
- # newblock = Expr(:block)
410
- # push!(block.args, Expr(:for, :($(Symbol(:i_, d)) = indices((A, B, C), $d)), newblock))
411
- # block = newblock
412
- # end
413
- # rblock = block
414
- # # Pre-reduction
415
- # ξ = Expr(:(=), :ξ, Expr(:call, Symbol(I.instance), Expr(:call, :eltype, :Bᵥ)))
416
- # push!(rblock.args, ξ)
417
- # for d ∈ rinds
418
- # push!(rblock.args, Expr(:(=), Symbol(:j_, d), Expr(:call, :one, :Int)))
419
- # end
420
- # # Reduction loop
421
- # for d ∈ rinds
422
- # newblock = Expr(:block)
423
- # push!(block.args, Expr(:for, :($(Symbol(:i_, d)) = axes(A, $d)), newblock))
424
- # block = newblock
425
- # end
426
- # # Push to inside innermost loop
427
- # # cmpr = Expr(:(=), :newmax, Expr(:call, :(>), A, :ξ))
428
- # cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance),
429
- # Expr(:call, :f, A), :ξ))
430
- # push!(block.args, cmpr)
431
- # # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax, A, :ξ))
432
- # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax,
433
- # Expr(:call, :f, A), :ξ))
434
- # push!(block.args, setmax)
435
- # for d ∈ rinds
436
- # setj = Expr(:(=), Symbol(:j_, d),
437
- # Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
438
- # push!(block.args, setj)
439
- # end
440
- # # Push to after reduction loop
441
- # setb = Expr(:(=), Bᵥ′, :ξ)
442
- # push!(rblock.args, setb)
443
- # # # Simplest variety
444
- # # postj = Expr(:(=), Cᵥ′)
445
- # # if length(rinds) == 1
446
- # # push!(postj.args, d == 1 ? :j_1 :
447
- # # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
448
- # # else
449
- # # setc = Expr(:call, :+)
450
- # # for d ∈ rinds
451
- # # push!(setc.args, d == 1 ? :j_1 :
452
- # # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
453
- # # end
454
- # # push!(postj.args, setc)
455
- # # end
456
- # # Potential loop-carried dependency
457
- # setc = Expr(:call, :+)
458
- # for d ∈ rinds
459
- # push!(setc.args, d == 1 ? :j_1 :
460
- # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
461
- # end
462
- # for d ∈ nrinds
463
- # push!(setc.args, d == 1 ? :i_1 :
464
- # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
465
- # end
466
- # push!(setc.args, 1, :Dstar)
467
- # postj = Expr(:(=), Cᵥ′, setc)
468
- # push!(rblock.args, postj)
469
- # # strides, offsets
470
- # t = Expr(:tuple)
471
- # for d = 1:N
472
- # push!(t.args, Symbol(:D_, d))
473
- # end
474
- # sz = Expr(:(=), t, Expr(:call, :size, :A))
475
- # dstar = Expr(:call, :+, 1)
476
- # for d = 2:N
477
- # push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
478
- # end
479
- # return quote
480
- # $sz
481
- # Dstar = $dstar
482
- # Dstar = -Dstar
483
- # Bᵥ = $Bᵥ
484
- # Cᵥ = $Cᵥ
485
- # @turbo $loops
486
- # return B, C
487
- # end
488
- # else
489
- # # Pre-reduction
490
- # ξ = Expr(:(=), :ξ, Expr(:call, Symbol(I.instance), Expr(:call, :eltype, :Bᵥ)))
491
- # j = Expr(:tuple)
492
- # for d = 1:N
493
- # push!(j.args, Symbol(:j_, d))
494
- # end
495
- # js = :($j = $(ntuple(_ -> 1, Val(N))))
496
- # # Reduction loop
497
- # block = Expr(:block)
498
- # loops = Expr(:for, :($(Symbol(:i_, rinds[1])) = axes(A, $(rinds[1]))), block)
499
- # for d ∈ @view(rinds[2:end])
500
- # newblock = Expr(:block)
501
- # push!(block.args, Expr(:for, :($(Symbol(:i_, d)) = axes(A, $d)), newblock))
502
- # block = newblock
503
- # end
504
- # # Push to inside innermost loop
505
- # # cmpr = Expr(:(=), :newmax, Expr(:call, :(>), A, :ξ))
506
- # cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance),
507
- # Expr(:call, :f, A), :ξ))
508
- # push!(block.args, cmpr)
509
- # # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax, A, :ξ))
510
- # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax,
511
- # Expr(:call, :f, A), :ξ))
512
- # push!(block.args, setmax)
513
- # for d ∈ rinds
514
- # setj = Expr(:(=), Symbol(:j_, d),
515
- # Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
516
- # push!(block.args, setj)
517
- # end
518
- # setc = Expr(:call, :+)
519
- # for d ∈ rinds
520
- # push!(setc.args, d == 1 ? :j_1 :
521
- # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
522
- # end
523
- # for d ∈ nrinds
524
- # push!(setc.args, d == 1 ? :i_1 :
525
- # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
526
- # end
527
- # push!(setc.args, 1, :Dstar)
528
- # # strides, offsets
529
- # t = Expr(:tuple)
530
- # for d = 1:N
531
- # push!(t.args, Symbol(:D_, d))
532
- # end
533
- # sz = Expr(:(=), t, Expr(:call, :size, :A))
534
- # dstar = Expr(:call, :+, 1)
535
- # for d = 2:N
536
- # push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
537
- # end
538
- # return quote
539
- # $js
540
- # $sz
541
- # Dstar = $dstar
542
- # Dstar = -Dstar
543
- # Bᵥ = $Bᵥ
544
- # Cᵥ = $Cᵥ
545
- # $ξ
546
- # @turbo $loops
547
- # Bᵥ[] = ξ
548
- # Cᵥ[] = $setc
549
- # return B, C
550
- # end
551
- # end
552
- # end
553
-
554
- # function branches_findminmax2_quote(OP, I, N::Int, M::Int, D)
555
- # static_dims = Int[]
556
- # for m ∈ 1:M
557
- # param = D.parameters[m]
558
- # if param <: StaticInt
559
- # new_dim = _dim(param)::Int
560
- # push!(static_dims, new_dim)
561
- # else
562
- # # tuple of static dimensions
563
- # t = Expr(:tuple)
564
- # for n ∈ static_dims
565
- # push!(t.args, :(StaticInt{$n}()))
566
- # end
567
- # q = Expr(:block, :(dimm = dims[$m]))
568
- # qold = q
569
- # # if-elseif statements
570
- # ifsym = :if
571
- # for n ∈ 1:N
572
- # n ∈ static_dims && continue
573
- # tc = copy(t)
574
- # push!(tc.args, :(StaticInt{$n}()))
575
- # qnew = Expr(ifsym, :(dimm == $n), :(return _vfindminmax2!(f, op, init, B, C, A, $tc)))
576
- # for r ∈ m+1:M
577
- # push!(tc.args, :(dims[$r]))
578
- # end
579
- # push!(qold.args, qnew)
580
- # qold = qnew
581
- # ifsym = :elseif
582
- # end
583
- # # else statement
584
- # tc = copy(t)
585
- # for r ∈ m+1:M
586
- # push!(tc.args, :(dims[$r]))
587
- # end
588
- # push!(qold.args, Expr(:block, :(return _vfindminmax2!(f, op, init, B, C, A, $tc))))
589
- # return q
590
- # end
591
- # end
592
- # return staticdim_findminmax2_quote(OP, I, static_dims, N)
593
- # end
594
-
595
- # @generated function _vfindminmax2!(f::F, op::OP, init::I, B::AbstractArray{Tₒ, N}, C::AbstractArray{Tₗ, N}, A::AbstractArray{T, N}, dims::D) where {F, OP, I, Tₒ, Tₗ, T, N, M, D<:Tuple{Vararg{Integer, M}}}
596
- # branches_findminmax2_quote(OP, I, N, M, D)
597
- # end
598
- # @generated function _vfindminmax2!(f::F, op::OP, init::I, B::AbstractArray{Tₒ, N}, C::AbstractArray{Tₗ, N}, A::AbstractArray{T, N}, dims::Tuple{}) where {F, OP, I, Tₒ, Tₗ, T, N}
599
- # :(copyto!(B, A); copyto!(C, LinearIndices(A)); return B, C)
600
- # end
601
-
602
367
# ###########################################################################################
603
368
function vtfindminmax (f:: F , op:: OP , init:: I , A:: AbstractArray{T, N} , dims:: NTuple{M, Int} ) where {F, OP, I, T, N, M}
604
369
Dᴬ = size (A)
@@ -752,16 +517,16 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
752
517
block = newblock
753
518
end
754
519
# Push to inside innermost loop
755
- # cmpr = Expr(:(=), :newmax , Expr(:call, :(>), A, :ξ))
756
- cmpr = Expr (:(= ), :newmax , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
520
+ # cmpr = Expr(:(=), :newm , Expr(:call, :(>), A, :ξ))
521
+ cmpr = Expr (:(= ), :newm , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
757
522
push! (block. args, cmpr)
758
- # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax , A, :ξ))
759
- setmax = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newmax , Expr (:call , :f , A), :ξ ))
760
- push! (block. args, setmax )
523
+ # setm = Expr(:(=), :ξ, Expr(:call, :ifelse, :newm , A, :ξ))
524
+ setm = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newm , Expr (:call , :f , A), :ξ ))
525
+ push! (block. args, setm )
761
526
for d ∈ rinds
762
527
setj = Expr (:(= ), Symbol (:j_ , d),
763
- Expr (:call , :ifelse , :newmax , Symbol (:i_ , d), Symbol (:j_ , d)))
764
- # setj = :($(Symbol(:j_, d)) = ifelse(newmax , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
528
+ Expr (:call , :ifelse , :newm , Symbol (:i_ , d), Symbol (:j_ , d)))
529
+ # setj = :($(Symbol(:j_, d)) = ifelse(newm , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
765
530
push! (block. args, setj)
766
531
end
767
532
# Push to after reduction loop
@@ -838,14 +603,14 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
838
603
block = newblock
839
604
end
840
605
# Push to inside innermost loop
841
- cmpr = Expr (:(= ), :newmax , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
606
+ cmpr = Expr (:(= ), :newm , Expr (:call , Symbol (OP. instance), Expr (:call , :f , A), :ξ ))
842
607
push! (block. args, cmpr)
843
- setmax = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newmax , Expr (:call , :f , A), :ξ ))
844
- push! (block. args, setmax )
608
+ setm = Expr (:(= ), :ξ , Expr (:call , :ifelse , :newm , Expr (:call , :f , A), :ξ ))
609
+ push! (block. args, setm )
845
610
for d ∈ rinds
846
611
setj = Expr (:(= ), Symbol (:j_ , d),
847
- Expr (:call , :ifelse , :newmax , Symbol (:i_ , d), Symbol (:j_ , d)))
848
- # setj = :($(Symbol(:j_, d)) = ifelse(newmax , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
612
+ Expr (:call , :ifelse , :newm , Symbol (:i_ , d), Symbol (:j_ , d)))
613
+ # setj = :($(Symbol(:j_, d)) = ifelse(newm , $(Symbol(:i_, d)), $(Symbol(:j_, d))))
849
614
push! (block. args, setj)
850
615
end
851
616
# The less efficient version (good for visualizing, though). It does not matter
933
698
branches_tfindminmax_quote (OP, I, N, M, D)
934
699
end
935
700
@generated function _vtfindminmax! (f:: F , op:: OP , init:: I , B:: AbstractArray{Tₒ, N} , C:: AbstractArray{Tₗ, N} , A:: AbstractArray{T, N} , dims:: Tuple{} ) where {F, OP, I, Tₒ, Tₗ, T, N}
936
- # :(copyto!(B, A); copyto!(C, LinearIndices(A)); return B, C)
937
701
:(vtmap! (f, B, A); copyto! (C, LinearIndices (A)); return B, C)
938
702
end
0 commit comments