Skip to content

Commit 4510812

Browse files
committed
Minor progress.
1 parent 97bc8b3 commit 4510812

File tree

4 files changed

+130
-44
lines changed

4 files changed

+130
-44
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,11 @@ function add_reductions!(q, ::Type{V}, reduction_symbols, unroll_factor, mod) wh
346346
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,one($T))))
347347
end
348348
if op === :+
349-
push!(q.args, :($sym = Base.FastMath.add_fast($sym, $mod.vsum($gsym))))
349+
push!(q.args, :($sym = $mod.SIMDPirates.reduced_add($sym, $gsym)))
350350
elseif op === :-
351351
push!(q.args, :($sym = Base.FastMath.sub_fast($sym, $mod.vsum($gsym))))
352352
elseif op === :*
353-
push!(q.args, :($sym = Base.FastMath.mul_fast($sym, $mod.SIMDPirates.vprod($gsym))))
353+
push!(q.args, :($sym = $mod.SIMDPirates.reduced_prod($sym, $gsym)))
354354
elseif op === :/
355355
push!(q.args, :($sym = Base.FastMath.div_fast($sym, $mod.SIMDPirates.vprod($gsym))))
356356
end

src/constructors.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11

22
### This file contains convenience functions for constructing LoopSets.
33

4-
function loopset_from_expr(q::Expr)
5-
q = contract_pass(q)
4+
function walk_body!(ls::LoopSet, body::Expr)
5+
6+
7+
end
8+
function Base.copyto!(ls::LoopSet, q::Expr)
9+
q.head === :for || throw("Expression must be a for loop.")
10+
add_loop!(ls, q.args[1])
11+
body = q.args[2]
612

7-
postwalk(q) do ex
8-
9-
end
13+
end
14+
15+
function LoopSet(q::Expr)
16+
q = contract_pass(q)
17+
ls = LoopSet()
18+
copyto!(ls, q)
19+
resize!(ls.loop_order, num_loops(ls))
20+
ls
1021
end
1122

1223

src/graphs.jl

Lines changed: 110 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ isdense(::Type{<:DenseArray}) = true
2828

2929
@enum OperationType begin
3030
memload
31+
compute
3132
memstore
32-
compute_new
33-
compute_update
3433
# accumulator
3534
end
3635

@@ -55,25 +54,23 @@ struct Operation
5554
elementbytes::Int
5655
instruction::Symbol
5756
node_type::OperationType
58-
# dependencies::Vector{Symbol}
5957
dependencies::Set{Symbol}
6058
reduced_deps::Set{Symbol}
61-
# dependencies::Set{Symbol}
6259
parents::Vector{Operation}
63-
children::Vector{Operation}
60+
# children::Vector{Operation}
6461
numerical_metadata::Vector{Int} # stride of -1 indicates dynamic
6562
symbolic_metadata::Vector{Symbol}
66-
# strides::Dict{Symbol,Union{Symbol,Int}}
6763
function Operation(
6864
identifier,
65+
variable,
6966
elementbytes,
7067
instruction,
7168
node_type,
7269
variable = gensym()
7370
)
7471
new(
7572
identifier, variable, elementbytes, instruction, node_type,
76-
Set{Symbol}(), Operation[], Operation[], Int[], Symbol[]#, Dict{Symbol,Union{Symbol,Int}}()
73+
Set{Symbol}(), Set{Symbol}(), Operation[], Int[], Symbol[]
7774
)
7875
end
7976
end
@@ -84,12 +81,13 @@ function isreduction(op::Operation)
8481
(op.node_type == memstore) && (length(op.symbolic_metadata) < length(op.dependencies))# && issubset(op.symbolic_metadata, op.dependencies)
8582
end
8683
isload(op::Operation) = op.node_type == memload
84+
iscompute(op::Operation) = op.node_type == compute
8785
isstore(op::Operation) = op.node_type == memstore
8886
accesses_memory(op::Operation) = isload(op) | isstore(op)
8987
elsize(op::Operation) = op.elementbytes
9088
dependson(op::Operation, sym::Symbol) = sym op.dependencies
9189
parents(op::Operation) = op.parents
92-
children(op::Operation) = op.children
90+
# children(op::Operation) = op.children
9391
loopdependencies(op::Operation) = op.dependencies
9492
reduceddependencies(op::Operation) = op.reduced_deps
9593
identifier(op::Operation) = op.identifier
@@ -159,29 +157,57 @@ end
159157

