Skip to content

Commit b0f5010

Browse files
Cleanup: old version, and enforce naming
1 parent 20123ea commit b0f5010

File tree

2 files changed

+50
-286
lines changed

2 files changed

+50
-286
lines changed

src/vfindminmax.jl

Lines changed: 22 additions & 258 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,9 @@ vfindmin(A::AbstractArray, dims) = vfindminmax(identity, <, typemax, A, dims)
7676

7777
# over all dims
7878
@generated function vfindminmax(f::F, op::OP, init::I, A::AbstractArray{T, N}, ::Colon) where {F, OP, I, T, N}
79-
# fsym = F.instance
8079
opsym = OP.instance
8180
initsym = I.instance
82-
# Tₒ = promote_type(Base.promote_op(fsym, T), T) # attempt at protecting against Union{}
8381
quote
84-
# m = $initsym($Tₒ)
8582
m = $initsym(Base.promote_op(f, $T))
8683
j = 1
8784
@turbo for i eachindex(A)
@@ -179,14 +176,14 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
179176
block = newblock
180177
end
181178
# Push to inside innermost loop
182-
cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
179+
cmpr = Expr(:(=), :newm, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
183180
push!(block.args, cmpr)
184-
setmax = Expr(:(=), , Expr(:call, :ifelse, :newmax, Expr(:call, :f, A), ))
185-
push!(block.args, setmax)
181+
setm = Expr(:(=), , Expr(:call, :ifelse, :newm, Expr(:call, :f, A), ))
182+
push!(block.args, setm)
186183
for d rinds
187184
setj = Expr(:(=), Symbol(:j_, d),
188-
Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
189-
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
185+
Expr(:call, :ifelse, :newm, Symbol(:i_, d), Symbol(:j_, d)))
186+
# setj = :($(Symbol(:j_, d)) = ifelse(newm, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
190187
push!(block.args, setj)
191188
end
192189
# Push to after reduction loop
@@ -265,14 +262,14 @@ function staticdim_findminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
265262
block = newblock
266263
end
267264
# Push to inside innermost loop
268-
cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
265+
cmpr = Expr(:(=), :newm, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
269266
push!(block.args, cmpr)
270-
setmax = Expr(:(=), , Expr(:call, :ifelse, :newmax, Expr(:call, :f, A), ))
271-
push!(block.args, setmax)
267+
setm = Expr(:(=), , Expr(:call, :ifelse, :newm, Expr(:call, :f, A), ))
268+
push!(block.args, setm)
272269
for d rinds
273270
setj = Expr(:(=), Symbol(:j_, d),
274-
Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
275-
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
271+
Expr(:call, :ifelse, :newm, Symbol(:i_, d), Symbol(:j_, d)))
272+
# setj = :($(Symbol(:j_, d)) = ifelse(newm, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
276273
push!(block.args, setj)
277274
end
278275
# The less efficient version (good for visualizing, though). It does not matter
@@ -364,241 +361,9 @@ end
364361

365362
# In the case of rinds = ∅, this just corresponds to a map
366363
@generated function _vfindminmax!(f::F, op::OP, init::I, B::AbstractArray{Tₒ, N}, C::AbstractArray{Tₗ, N}, A::AbstractArray{T, N}, dims::Tuple{}) where {F, OP, I, Tₒ, Tₗ, T, N}
367-
# :(copyto!(B, A); copyto!(C, LinearIndices(A)); return B, C)
368364
:(vvmap!(f, B, A); copyto!(C, LinearIndices(A)); return B, C)
369365
end
370366

371-
372-
################
373-
# function vfindminmax2(f::F, op::OP, init::I, A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {F, OP, I, T, N, M}
374-
# Dᴬ = size(A)
375-
# Dᴮ′ = ntuple(d -> d ∈ dims ? 1 : Dᴬ[d], Val(N))
376-
# B = similar(A, Base.promote_op(f, T), Dᴮ′)
377-
# C = similar(A, Int, Dᴮ′)
378-
# _vfindminmax2!(f, op, init, B, C, A, dims)
379-
# return B, CartesianIndices(A)[C]
380-
# end
381-
382-
# function staticdim_findminmax2_quote(OP, I, static_dims::Vector{Int}, N::Int)
383-
# A = Expr(:ref, :A, ntuple(d -> Symbol(:i_, d), N)...)
384-
# Bᵥ = Expr(:call, :view, :B)
385-
# Cᵥ = Expr(:call, :view, :C)
386-
# Bᵥ′ = Expr(:ref, :Bᵥ)
387-
# Cᵥ′ = Expr(:ref, :Cᵥ)
388-
# rinds = Int[]
389-
# nrinds = Int[]
390-
# for d = 1:N
391-
# if d ∈ static_dims
392-
# push!(Bᵥ.args, Expr(:call, :firstindex, :B, d))
393-
# push!(Cᵥ.args, Expr(:call, :firstindex, :C, d))
394-
# push!(rinds, d)
395-
# else
396-
# push!(Bᵥ.args, :)
397-
# push!(Cᵥ.args, :)
398-
# push!(nrinds, d)
399-
# push!(Bᵥ′.args, Symbol(:i_, d))
400-
# push!(Cᵥ′.args, Symbol(:i_, d))
401-
# end
402-
# end
403-
# reverse!(rinds)
404-
# reverse!(nrinds)
405-
# if !isempty(nrinds)
406-
# block = Expr(:block)
407-
# loops = Expr(:for, :($(Symbol(:i_, nrinds[1])) = indices((A, B, C), $(nrinds[1]))), block)
408-
# for d ∈ @view(nrinds[2:end])
409-
# newblock = Expr(:block)
410-
# push!(block.args, Expr(:for, :($(Symbol(:i_, d)) = indices((A, B, C), $d)), newblock))
411-
# block = newblock
412-
# end
413-
# rblock = block
414-
# # Pre-reduction
415-
# ξ = Expr(:(=), :ξ, Expr(:call, Symbol(I.instance), Expr(:call, :eltype, :Bᵥ)))
416-
# push!(rblock.args, ξ)
417-
# for d ∈ rinds
418-
# push!(rblock.args, Expr(:(=), Symbol(:j_, d), Expr(:call, :one, :Int)))
419-
# end
420-
# # Reduction loop
421-
# for d ∈ rinds
422-
# newblock = Expr(:block)
423-
# push!(block.args, Expr(:for, :($(Symbol(:i_, d)) = axes(A, $d)), newblock))
424-
# block = newblock
425-
# end
426-
# # Push to inside innermost loop
427-
# # cmpr = Expr(:(=), :newmax, Expr(:call, :(>), A, :ξ))
428-
# cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance),
429-
# Expr(:call, :f, A), :ξ))
430-
# push!(block.args, cmpr)
431-
# # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax, A, :ξ))
432-
# setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax,
433-
# Expr(:call, :f, A), :ξ))
434-
# push!(block.args, setmax)
435-
# for d ∈ rinds
436-
# setj = Expr(:(=), Symbol(:j_, d),
437-
# Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
438-
# push!(block.args, setj)
439-
# end
440-
# # Push to after reduction loop
441-
# setb = Expr(:(=), Bᵥ′, :ξ)
442-
# push!(rblock.args, setb)
443-
# # # Simplest variety
444-
# # postj = Expr(:(=), Cᵥ′)
445-
# # if length(rinds) == 1
446-
# # push!(postj.args, d == 1 ? :j_1 :
447-
# # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
448-
# # else
449-
# # setc = Expr(:call, :+)
450-
# # for d ∈ rinds
451-
# # push!(setc.args, d == 1 ? :j_1 :
452-
# # Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
453-
# # end
454-
# # push!(postj.args, setc)
455-
# # end
456-
# # Potential loop-carried dependency
457-
# setc = Expr(:call, :+)
458-
# for d ∈ rinds
459-
# push!(setc.args, d == 1 ? :j_1 :
460-
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
461-
# end
462-
# for d ∈ nrinds
463-
# push!(setc.args, d == 1 ? :i_1 :
464-
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
465-
# end
466-
# push!(setc.args, 1, :Dstar)
467-
# postj = Expr(:(=), Cᵥ′, setc)
468-
# push!(rblock.args, postj)
469-
# # strides, offsets
470-
# t = Expr(:tuple)
471-
# for d = 1:N
472-
# push!(t.args, Symbol(:D_, d))
473-
# end
474-
# sz = Expr(:(=), t, Expr(:call, :size, :A))
475-
# dstar = Expr(:call, :+, 1)
476-
# for d = 2:N
477-
# push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
478-
# end
479-
# return quote
480-
# $sz
481-
# Dstar = $dstar
482-
# Dstar = -Dstar
483-
# Bᵥ = $Bᵥ
484-
# Cᵥ = $Cᵥ
485-
# @turbo $loops
486-
# return B, C
487-
# end
488-
# else
489-
# # Pre-reduction
490-
# ξ = Expr(:(=), :ξ, Expr(:call, Symbol(I.instance), Expr(:call, :eltype, :Bᵥ)))
491-
# j = Expr(:tuple)
492-
# for d = 1:N
493-
# push!(j.args, Symbol(:j_, d))
494-
# end
495-
# js = :($j = $(ntuple(_ -> 1, Val(N))))
496-
# # Reduction loop
497-
# block = Expr(:block)
498-
# loops = Expr(:for, :($(Symbol(:i_, rinds[1])) = axes(A, $(rinds[1]))), block)
499-
# for d ∈ @view(rinds[2:end])
500-
# newblock = Expr(:block)
501-
# push!(block.args, Expr(:for, :($(Symbol(:i_, d)) = axes(A, $d)), newblock))
502-
# block = newblock
503-
# end
504-
# # Push to inside innermost loop
505-
# # cmpr = Expr(:(=), :newmax, Expr(:call, :(>), A, :ξ))
506-
# cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance),
507-
# Expr(:call, :f, A), :ξ))
508-
# push!(block.args, cmpr)
509-
# # setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax, A, :ξ))
510-
# setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax,
511-
# Expr(:call, :f, A), :ξ))
512-
# push!(block.args, setmax)
513-
# for d ∈ rinds
514-
# setj = Expr(:(=), Symbol(:j_, d),
515-
# Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
516-
# push!(block.args, setj)
517-
# end
518-
# setc = Expr(:call, :+)
519-
# for d ∈ rinds
520-
# push!(setc.args, d == 1 ? :j_1 :
521-
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:j_, d)))
522-
# end
523-
# for d ∈ nrinds
524-
# push!(setc.args, d == 1 ? :i_1 :
525-
# Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)..., Symbol(:i_, d)))
526-
# end
527-
# push!(setc.args, 1, :Dstar)
528-
# # strides, offsets
529-
# t = Expr(:tuple)
530-
# for d = 1:N
531-
# push!(t.args, Symbol(:D_, d))
532-
# end
533-
# sz = Expr(:(=), t, Expr(:call, :size, :A))
534-
# dstar = Expr(:call, :+, 1)
535-
# for d = 2:N
536-
# push!(dstar.args, d == 2 ? :D_1 : Expr(:call, :*, ntuple(i -> Symbol(:D_, i), d - 1)...))
537-
# end
538-
# return quote
539-
# $js
540-
# $sz
541-
# Dstar = $dstar
542-
# Dstar = -Dstar
543-
# Bᵥ = $Bᵥ
544-
# Cᵥ = $Cᵥ
545-
#
546-
# @turbo $loops
547-
# Bᵥ[] = ξ
548-
# Cᵥ[] = $setc
549-
# return B, C
550-
# end
551-
# end
552-
# end
553-
554-
# function branches_findminmax2_quote(OP, I, N::Int, M::Int, D)
555-
# static_dims = Int[]
556-
# for m ∈ 1:M
557-
# param = D.parameters[m]
558-
# if param <: StaticInt
559-
# new_dim = _dim(param)::Int
560-
# push!(static_dims, new_dim)
561-
# else
562-
# # tuple of static dimensions
563-
# t = Expr(:tuple)
564-
# for n ∈ static_dims
565-
# push!(t.args, :(StaticInt{$n}()))
566-
# end
567-
# q = Expr(:block, :(dimm = dims[$m]))
568-
# qold = q
569-
# # if-elseif statements
570-
# ifsym = :if
571-
# for n ∈ 1:N
572-
# n ∈ static_dims && continue
573-
# tc = copy(t)
574-
# push!(tc.args, :(StaticInt{$n}()))
575-
# qnew = Expr(ifsym, :(dimm == $n), :(return _vfindminmax2!(f, op, init, B, C, A, $tc)))
576-
# for r ∈ m+1:M
577-
# push!(tc.args, :(dims[$r]))
578-
# end
579-
# push!(qold.args, qnew)
580-
# qold = qnew
581-
# ifsym = :elseif
582-
# end
583-
# # else statement
584-
# tc = copy(t)
585-
# for r ∈ m+1:M
586-
# push!(tc.args, :(dims[$r]))
587-
# end
588-
# push!(qold.args, Expr(:block, :(return _vfindminmax2!(f, op, init, B, C, A, $tc))))
589-
# return q
590-
# end
591-
# end
592-
# return staticdim_findminmax2_quote(OP, I, static_dims, N)
593-
# end
594-
595-
# @generated function _vfindminmax2!(f::F, op::OP, init::I, B::AbstractArray{Tₒ, N}, C::AbstractArray{Tₗ, N}, A::AbstractArray{T, N}, dims::D) where {F, OP, I, Tₒ, Tₗ, T, N, M, D<:Tuple{Vararg{Integer, M}}}
596-
# branches_findminmax2_quote(OP, I, N, M, D)
597-
# end
598-
# @generated function _vfindminmax2!(f::F, op::OP, init::I, B::AbstractArray{Tₒ, N}, C::AbstractArray{Tₗ, N}, A::AbstractArray{T, N}, dims::Tuple{}) where {F, OP, I, Tₒ, Tₗ, T, N}
599-
# :(copyto!(B, A); copyto!(C, LinearIndices(A)); return B, C)
600-
# end
601-
602367
############################################################################################
603368
function vtfindminmax(f::F, op::OP, init::I, A::AbstractArray{T, N}, dims::NTuple{M, Int}) where {F, OP, I, T, N, M}
604369
Dᴬ = size(A)
@@ -752,16 +517,16 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
752517
block = newblock
753518
end
754519
# Push to inside innermost loop
755-
# cmpr = Expr(:(=), :newmax, Expr(:call, :(>), A, :ξ))
756-
cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
520+
# cmpr = Expr(:(=), :newm, Expr(:call, :(>), A, :ξ))
521+
cmpr = Expr(:(=), :newm, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
757522
push!(block.args, cmpr)
758-
# setmax = Expr(:(=), :ξ, Expr(:call, :ifelse, :newmax, A, :ξ))
759-
setmax = Expr(:(=), , Expr(:call, :ifelse, :newmax, Expr(:call, :f, A), ))
760-
push!(block.args, setmax)
523+
# setm = Expr(:(=), :ξ, Expr(:call, :ifelse, :newm, A, :ξ))
524+
setm = Expr(:(=), , Expr(:call, :ifelse, :newm, Expr(:call, :f, A), ))
525+
push!(block.args, setm)
761526
for d rinds
762527
setj = Expr(:(=), Symbol(:j_, d),
763-
Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
764-
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
528+
Expr(:call, :ifelse, :newm, Symbol(:i_, d), Symbol(:j_, d)))
529+
# setj = :($(Symbol(:j_, d)) = ifelse(newm, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
765530
push!(block.args, setj)
766531
end
767532
# Push to after reduction loop
@@ -838,14 +603,14 @@ function staticdim_tfindminmax_quote(OP, I, static_dims::Vector{Int}, N::Int)
838603
block = newblock
839604
end
840605
# Push to inside innermost loop
841-
cmpr = Expr(:(=), :newmax, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
606+
cmpr = Expr(:(=), :newm, Expr(:call, Symbol(OP.instance), Expr(:call, :f, A), ))
842607
push!(block.args, cmpr)
843-
setmax = Expr(:(=), , Expr(:call, :ifelse, :newmax, Expr(:call, :f, A), ))
844-
push!(block.args, setmax)
608+
setm = Expr(:(=), , Expr(:call, :ifelse, :newm, Expr(:call, :f, A), ))
609+
push!(block.args, setm)
845610
for d rinds
846611
setj = Expr(:(=), Symbol(:j_, d),
847-
Expr(:call, :ifelse, :newmax, Symbol(:i_, d), Symbol(:j_, d)))
848-
# setj = :($(Symbol(:j_, d)) = ifelse(newmax, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
612+
Expr(:call, :ifelse, :newm, Symbol(:i_, d), Symbol(:j_, d)))
613+
# setj = :($(Symbol(:j_, d)) = ifelse(newm, $(Symbol(:i_, d)), $(Symbol(:j_, d))))
849614
push!(block.args, setj)
850615
end
851616
# The less efficient version (good for visualizing, though). It does not matter
@@ -933,6 +698,5 @@ end
933698
branches_tfindminmax_quote(OP, I, N, M, D)
934699
end
935700
@generated function _vtfindminmax!(f::F, op::OP, init::I, B::AbstractArray{Tₒ, N}, C::AbstractArray{Tₗ, N}, A::AbstractArray{T, N}, dims::Tuple{}) where {F, OP, I, Tₒ, Tₗ, T, N}
936-
# :(copyto!(B, A); copyto!(C, LinearIndices(A)); return B, C)
937701
:(vtmap!(f, B, A); copyto!(C, LinearIndices(A)); return B, C)
938702
end

0 commit comments

Comments
 (0)