Skip to content

Commit 28233d2

Browse files
committed
init sym for CartesianIndices{0} loops correctly more often
1 parent 527dffc commit 28233d2

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

src/transforms.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,55 @@ function return_empty_reductinit(op::Operation, var::Symbol)
4040
end
4141

4242

43+
44+
function constant_symbol!(ls::LoopSet, op::Operation)
45+
# hack
46+
# relowers, but should make it work
47+
# TODO: DRY with `lower_licm_constants!` from `src/codegen/lower_constants.jl`
48+
skip_constant(instruction(op)) || return instruction(op).instr
49+
idcheck = identifier(op)
50+
symname = constantopname(op)
51+
for (id, sym) ls.preamble_symsym
52+
(idcheck nothing) && ((idcheck == id) && continue)
53+
pushpreamble!(ls, Expr(:(=), symname, sym))
54+
return symname
55+
# setconstantop!(ls, op, sym)
56+
# setconstantop!(ls, op, Expr(:call, lv(:maybeconvert), ls.T, sym))
57+
end
58+
for (id,(intval,intsz,signed)) ls.preamble_symint
59+
(idcheck nothing) && ((idcheck == id) && continue)
60+
if intsz == 1
61+
pushpreamble!(ls, Expr(:(=), symname, intval % Bool))
62+
else
63+
pushpreamble!(ls, Expr(:(=), symname, sizeequivalent_symint_expr(intval, signed)))
64+
end
65+
return symname
66+
end
67+
for (id,floatval) ls.preamble_symfloat
68+
(idcheck nothing) && ((idcheck == id) && continue)
69+
pushpreamble!(ls, Expr(:(=), symname, Expr(:call, lv(:sizeequivalentfloat), ELTYPESYMBOL, floatval)))
70+
return symname
71+
end
72+
for (id,typ) ls.preamble_zeros
73+
(idcheck nothing) && ((idcheck == id) && continue)
74+
instruction(op) === LOOPCONSTANT || continue
75+
if typ == IntOrFloat
76+
pushpreamble!(ls, Expr(:(=), symname, Expr(:call, :zero, ELTYPESYMBOL)))
77+
elseif typ == HardInt
78+
pushpreamble!(ls, Expr(:(=), symname, Expr(:call, lv(:zerointeger), ELTYPESYMBOL)))
79+
else#if typ == HardFloat
80+
pushpreamble!(ls, Expr(:(=), symname, Expr(:call, lv(:zerofloat), ELTYPESYMBOL)))
81+
end
82+
return symname
83+
end
84+
for (id,f) ls.preamble_funcofeltypes
85+
(idcheck nothing) && ((idcheck == id) && continue)
86+
pushpreamble!(ls, Expr(:(=), symname, Expr(:call, reduction_zero(f), ELTYPESYMBOL)))
87+
return symname
88+
end
89+
throw("Constant operation symbol not found.")
90+
end
91+
4392
function hoist_constant_store!(q::Expr, ls::LoopSet, op::Operation)
4493
op.instruction = DROPPEDCONSTANT
4594
op.node_type = constant
@@ -52,7 +101,10 @@ function hoist_constant_store!(q::Expr, ls::LoopSet, op::Operation)
52101
end
53102
push!(ls.outer_reductions, identifier(opr))
54103

55-
init = return_empty_reductinit(opr, name(opr)).instruction.instr
104+
initop = return_empty_reductinit(opr, name(opr))
105+
# @show last(ls.preamble.args)
106+
init = constant_symbol!(ls, initop)
107+
# @show last(ls.preamble.args)
56108
pushpreamble!(ls, Expr(:(=), outer_reduct_init_typename(opr), Expr(:call, lv(:typeof), init)))
57109
qpre = Expr(:block)
58110
push!(q.args, Expr(:call, lv(:unsafe_store!), Expr(:call, lv(:pointer), op.ref.ptr), outer_reduction_to_scalar_reduceq!(qpre, opr, init)))

test/miscellaneous.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,22 @@ end
11821182
end
11831183
out
11841184
end
1185+
function smoothdim_kernel_tile_avx_2!(out, z, src::AbstractArray, kernel::AbstractVector, Rpre::CartesianIndices, axout_tile, Rpost::CartesianIndices)
1186+
axkernel = axes(kernel, 1)
1187+
for Ipost in Rpost
1188+
for i in axout_tile
1189+
@turbo for Ipre in Rpre
1190+
tmp = zero(eltype(out))
1191+
# tmp = convert(eltype(out), z) # failing to hoist this leads to an "UndefVarError: tmp not defined"
1192+
for j in axkernel
1193+
tmp += oftype(z, src[Ipre,i+j,Ipost])*kernel[j]
1194+
end
1195+
out[Ipre,i,Ipost] = tmp
1196+
end
1197+
end
1198+
end
1199+
out
1200+
end
11851201

11861202

11871203
for T (UInt8,Float32, Float64)
@@ -1229,6 +1245,8 @@ end
12291245
smoothdim_kernel_tile!( dest1, float(zero(T)), src, kernel, Rpre, axes(dest1, d), Rpost);
12301246
smoothdim_kernel_tile_avx!(dest2, float(zero(T)), src, kernel, Rpre, axes(dest2, d), Rpost);
12311247
@test dest1 dest2
1248+
fill!(dest2,NaN); smoothdim_kernel_tile_avx_2!(dest2, float(zero(T)), src, kernel, Rpre, axes(dest2, d), Rpost);
1249+
@test dest1 dest2
12321250
end
12331251
end
12341252
end

0 commit comments

Comments
 (0)