Skip to content

Commit 61d3a3c

Browse files
committed
Support multiple assignments, fixes #271.
1 parent abdb4e0 commit 61d3a3c

File tree

3 files changed

+82
-44
lines changed

3 files changed

+82
-44
lines changed

src/modeling/graphs.jl

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,58 @@ function prepare_rhs_for_storage!(ls::LoopSet, RHS::Union{Symbol,Expr}, array, r
11021102
add_store!(ls, mpref, elementbytes)
11031103
end
11041104

1105+
function add_assignment!(ls::LoopSet, LHS, RHS, elementbytes::Int, position::Int)
1106+
if LHS isa Symbol
1107+
if RHS isa Expr
1108+
maybe_const_compute!(ls, LHS, add_operation!(ls, LHS, RHS, elementbytes, position), elementbytes, position)
1109+
else
1110+
add_constant!(ls, RHS, ls.loopsymbols[1:position], LHS, elementbytes)
1111+
end
1112+
elseif LHS isa Expr
1113+
if LHS.head === :ref
1114+
if RHS isa Symbol
1115+
add_store_ref!(ls, RHS, LHS, elementbytes)
1116+
elseif RHS isa Expr
1117+
# need to check if LHS appears in RHS
1118+
# assign RHS to lrhs
1119+
array, rawindices = ref_from_expr!(ls, LHS)
1120+
prepare_rhs_for_storage!(ls, RHS, array, rawindices, elementbytes, position)
1121+
else
1122+
add_store_ref!(ls, RHS, LHS, elementbytes) # is this necessary? (Extension API?)
1123+
end
1124+
elseif LHS.head === :tuple
1125+
if RHS.head === :tuple
1126+
for i eachindex(LHS.args)
1127+
add_assignment!(ls, LHS.args[i], RHS.args[i], elementbytes, position)
1128+
end
1129+
return
1130+
end
1131+
@assert length(LHS.args) 9 "Functions returning more than 9 values aren't currently supported."
1132+
lhstemp = gensym!(ls, "lhstuple")
1133+
vparents = Operation[maybe_const_compute!(ls, lhstemp, add_operation!(ls, lhstemp, RHS, elementbytes, position), elementbytes, position)]
1134+
for i eachindex(LHS.args)
1135+
f = (:first,:second,:third,:fourth,:fifth,:sixth,:seventh,:eighth,:ninth)[i]
1136+
lhsi = LHS.args[i]
1137+
if lhsi isa Symbol
1138+
add_compute!(ls, lhsi, f, vparents, elementbytes)
1139+
elseif lhsi isa Expr && lhsi.head === :ref
1140+
tempunpacksym = gensym!(ls, "tempunpack")
1141+
add_compute!(ls, tempunpacksym, f, vparents, elementbytes)
1142+
add_store_ref!(ls, tempunpacksym, lhsi, elementbytes)
1143+
else
1144+
throw(LoopError("Unpacking the above expression in the left hand side was not understood/supported.", lhsi))
1145+
end
1146+
end
1147+
first(vparents)
1148+
else
1149+
throw(LoopError("LHS not understood; only `:ref`s and `:tuple`s are currently supported.", LHS))
1150+
end
1151+
else
1152+
throw(LoopError("LHS not understood.", LHS))
1153+
end
1154+
nothing
1155+
end
1156+
11051157
function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
11061158
if ex.head === :call
11071159
finex = first(ex.args)::Symbol
@@ -1112,50 +1164,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
11121164
error("Function $finex not recognized.")
11131165
end
11141166
elseif ex.head === :(=)
1115-
LHS = ex.args[1]
1116-
RHS = ex.args[2]
1117-
if LHS isa Symbol
1118-
if RHS isa Expr
1119-
maybe_const_compute!(ls, LHS, add_operation!(ls, LHS, RHS, elementbytes, position), elementbytes, position)
1120-
else
1121-
add_constant!(ls, RHS, ls.loopsymbols[1:position], LHS, elementbytes)
1122-
end
1123-
elseif LHS isa Expr
1124-
if LHS.head === :ref
1125-
if RHS isa Symbol
1126-
add_store_ref!(ls, RHS, LHS, elementbytes)
1127-
elseif RHS isa Expr
1128-
# need to check if LHS appears in RHS
1129-
# assign RHS to lrhs
1130-
array, rawindices = ref_from_expr!(ls, LHS)
1131-
prepare_rhs_for_storage!(ls, RHS, array, rawindices, elementbytes, position)
1132-
else
1133-
add_store_ref!(ls, RHS, LHS, elementbytes) # is this necessary? (Extension API?)
1134-
end
1135-
elseif LHS.head === :tuple
1136-
@assert length(LHS.args) 9 "Functions returning more than 9 values aren't currently supported."
1137-
lhstemp = gensym!(ls, "lhstuple")
1138-
vparents = Operation[maybe_const_compute!(ls, lhstemp, add_operation!(ls, lhstemp, RHS, elementbytes, position), elementbytes, position)]
1139-
for i eachindex(LHS.args)
1140-
f = (:first,:second,:third,:fourth,:fifth,:sixth,:seventh,:eighth,:ninth)[i]
1141-
lhsi = LHS.args[i]
1142-
if lhsi isa Symbol
1143-
add_compute!(ls, lhsi, f, vparents, elementbytes)
1144-
elseif lhsi isa Expr && lhsi.head === :ref
1145-
tempunpacksym = gensym!(ls, "tempunpack")
1146-
add_compute!(ls, tempunpacksym, f, vparents, elementbytes)
1147-
add_store_ref!(ls, tempunpacksym, lhsi, elementbytes)
1148-
else
1149-
throw(LoopError("Unpacking the above expression in the left hand side was not understood/supported.", lhsi))
1150-
end
1151-
end
1152-
first(vparents)
1153-
else
1154-
throw(LoopError("LHS not understood; only `:ref`s and `:tuple`s are currently supported.", LHS))
1155-
end
1156-
else
1157-
throw(LoopError("LHS not understood.", LHS))
1158-
end
1167+
add_assignment!(ls, ex.args[1], ex.args[2], elementbytes, position)
11591168
elseif ex.head === :block
11601169
add_block!(ls, ex, elementbytes, position)
11611170
elseif ex.head === :for

test/multiassignments.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using LoopVectorization, Test
2+
3+
4+
function multiassign!(y, x)
5+
@assert length(y)+3 == length(x)
6+
@inbounds for i eachindex(y)
7+
x₁, ((x₂,x₃), (x₄,x₅)) = x[i], (sincos(x[i+1]), (x[i+2], x[i+3]))
8+
y[i] = x₁ * x₄ - x₂ * x₃
9+
end
10+
y
11+
end
12+
multiassign(x) = multiassign!(similar(x, length(x)-3), x)
13+
function multiassign_turbo!(y, x)
14+
@assert length(y)+3 == length(x)
15+
@turbo for i eachindex(y)
16+
x₁, ((x₂,x₃), (x₄,x₅)) = x[i], (sincos(x[i+1]), (x[i+2], x[i+3]))
17+
y[i] = x₁ * x₄ - x₂ * x₃
18+
end
19+
y
20+
end
21+
multiassign_turbo(x) = multiassign_turbo!(similar(x, length(x)-3), x)
22+
23+
@testset "Multiple assignments" begin
24+
x = rand(111);
25+
@test multiassign(x) multiassign_turbo(x)
26+
end
27+

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ const START_TIME = time()
5454
@time include("dot.jl")
5555

5656
@time include("special.jl")
57+
58+
@time include("multiassignments.jl")
5759
end
5860

5961
@time if LOOPVECTORIZATION_TEST == "all" || LOOPVECTORIZATION_TEST == "part2"

0 commit comments

Comments
 (0)