160158
# load/compute/store × isunroled × istiled × pre/post loop × Loop number
161159
struct LoopOrder <: AbstractArray{Vector{Operation},5}
162-
oporder::Array{Vector{Operation},5}
160+
oporder::Vector{Vector{Operation}}
163161
loopnames::Vector{Symbol}
164162
end
165163
function LoopOrder(N::Int)
166-
LoopOrder( [ Operation[] for i 1:3, j 1:2, k 1:2, l 1:2, n 1:N ], Vector{Symbol}(undef, N) )
164+
LoopOrder( [ Operation[] for i 1:24N ], Vector{Symbol}(undef, N) )
167165
end
166+
LoopOrder() = LoopOrder(Vector{Operation}[])
168167
Base.empty!(lo::LoopOrder) = foreach(empty!, lo.oporder)
169-
Base.size(lo::LoopOrder) = (3,2,2,2,size(lo.oporder,5))
170-
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i...) = lo.oporder[i...]
168+
function Base.resize!(lo::LoopOrder, N::Int)
169+
Nold = length(lo.loopnames)
170+
resize!(lo.oporder, 24N)
171+
for n 24Nold+1:24N
172+
lo.oporder[n] = Operation[]
173+
end
174+
resize!(lo.loopnames, N)
175+
lo
176+
end
177+
Base.size(lo::LoopOrder) = (3,2,2,2,length(lo.loopnames))
178+
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i::Int) = lo.oporder[i]
179+
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i...) = lo.oporder[LinearIndices(size(lo))[i...]]
171180

