Skip to content

Commit deec62f

Browse files
committed
Fixed a symbol-index bug.
1 parent e532a98 commit deec62f

File tree

4 files changed

+127
-36
lines changed

4 files changed

+127
-36
lines changed

src/graphs.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,21 @@ function add_vptr!(ls::LoopSet, indexed::Symbol, id::Int)
297297
end
298298
nothing
299299
end
300-
300+
function intersection(depsplus, ls)
301+
deps = Symbol[]
302+
for dep depsplus
303+
dep ls && push!(deps, dep)
304+
end
305+
deps
306+
end
307+
function loopdependencies(ref::ArrayReference, ls::LoopSet)
308+
deps = loopdependencies(ref)
309+
loopset = keys(ls.loops)
310+
for dep deps
311+
dep loopset || return intersection(deps, loopset)
312+
end
313+
deps
314+
end
301315
function add_load!(
302316
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
303317
)
@@ -313,7 +327,7 @@ function add_load!(
313327
# ls.ref_to_sym_aliases[ ref ] = var
314328
op = Operation(
315329
length(operations(ls)), var, elementbytes,
316-
:getindex, memload, loopdependencies(ref),
330+
:getindex, memload, loopdependencies(ref, ls),
317331
NODEPENDENCY, NOPARENTS, ref
318332
)
319333
add_vptr!(ls, ref.array, identifier(op))

src/lowering.jl

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,43 @@ function variable_name(op::Operation, suffix)
99
end
1010

1111

12-
function append_deps!(ret, deps)
13-
if first(deps) === Symbol("##DISCONTIGUOUSSUBARRAY##")
14-
append!(ret.args, @view(deps[2:end]))
15-
else
16-
append!(ret.args, deps)
12+
function append_inds!(ret, indices, deps)
13+
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1# && return append_inds!(ret, @view(indices[2:end]), deps)
14+
for ind @view(indices[start:end])
15+
if ind isa Int
16+
push!(ret.args, ind - 1)
17+
elseif ind deps
18+
push!(ret.args, ind)
19+
else
20+
push!(ret.args, Expr(:call, :-, ind, 1))
21+
end
1722
end
1823
ret
1924
end
2025

2126
function mem_offset(op::Operation)
2227
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
23-
append_deps!(Expr(:tuple), op.ref.ref)
28+
append_inds!(Expr(:tuple), op.ref.ref, op.dependencies)
2429
end
2530
function mem_offset(op::Operation, incr::Int, unrolled::Symbol)
2631
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
2732
ret = Expr(:tuple)
28-
deps = op.ref.ref
33+
indices = op.ref.ref
34+
deps = op.dependencies
2935
if incr == 0
30-
append_deps!(ret, deps)
36+
append_inds!(ret, indices, deps)
3137
else
32-
for n 1:length(deps)
33-
dep = deps[n]
34-
n == 1 && dep === Symbol("##DISCONTIGUOUSSUBARRAY##") && continue
35-
if dep === unrolled
36-
push!(ret.args, Expr(:call, :+, dep, incr))
38+
for n 1:length(indices)
39+
ind = indices[n]
40+
n == 1 && ind === Symbol("##DISCONTIGUOUSSUBARRAY##") && continue
41+
if ind isa Int
42+
push!(ret.args, ind - 1)
43+
elseif ind === unrolled
44+
push!(ret.args, Expr(:call, :+, ind, incr))
45+
elseif ind deps
46+
push!(ret.args, ind)
3747
else
38-
push!(ret.args, dep)
48+
push!(ret.args, Expr(:call, :-, ind, 1))
3949
end
4050
end
4151
end
@@ -45,16 +55,22 @@ function mem_offset(op::Operation, mul::Symbol, incr::Int, unrolled::Symbol)
4555
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
4656
ret = Expr(:tuple)
4757
deps = op.ref.ref
58+
indices = op.ref.ref
59+
deps = op.dependencies
4860
if incr == 0
49-
append_deps!(ret, deps)
61+
append_inds!(ret, indices, deps)
5062
else
51-
for n 1:length(deps)
52-
dep = deps[n]
53-
n == 1 && dep === Symbol("##DISCONTIGUOUSSUBARRAY##") && continue
54-
if dep === unrolled
55-
push!(ret.args, Expr(:call, :+, dep, Expr(:call, lv(:valmul), mul, incr)))
63+
for n 1:length(indices)
64+
ind = indices[n]
65+
n == 1 && ind === Symbol("##DISCONTIGUOUSSUBARRAY##") && continue
66+
if ind isa Int
67+
push!(ret.args, ind - 1)
68+
elseif ind === unrolled
69+
push!(ret.args, Expr(:call, :+, ind, Expr(:call, lv(:valmul), mul, incr)))
70+
elseif ind deps
71+
push!(ret.args, ind)
5672
else
57-
push!(ret.args, dep)
73+
push!(ret.args, Expr(:call, :-, ind, 1))
5874
end
5975
end
6076
end

src/operations.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ struct ArrayReference
22
array::Symbol
33
ref::Vector{Union{Symbol,Int}}
44
loaded::Base.RefValue{Bool}
5-
function ArrayReference(
6-
array, refsin, loadedin = Ref{Bool}(false)
7-
)
8-
ref = Vector{Union{Symbol,Int}}(undef, length(refsin))
9-
for i eachindex(ref)
10-
refᵢ = (refsin[i])::Union{Symbol,Int}
11-
ref[i] = refᵢ isa Int ? refᵢ - 1 : refᵢ
12-
end
13-
new(array, ref, loadedin)
14-
end
5+
# function ArrayReference(
6+
# array, refsin, loadedin = Ref{Bool}(false)
7+
# )
8+
# ref = Vector{Union{Symbol,Int}}(undef, length(refsin))
9+
# for i ∈ eachindex(ref)
10+
# refᵢ = (refsin[i])::Union{Symbol,Int}
11+
# ref[i] = refᵢ isa Int ? refᵢ - 1 : refᵢ
12+
# end
13+
# new(array, ref, loadedin)
14+
# end
1515
end
16+
ArrayReference(array::Symbol, ref) = ArrayReference(array, ref, Ref{Bool}(false))
1617
function ArrayReference(
1718
array::Symbol,
1819
ref::AbstractVector

test/runtests.jl

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,77 @@ using LinearAlgebra
9595
C[i,j] = Cᵢⱼ
9696
end
9797
end
98+
99+
function rank2AmulB!(C, Aₘ, Aₖ, B)
100+
@inbounds for i 1:size(C,1), j 1:size(C,2)
101+
Cᵢⱼ = zero(eltype(C))
102+
@fastmath for k 1:size(B,1)
103+
Cᵢⱼ += (Aₘ[i,1]*Aₖ[1,k]+Aₘ[i,2]*Aₖ[2,k]) * B[k,j]
104+
end
105+
C[i,j] = Cᵢⱼ
106+
end
107+
end
108+
function rank2AmulBavx!(C, Aₘ, Aₖ, B)
109+
@avx for i 1:size(C,1), j 1:size(C,2)
110+
Cᵢⱼ = zero(eltype(C))
111+
for k 1:size(B,1)
112+
Cᵢⱼ += (Aₘ[i,1]*Aₖ[1,k]+Aₘ[i,2]*Aₖ[2,k]) * B[k,j]
113+
end
114+
C[i,j] = Cᵢⱼ
115+
end
116+
end
117+
118+
function mulCAtB_2x2block!(C, A, B)
119+
M, N = size(C); K = size(B,1)
120+
@assert size(C, 1) == size(A, 2)
121+
@assert size(C, 2) == size(B, 2)
122+
@assert size(A, 1) == size(B, 1)
123+
T = eltype(C)
124+
if mod(M, 2) == 0 && mod(N, 2) == 0
125+
for m 1:2:M
126+
m1 = m + 1
127+
for n 1:2:N
128+
n1 = n + 1
129+
C11, C21, C12, C22 = zero(T), zero(T), zero(T), zero(T)
130+
@avx for k 1:K
131+
C11 += A[k,m] * B[k,n]
132+
C21 += A[k,m1] * B[k,n]
133+
C12 += A[k,m] * B[k,n1]
134+
C22 += A[k,m1] * B[k,n1]
135+
end
136+
C[m,n] = C11
137+
C[m1,n] = C21
138+
C[m,n1] = C12
139+
C[m1,n1] = C22
140+
end
141+
end
142+
else
143+
@inbounds for n 1:N, m 1:M
144+
Cmn = 0.0
145+
@inbounds for k 1:K
146+
Cmn += A[k,m] * B[k,n]
147+
end
148+
C[m,n] = Cmn
149+
end
150+
end
151+
return C
152+
end
98153

99154
for T (Float32, Float64)
100-
M, K, N = 72, 75, 71;
155+
M, K, N = 72, 75, 68;
101156
C = Matrix{T}(undef, M, N); A = randn(T, M, K); B = randn(T, K, N);
102157
C2 = similar(C);
103158
AmulBavx!(C, A, B)
104159
AmulB!(C2, A, B)
105160
@test C C2
106161
At = copy(A');
107-
fill!(C, 9999.999);
108-
AtmulBavx!(C, At, B)
162+
fill!(C, 9999.999); AtmulBavx!(C, At, B)
163+
@test C C2
164+
fill!(C, 9999.999); mulCAtB_2x2block!(C, At, B);
165+
@test C C2
166+
Aₘ= rand(T, M, 2); Aₖ = rand(T, 2, K);
167+
rank2AmulBavx!(C, Aₘ, Aₖ, B)
168+
rank2AmulB!(C2, Aₘ, Aₖ, B)
109169
@test C C2
110170
end
111171
end

0 commit comments

Comments
 (0)