Skip to content

Commit d238ab2

Browse files
committed
Merge branch 'master' into teh/multiindices
2 parents 8eb2440 + ca31ba1 commit d238ab2

File tree

9 files changed

+71
-44
lines changed

9 files changed

+71
-44
lines changed

src/add_compute.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -190,31 +190,36 @@ function add_compute!(
190190
instr = instruction(first(ex.args))::Symbol
191191
args = @view(ex.args[2:end])
192192
(instr === :(^) && length(args) == 2 && (args[2] isa Number)) && return add_pow!(ls, var, args[1], args[2], elementbytes, position)
193-
parents = Operation[]
193+
vparents = Operation[]
194194
deps = Symbol[]
195195
reduceddeps = Symbol[]
196196
reduction_ind = 0
197197
for (ind,arg) enumerate(args)
198198
if var === arg
199199
reduction_ind = ind
200-
add_reduction!(parents, deps, reduceddeps, ls, arg, elementbytes)
200+
add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
201201
elseif arg isa Expr
202-
isref, argref = tryrefconvert(ls, arg, elementbytes)
202+
isref, argref = tryrefconvert(ls, arg, elementbytes, varname(mpref))
203203
if isref
204204
if mpref == argref
205-
reduction_ind = ind
206-
add_load!(ls, var, argref, elementbytes)
205+
if varname(mpref) === var
206+
reduction_ind = ind
207+
add_load!(ls, argref, elementbytes)
208+
else
209+
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
210+
end
207211
else
208-
pushparent!(parents, deps, reduceddeps, add_load!(ls, gensym(:tempload), argref, elementbytes))
212+
argref.varname = gensym(:tempload)
213+
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
209214
end
210215
else
211-
add_parent!(parents, deps, reduceddeps, ls, arg, elementbytes, position)
216+
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
212217
end
213218
elseif arg ls.loopsymbols
214219
loopsymop = add_loopvalue!(ls, arg, elementbytes)
215-
pushparent!(parents, deps, reduceddeps, loopsymop)
220+
pushparent!(vparents, deps, reduceddeps, loopsymop)
216221
else
217-
add_parent!(parents, deps, reduceddeps, ls, arg, elementbytes, position)
222+
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
218223
end
219224
end
220225
reduction = reduction_ind > 0
@@ -228,29 +233,30 @@ function add_compute!(
228233
mergesetv!(newreduceddeps, reduceddeps)
229234
deps = newloopdeps; reduceddeps = newreduceddeps
230235
end
231-
if reduction || search_tree(parents, var)
232-
parent = getop(ls, var, elementbytes)
236+
if reduction || search_tree(vparents, var)
237+
parent = ls.opdict[var]
233238
setdiffv!(reduceddeps, deps, loopdependencies(parent))
239+
# parent = getop(ls, var, elementbytes)
234240
if length(reduceddeps) == 0
235-
insert!(parents, reduction_ind, parent)
236-
op = Operation(length(operations(ls)), var, elementbytes, instruction(ls,instr), compute, deps, reduceddeps, parents)
241+
insert!(vparents, reduction_ind, parent)
242+
op = Operation(length(operations(ls)), var, elementbytes, instruction(ls,instr), compute, deps, reduceddeps, vparents)
237243
pushop!(ls, op, var)
238244
else
239-
add_reduction_update_parent!(parents, deps, reduceddeps, ls, parent, instr, reduction_ind, elementbytes)
245+
add_reduction_update_parent!(vparents, deps, reduceddeps, ls, parent, instr, reduction_ind, elementbytes)
240246
end
241247
else
242-
op = Operation(length(operations(ls)), var, elementbytes, instruction(ls,instr), compute, deps, reduceddeps, parents)
248+
op = Operation(length(operations(ls)), var, elementbytes, instruction(ls,instr), compute, deps, reduceddeps, vparents)
243249
pushop!(ls, op, var)
244250
end
245251
end
246252

247253
function add_compute!(
248-
ls::LoopSet, LHS::Symbol, instr, parents::Vector{Operation}, elementbytes
254+
ls::LoopSet, LHS::Symbol, instr, vparents::Vector{Operation}, elementbytes
249255
)
250256
deps = Symbol[]
251257
reduceddeps = Symbol[]
252-
foreach(parent -> update_deps!(deps, reduceddeps, parent), parents)
253-
op = Operation(length(operations(ls)), LHS, elementbytes, instr, compute, deps, reduceddeps, parents)
258+
foreach(parent -> update_deps!(deps, reduceddeps, parent), vparents)
259+
op = Operation(length(operations(ls)), LHS, elementbytes, instr, compute, deps, reduceddeps, vparents)
254260
pushop!(ls, op, LHS)
255261
end
256262

src/add_constants.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
3535
end
3636
pushop!(ls, op)
3737
end
38-
function add_constant!(ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
39-
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
38+
function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
39+
op = Operation(length(operations(ls)), varname(mpref), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
4040
add_vptr!(ls, op)
4141
temp = gensym(:intermediateconstref)
4242
pushpreamble!(ls, Expr(:(=), temp, Expr(:call, lv(:vload), mpref.mref.ptr, mem_offset(op, UnrollArgs(0, Symbol(""), Symbol(""), nothing)))))

src/add_loads.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ end
1717
function add_load!(
1818
ls::LoopSet, var::Symbol, array::Symbol, rawindices, elementbytes::Int
1919
)
20-
mpref = array_reference_meta!(ls, array, rawindices, elementbytes)
21-
add_load!(ls, var, mpref, elementbytes)
20+
mpref = array_reference_meta!(ls, array, rawindices, elementbytes, var)
21+
add_load!(ls, mpref, elementbytes)
2222
end
2323
function add_load!(
24-
ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int
24+
ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int
2525
)
26-
length(mpref.loopdependencies) == 0 && return add_constant!(ls, var, mpref, elementbytes)
27-
op = Operation( ls, var, elementbytes, :getindex, memload, mpref )
26+
length(mpref.loopdependencies) == 0 && return add_constant!(ls, mpref, elementbytes)
27+
op = Operation( ls, varname(mpref), elementbytes, :getindex, memload, mpref )
2828
add_load!(ls, op, true, false)
2929
end
3030

src/add_stores.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ function add_copystore!(
2121
)
2222
op = add_compute!(ls, gensym(), :identity, [parent], elementbytes)
2323
# pushfirst!(mpref.parents, parent)
24-
add_store!(ls, name(op), mpref, elementbytes, op)
24+
add_store!(ls, mpref, elementbytes, op)
2525
end
2626

2727

2828
function add_store!(
29-
ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int, parent = getop(ls, var, mpref.loopdependencies, elementbytes)
29+
ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int, parent = getop(ls, varname(mpref), mpref.loopdependencies, elementbytes)
3030
)
3131
isload(parent) && return add_copystore!(ls, parent, mpref, elementbytes)
3232
vparents = mpref.parents
@@ -56,8 +56,8 @@ end
5656
function add_store!(
5757
ls::LoopSet, var::Symbol, array::Symbol, rawindices, elementbytes::Int
5858
)
59-
mpref = array_reference_meta!(ls, array, rawindices, elementbytes)
60-
add_store!(ls, var, mpref, elementbytes)
59+
mpref = array_reference_meta!(ls, array, rawindices, elementbytes, var)
60+
add_store!(ls, mpref, elementbytes)
6161
end
6262
function add_simple_store!(ls::LoopSet, parent::Operation, ref::ArrayReference, elementbytes::Int)
6363
mref = ArrayReferenceMeta(

src/graphs.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,8 @@ function Operation(
306306
node_type, dependencies, reduced_deps, parents, ref
307307
)
308308
end
309-
function Operation(ls::LoopSet, var, elementbytes, instr, optype, mpref::ArrayReferenceMetaPosition)
310-
Operation(length(operations(ls)), var, elementbytes, instr, optype, mpref)
309+
function Operation(ls::LoopSet, variable, elementbytes, instr, optype, mpref::ArrayReferenceMetaPosition)
310+
Operation(length(operations(ls)), variable, elementbytes, instr, optype, mpref)
311311
end
312312

313313
# load_operations(ls::LoopSet) = ls.loadops
@@ -479,8 +479,8 @@ function add_operation!(
479479
)
480480
if RHS.head === :ref# || (RHS.head === :call && first(RHS.args) === :getindex)
481481
array, rawindices = ref_from_expr(RHS)
482-
RHS_ref = array_reference_meta!(ls, array, rawindices, elementbytes)
483-
op = add_load!(ls, gensym(LHS_sym), RHS_ref, elementbytes)
482+
RHS_ref = array_reference_meta!(ls, array, rawindices, elementbytes, gensym(LHS_sym))
483+
op = add_load!(ls, RHS_ref, elementbytes)
484484
iop = add_compute!(ls, LHS_sym, :identity, [op], elementbytes)
485485
# pushfirst!(LHS_ref.parents, iop)
486486
elseif RHS.head === :call
@@ -528,11 +528,14 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
528528
# assign RHS to lrhs
529529
array, rawindices = ref_from_expr(LHS)
530530
mpref = array_reference_meta!(ls, array, rawindices, elementbytes)
531+
cachedparents = copy(mpref.parents)
531532
ref = mpref.mref.ref
532533
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
533534
lrhs = id === nothing ? gensym(:RHS) : ls.syms_aliasing_refs[id]
535+
mpref.varname = lrhs
534536
add_operation!(ls, lrhs, RHS, mpref, elementbytes, position)
535-
add_store!( ls, lrhs, mpref, elementbytes)
537+
mpref.parents = cachedparents
538+
add_store!(ls, mpref, elementbytes)
536539
else
537540
add_store_ref!(ls, RHS, LHS, elementbytes)
538541
end

src/memory_ops_common.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind, previndices,
6666
subsetvptr
6767
end
6868
const DISCONTIGUOUS = Symbol("##DISCONTIGUOUSSUBARRAY##")
69-
function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementbytes::Int)
69+
function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementbytes::Int, var::Union{Nothing,Symbol} = nothing)
7070
vptrarray = vptr(array)
7171
add_vptr!(ls, array, vptrarray) # now, subset
7272
indices = Symbol[]
@@ -109,9 +109,9 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
109109
end
110110
# (length(parents) != 0 && first(indices) !== Symbol("##DISCONTIGUOUSSUBARRAY##")) && pushfirst!(indices, Symbol("##DISCONTIGUOUSSUBARRAY##"))
111111
mref = ArrayReferenceMeta(ArrayReference( array, indices ), loopedindex, vptrarray)
112-
ArrayReferenceMetaPosition(mref, parents, loopdependencies, reduceddeps)
112+
ArrayReferenceMetaPosition(mref, parents, loopdependencies, reduceddeps, isnothing(var) ? Symbol("") : var )
113113
end
114-
function tryrefconvert(ls::LoopSet, ex::Expr, elementbytes::Int)::Tuple{Bool,ArrayReferenceMetaPosition}
114+
function tryrefconvert(ls::LoopSet, ex::Expr, elementbytes::Int, var::Union{Nothing,Symbol} = nothing)::Tuple{Bool,ArrayReferenceMetaPosition}
115115
ya, yinds = if ex.head === :ref
116116
ref_from_ref(ex)
117117
elseif ex.head === :call
@@ -126,6 +126,6 @@ function tryrefconvert(ls::LoopSet, ex::Expr, elementbytes::Int)::Tuple{Bool,Arr
126126
else
127127
return false, NOTAREFERENCEMP
128128
end
129-
true, array_reference_meta!(ls, ya, yinds, elementbytes)
129+
true, array_reference_meta!(ls, ya, yinds, elementbytes, var)
130130
end
131131

src/operations.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,25 @@ function isouterreduction(op::Operation)
200200
end
201201
end
202202

203-
struct ArrayReferenceMetaPosition
203+
mutable struct ArrayReferenceMetaPosition
204204
mref::ArrayReferenceMeta
205205
parents::Vector{Operation}
206206
loopdependencies::Vector{Symbol}
207207
reduceddeps::Vector{Symbol}
208+
varname::Symbol
208209
end
209-
function ArrayReferenceMetaPosition(parents::Vector{Operation}, ldref::Vector{Symbol}, reduceddeps::Vector{Symbol})
210-
ArrayReferenceMetaPosition( NOTAREFERENCE, parents, ldref, reduceddeps )
210+
function ArrayReferenceMetaPosition(parents::Vector{Operation}, ldref::Vector{Symbol}, reduceddeps::Vector{Symbol}, varname::Symbol)
211+
ArrayReferenceMetaPosition( NOTAREFERENCE, parents, ldref, reduceddeps, varname )
211212
end
212213
function Operation(id::Int, var::Symbol, elementbytes::Int, instr, optype::OperationType, mpref::ArrayReferenceMetaPosition)
213214
Operation( id, var, elementbytes, instr, optype, mpref.loopdependencies, mpref.reduceddeps, mpref.parents, mpref.mref )
214215
end
215216
Base.:(==)(x::ArrayReferenceMetaPosition, y::ArrayReferenceMetaPosition) = x.mref.ref == y.mref.ref
216217
# Avoid memory allocations by using this for ops that aren't references
217218
const NOTAREFERENCE = ArrayReferenceMeta(ArrayReference(Symbol(""), Union{Symbol,Int}[]),Bool[],Symbol(""))
218-
const NOTAREFERENCEMP = ArrayReferenceMetaPosition(NOTAREFERENCE, NOPARENTS, Symbol[], Symbol[])
219+
const NOTAREFERENCEMP = ArrayReferenceMetaPosition(NOTAREFERENCE, NOPARENTS, Symbol[], Symbol[],Symbol(""))
220+
varname(::Nothing) = nothing
221+
varname(mpref::ArrayReferenceMetaPosition) = mpref.varname
219222
name(mpref::ArrayReferenceMetaPosition) = name(mpref.mref.ref)
220223
loopdependencies(ref::ArrayReferenceMetaPosition) = ref.loopdependencies
221224
reduceddependencies(ref::ArrayReferenceMetaPosition) = ref.reduceddeps

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ function _avx_loopset(OPSsv, ARFsv, AMsv, LPSYMsv, LBsv, vargs)
414414
)
415415
end
416416
@generated function _avx_!(::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB, vargs...) where {UT, OPS, ARF, AM, LPSYM, LB}
417-
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
417+
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
418418
ls = _avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, vargs)
419419
avx_body(ls, UT)
420420
end

test/ifelsemasks.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ T = Float32
118118
c[i] = 1 + ifelse(a[i] > b[i], a[i] + b[i], a[i] * b[i])
119119
end
120120
end
121+
function ifelseoverwrite!(p)
122+
for i eachindex(p)
123+
p[i] = p[i] < 0.5 ? p[i]^2 : p[i]^3
124+
end
125+
end
126+
function ifelseoverwriteavx!(p)
127+
@avx for i eachindex(p)
128+
p[i] = p[i] < 0.5 ? p[i]^2 : p[i]^3
129+
end
130+
end
131+
121132

122133

123134
function maybewriteand!(c, a, b)
@@ -286,7 +297,6 @@ T = Float32
286297
x[i] = yᵢ * zᵢ
287298
end
288299
end
289-
290300
N = 117
291301
for T (Float32, Float64, Int32, Int64)
292302
@show T, @__LINE__
@@ -343,6 +353,11 @@ T = Float32
343353
@test c1 c2
344354
fill!(c2, -999999999); andorassignment_avx!(c2, a, b);
345355
@test c1 c2
356+
357+
a1 = copy(a); a2 = copy(a);
358+
ifelseoverwrite!(a1)
359+
ifelseoverwriteavx!(a2)
360+
@test a1 a2
346361

347362
if T <: Union{Float32,Float64}
348363
a .*= 100;

0 commit comments

Comments
 (0)