172181
# Must make it easy to iterate
173182
struct LoopSet
174183
loops::Dict{Symbol,Loop} # sym === loops[sym].itersymbol
175-
# operations::Vector{Operation}
176-
loadops::Vector{Operation} # Split them to make it easier to iterate over just a subset
177-
computeops::Vector{Operation}
178-
storeops::Vector{Operation}
179-
inner_reductions::Set{UInt} # IDs of reduction operations nested within loops and stored.
184+
opdict::Dict{Symbol,Operation}
185+
operations::Vector{Operation} # Split them to make it easier to iterate over just a subset
186+
# computeops::Vector{Operation}
187+
# storeops::Vector{Operation}
180188
outer_reductions::Set{UInt} # IDs of reduction operations that need to be reduced at end.
181189
loop_order::LoopOrder
182-
# strideset::Vector{}
190+
preamble::Expr # TODO: add preamble to lowering
191+
end
192+
function LoopSet()
193+
LoopSet(
194+
Dict{Symbol,Loop}(),
195+
Dict{Symbol,Operation}(),
196+
Operation[],
197+
# Operation[],
198+
# Operation[],
199+
# Set{UInt}(),
200+
Set{UInt}(),
201+
LoopOrder(),
202+
Expr(:block,)
203+
)
183204
end
184205
num_loops(ls::LoopSet) = length(ls.loops)
206+
function oporder(ls::LoopSet)
207+
N = length(ls.loop_order.loopnames)
208+
reshape(ls.loop_order.oporder, (3,2,2,2,N))
209+
end
210+
names(ls::LoopSet) = ls.loop_order.loopnames
185211
isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
186212
looprangehint(ls::LoopSets, s::Symbol) = ls.loops[s].rangehint
187213
looprangesym(ls::LoopSets, s::Symbol) = ls.loops[s].rangesym
@@ -198,15 +224,71 @@ end
198224
function Base.length(ls::LoopSet, is::Symbol)
199225
ls.loops[is].rangehint
200226
end
201-
load_operations(ls::LoopSet) = ls.loadops
202-
compute_operations(ls::LoopSet) = ls.computeops
203-
store_operations(ls::LoopSet) = ls.storeops
204-
function operations(ls::LoopSet)
205-
Base.Iterators.flatten((
206-
load_operations(ls),
207-
compute_operations(ls),
208-
store_operations(ls)
209-
))
227+
# load_operations(ls::LoopSet) = ls.loadops
228+
# compute_operations(ls::LoopSet) = ls.computeops
229+
# store_operations(ls::LoopSet) = ls.storeops
230+
# function operations(ls::LoopSet)
231+
# Base.Iterators.flatten((
232+
# load_operations(ls),
233+
# compute_operations(ls),
234+
# store_operations(ls)
235+
# ))
236+
# end
237+
operations(ls::LoopSet) = ls.operations
238+
function add_loop!(ls::LoopSet, looprange::Expr)
239+
itersym = (looprange.args[1])::Symbol
240+
r = (looprange.args[2])::Expr
241+
@assert r.head === :call
242+
f = first(r.args)
243+
loop::Loop = if f === :(:)
244+
lower = r.args[2]
245+
upper = r.args[3]
246+
lii::Bool = lower isa Integer
247+
uii::Bool = upper isa Integer
248+
if lii & uii
249+
Loop(itersym, 1 + convert(Int,upper) - convert(Int,lower))
250+
else
251+
N = gensym(:loop, itersym)
252+
ex = if lii
253+
Expr(:call, :-, upper, lower - 1)
254+
elseif uii
255+
Expr(:call, :-, upper + 1, lower)
256+
else
257+
Expr(:call, :-, Expr(:call, :+, upper, 1), lower)
258+
end
259+
push!(ls.preamble.args, Expr(:(=), N, ex))
260+
Loop(itersym, N)
261+
end
262+
elseif f === :eachindex
263+
N = gensym(:loop, itersym)
264+
push!(ls.preamble.args, Expr(:(=), N, Expr(:call, :length, r.args[2])))
265+
Loop(itersym, N)
266+
else
267+
throw("Unrecognized loop range type: $r.")
268+
end
269+
ls.loops[itersym] = loop
270+
nothing
271+
end
272+
function add_load!(ls::LoopSet, indexed::Symbol, indices::AbstractVector)
273+
Ninds = length(indices)
274+
275+
276+
277+
end
278+
function add_load_getindex!(ls::LoopSet, ex::Expr)
279+
add_load!(ls, ex.args[2], @view(ex.args[3:end]))
280+
end
281+
function add_load_ref!(ls::LoopSet, ex::Expr)
282+
add_load!(ls, ex.args[1], @view(ex.args[2:end]))
283+
end
284+
function add_compute!(ls::LoopSet, ex::Expr)
285+
286+
end
287+
function add_store!(ls::LoopSet, ex::Expr)
288+
289+
end
290+
function Base.push!(ls::LoopSet, ex::Expr)
291+
210292
end
211293

212294
function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
@@ -233,13 +315,7 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
233315
included_vars[id] = true
234316
isunrolled = (unrolled loopdependencies(op)) + 1
235317
istiled = (loopistiled ? false : (tiled loopdependencies(op))) + 1
236-
optype = if isload(op)
237-
1
238-
elseif isstore(op)
239-
3
240-
else#if compute
241-
2
242-
end
318+
optype = Int(op.node_type)
243319
after_loop = (length(reduceddependencies(op)) > 0) + 1
244320
push!(lo[optype,isunrolled,istiled,after_loop,_n], op)
245321
end

src/lowering.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,8 @@ function lower_nest(
298298
loopstart::Union{Int,Symbol}, W::Int,
299299
mask::Union{Nothing,Symbol,Unsigned} = nothing, exprtype::Symbol = :while
300300
)
301-
lo = ls.loop_order
302-
ops = lo.oporder
303-
order = lo.loopnames
301+
ops = oporder(ls)
302+
order = names(ls)
304303
istiled = T != -1
305304
loopsym = order[n]
306305
nloops = num_loops(ls)

0 commit comments

Comments
 (0)