Skip to content

Commit 5c7ec87

Browse files
committed
Fix a statically sized codegen bug
1 parent 177708e commit 5c7ec87

File tree

4 files changed

+47
-8
lines changed

4 files changed

+47
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.158"
4+
version = "0.12.159"
55

66
[weakdeps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/codegen/lower_compute.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,11 @@ end
323323
for k 1:K
324324
push!(q.args, :($gf(vargs, $k, false)))
325325
end
326-
return Expr(:block, Expr(:meta, :inline), q)
326+
if VERSION >= v"1.8"
327+
return Expr(:block, Expr(:meta, :inline), :(@inline $q))
328+
else
329+
return Expr(:block, Expr(:meta, :inline), q)
330+
end
327331
end
328332
if Sreduced
329333
M = N
@@ -346,6 +350,9 @@ end
346350
push!(call.args, Expr(:call, gf, syms[k], m, false))
347351
end
348352
end
353+
if VERSION >= v"1.8"
354+
call = :(@inline $call)
355+
end
349356
# minimal change in behavior to fix case when !Sreduced by N -> Dlen; TODO: what should Dlen be here?
350357
if Sreduced ? (N == -1) : (Dlen == -1)
351358
push!(q.args, call)
@@ -516,6 +523,14 @@ function reduce_parent!(
516523
end
517524
newp
518525
end
526+
if VERSION >= v"1.8"
527+
function inlinecall(x)
528+
Meta.isexpr(x, :call) || return x
529+
:(@inline $x)
530+
end
531+
else
532+
const inlinecall = identity
533+
end
519534
function lower_compute!(
520535
q::Expr,
521536
op::Operation,
@@ -769,7 +784,13 @@ function lower_compute!(
769784
Expr(
770785
:(=),
771786
varsym,
772-
Expr(:call, lv(ifelsefunc), MASKSYMBOL, instrcall, selfopname)
787+
Expr(
788+
:call,
789+
lv(ifelsefunc),
790+
MASKSYMBOL,
791+
inlinecall(instrcall),
792+
selfopname
793+
)
773794
)
774795
)
775796
elseif ((u₁ 1) | (selfdepreduce 0))
@@ -786,7 +807,7 @@ function lower_compute!(
786807
MASKSYMBOL,
787808
staticexpr(u₁),
788809
staticexpr(selfdepreduce),
789-
instrcall,
810+
inlinecall(instrcall),
790811
selfopname
791812
)
792813
)
@@ -799,7 +820,13 @@ function lower_compute!(
799820
Expr(
800821
:(=),
801822
varsym,
802-
Expr(:call, lv(:ifelse), MASKSYMBOL, instrcall, selfopname)
823+
Expr(
824+
:call,
825+
lv(:ifelse),
826+
MASKSYMBOL,
827+
inlinecall(instrcall),
828+
selfopname
829+
)
803830
)
804831
)
805832
end
@@ -832,11 +859,11 @@ function lower_compute!(
832859
Expr(
833860
:(=),
834861
varsym,
835-
Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, instrcall)
862+
Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, inlinecall(instrcall))
836863
)
837864
)
838865
else
839-
push!(q.args, Expr(:(=), varsym, instrcall))
866+
push!(q.args, Expr(:(=), varsym, inlinecall(instrcall)))
840867
end
841868
# end
842869
end

src/codegen/lowering.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ function lower_unrolled_dynamic(
311311
loopisstatic = false
312312
loopisbounded = false
313313
end
314-
Ureduct = (n == num_loops(ls) && (u₂ == -1)) ? ureduct(ls) : -1
314+
Ureduct = ((n == num_loops(ls) && (u₂ == -1))) ? ureduct(ls) : -1
315315
# for now, require loopisstatic or !Ureduct-ing for reducing UF
316316
if loopisbounded & (loopisstatic | (Ureduct < 0))
317317
UFWnew = cld(looplength, cld(looplength, UFW))
@@ -390,9 +390,11 @@ function lower_unrolled_dynamic(
390390
initialize_outer_reductions!(q, ls, Ureduct)
391391
q
392392
elseif loopisstatic
393+
blockhead = :block
393394
if length(loop) < UF * W
394395
Expr(:block)
395396
else
397+
UFt -= Ureduct
396398
Expr(
397399
:block,
398400
add_upper_outer_reductions(ls, q, Ureduct, UF, loop, nisvectorized)

test/staticsize.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ function sum_turbo(x)
9696
end
9797
s
9898
end
99+
function sum2_10turbo(x)
100+
s = zero(eltype(x))
101+
for i = 1:10, j = 1:2
102+
s += x[j, i]
103+
end
104+
s
105+
end
99106

100107
@testset "Statically Sized Arrays" begin
101108
@show @__LINE__
@@ -124,4 +131,7 @@ end
124131
@test maxabs(x) == maximum(abs, x)
125132
@test sum_turbo(x) sum(x)
126133
end
134+
let A = rand(2, 10)
135+
@test sum2_10turbo(A) sum(A)
136+
end
127137
end

0 commit comments

Comments
 (0)