@@ -138,10 +138,11 @@ struct OperationStruct <: AbstractLoopOperation
138
138
loopdeps:: UInt128
139
139
reduceddeps:: UInt128
140
140
childdeps:: UInt128
141
- parents:: UInt128
141
+ parents₀:: UInt128
142
+ parents₁:: UInt128
142
143
node_type:: OperationType
144
+ symid:: UInt16
143
145
array:: UInt8
144
- symid:: UInt8
145
146
end
146
147
optype (os) = os. node_type
147
148
@@ -166,14 +167,22 @@ end
166
167
loopdeps_uint (ls:: LoopSet , op:: Operation ) = shifted_loopset (ls, loopdependencies (op))
167
168
reduceddeps_uint (ls:: LoopSet , op:: Operation ) = shifted_loopset (ls, reduceddependencies (op))
168
169
childdeps_uint (ls:: LoopSet , op:: Operation ) = shifted_loopset (ls, reducedchildren (op))
169
- function parents_uint (ls :: LoopSet , op :: Operation )
170
+ function parents_uint (oppv :: AbstractVector{ Operation} )
170
171
p = zero (UInt128)
171
- for parent ∈ parents (op)
172
- p <<= 8
172
+ for parent ∈ oppv
173
+ p <<= 16
173
174
p |= identifier (parent)
174
175
end
175
176
p
176
177
end
178
+ function parents_uint (op:: Operation )
179
+ opv = parents (op)
180
+ N = length (opv)
181
+ @assert N ≤ 16
182
+ p0 = parents_uint (view (opv, 1 : min (8 ,N)))
183
+ p1 = N > 8 ? parents_uint (view (opv, 9 : N)) : zero (p0)
184
+ p0, p1
185
+ end
177
186
function recursively_set_parents_true! (x:: Vector{Bool} , op:: Operation )
178
187
x[identifier (op)] && return nothing # don't redescend
179
188
x[identifier (op)] = true
@@ -199,16 +208,16 @@ function getroots!(rooted::Vector{Bool}, ls::LoopSet)
199
208
return rooted
200
209
end
201
210
function OperationStruct! (varnames:: Vector{Symbol} , ids:: Vector{Int} , ls:: LoopSet , op:: Operation )
202
- instr = instruction (op)
203
- ld = loopdeps_uint (ls, op)
204
- rd = reduceddeps_uint (ls, op)
205
- cd = childdeps_uint (ls, op)
206
- p = parents_uint (ls, op)
207
- array = accesses_memory (op) ? findmatchingarray (ls, op. ref) : 0x00
208
- ids[identifier (op)] = id = findindoradd! (varnames, name (op))
209
- OperationStruct (
210
- ld, rd, cd, p, op. node_type, array, id
211
- )
211
+ instr = instruction (op)
212
+ ld = loopdeps_uint (ls, op)
213
+ rd = reduceddeps_uint (ls, op)
214
+ cd = childdeps_uint (ls, op)
215
+ p0, p1 = parents_uint (op)
216
+ array = accesses_memory (op) ? findmatchingarray (ls, op. ref) : 0x00
217
+ ids[identifier (op)] = id = findindoradd! (varnames, name (op))
218
+ OperationStruct (
219
+ ld, rd, cd, p0, p1, op. node_type, id, array
220
+ )
212
221
end
213
222
# # turn a LoopSet into a type object which can be used to reconstruct the LoopSet.
214
223
@@ -527,10 +536,10 @@ end
527
536
:: Val{CNFARG} , :: StaticInt{W} , :: StaticInt{RS} , :: StaticInt{AR} , :: StaticInt{NT} ,
528
537
:: StaticInt{CLS} , :: StaticInt{L1} , :: StaticInt{L2} , :: StaticInt{L3}
529
538
) where {CNFARG,W,RS,AR,CLS,L1,L2,L3,NT}
530
- inline,u₁,u₂,BROADCAST,thread = CNFARG
539
+ inline,u₁,u₂,v, BROADCAST,thread = CNFARG
531
540
nt = min (thread % UInt, NT % UInt)
532
- t = Expr (:tuple , inline, u₁, u₂, BROADCAST, W, RS, AR, CLS, L1,L2,L3, nt)
533
- length (CNFARG) == 6 && push! (t. args, last ( CNFARG) )
541
+ t = Expr (:tuple , inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, L1, L2, L3, nt)
542
+ length (CNFARG) == 7 && push! (t. args, CNFARG[ 7 ] )
534
543
Expr (:call , Expr (:curly , :Val , t))
535
544
end
536
545
@inline function avx_config_val (
563
572
564
573
565
574
function split_ifelse! (
566
- ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr , k:: Int , inlineu₁u₂:: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool
575
+ ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr , k:: Int ,
576
+ inlineu₁u₂:: Tuple{Bool,Int8,Int8,Int8} , thread:: UInt , debug:: Bool
567
577
)
568
578
roots[k] = false
569
579
op = operations (ls)[k]
@@ -617,13 +627,14 @@ function split_ifelse!(
617
627
prepre
618
628
end
619
629
620
- function generate_call (ls:: LoopSet , inlineu₁u₂:: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool )
630
+ function generate_call (ls:: LoopSet , inlineu₁u₂:: Tuple{Bool,Int8,Int8,Int8 } , thread:: UInt , debug:: Bool )
621
631
extra_args = Expr (:tuple )
622
632
preserve, shouldindbyind, roots = add_grouped_strided_pointer! (extra_args, ls)
623
633
generate_call_split (ls, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug)
624
634
end
625
635
function generate_call_split (
626
- ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr , inlineu₁u₂:: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool
636
+ ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr ,
637
+ inlineu₁u₂:: Tuple{Bool,Int8,Int8,Int8} , thread:: UInt , debug:: Bool
627
638
)
628
639
for (k,op) ∈ enumerate (operations (ls))
629
640
parents_op = parents (op)
636
647
637
648
# Try to condense in type stable manner
638
649
function generate_call_types (
639
- ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr , (inline,u₁,u₂):: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool
650
+ ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr ,
651
+ (inline,u₁,u₂,v):: Tuple{Bool,Int8,Int8,Int8} , thread:: UInt , debug:: Bool
640
652
)
641
653
# good place to check for split
642
654
operation_descriptions = Expr (:tuple )
@@ -665,7 +677,7 @@ function generate_call_types(
665
677
loop_syms = tuple_expr (QuoteNode, ls. loopsymbols)
666
678
func = debug ? lv (:_turbo_loopset_debug ) : lv (:_turbo_! )
667
679
lbarg = debug ? Expr (:call , :typeof , loop_bounds) : loop_bounds
668
- configarg = (inline,u₁,u₂,ls. isbroadcast,thread)
680
+ configarg = (inline,u₁,u₂,v, ls. isbroadcast,thread)
669
681
unroll_param_tup = Expr (:call , lv (:avx_config_val ), :(Val {$configarg} ()), VECTORWIDTHSYMBOL)
670
682
q = Expr (:call , func, unroll_param_tup, val (operation_descriptions), val (arrayref_descriptions), val (argmeta), val (loop_syms))
671
683
@@ -697,9 +709,10 @@ function generate_call_types(
697
709
end
698
710
# @inline reductinittype(::T) where {T} = StaticType{T}()
699
711
typeof_expr (op:: Operation ) = Expr (:call , GlobalRef (Base,:typeof ), name (op))
712
+ eltype_expr (op:: Operation ) = Expr (:call , GlobalRef (Base,:eltype ), name (op))
700
713
function add_outerreduct_types! (extra_args:: Expr , ls:: LoopSet ) # extract_outerreduct_types!
701
714
for or ∈ ls. outer_reductions
702
- push! (extra_args. args, typeof_expr (operations (ls)[or]))
715
+ push! (extra_args. args, eltype_expr (operations (ls)[or]))
703
716
end
704
717
end
705
718
"""
@@ -735,6 +748,7 @@ Returns true if the element type is supported.
735
748
"""
736
749
@inline check_type (:: Type{T} ) where {T <: NativeTypes } = true
737
750
@inline check_type (:: Type{T} ) where {T} = false
751
+ @inline check_type (:: Type{T} ) where {T <: AbstractSIMD } = true
738
752
@inline check_device (:: ArrayInterface.CPUPointer ) = true
739
753
@inline check_device (:: ArrayInterface.CPUTuple ) = true
740
754
@inline check_device (x) = false
@@ -787,10 +801,10 @@ function setup_call_final(ls::LoopSet, q::Expr)
787
801
return ls. preamble
788
802
end
789
803
function setup_call_debug (ls:: LoopSet )
790
- generate_call (ls, (false ,zero (Int8),zero (Int8)), zero (UInt), true )
804
+ generate_call (ls, (false ,zero (Int8),zero (Int8), zero (Int8) ), zero (UInt), true )
791
805
end
792
806
function setup_call (
793
- ls:: LoopSet , q:: Expr , source:: LineNumberNode , inline:: Bool , check_empty:: Bool , u₁:: Int8 , u₂:: Int8 , thread:: Int , warncheckarg:: Int
807
+ ls:: LoopSet , q:: Expr , source:: LineNumberNode , inline:: Bool , check_empty:: Bool , u₁:: Int8 , u₂:: Int8 , v :: Int8 , thread:: Int , warncheckarg:: Int
794
808
)
795
809
# We outline/inline at the macro level by creating/not creating an anonymous function.
796
810
# The old API instead was based on inlining or not inline the generated function, but
@@ -799,7 +813,7 @@ function setup_call(
799
813
# inlining the generated function into the loop preamble.
800
814
lnns = extract_all_lnns (q)
801
815
pushfirst! (lnns, source)
802
- call = generate_call (ls, (inline, u₁, u₂), thread% UInt, false )
816
+ call = generate_call (ls, (inline, u₁, u₂, v ), thread% UInt, false )
803
817
call = check_empty ? check_if_empty (ls, call) : call
804
818
argfailure = make_crashy (make_fast (q))
805
819
if warncheckarg ≠ 0
0 commit comments