Skip to content

Commit 260b9bc

Browse files
committed
Improve prefetching behavior slightly, update README to note loops will generally be faster than broadcasts.
1 parent d9541a7 commit 260b9bc

File tree

4 files changed

+60
-7
lines changed

4 files changed

+60
-7
lines changed

README.md

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,56 @@ julia> buf1 = Matrix{Float64}(undef, size(C,1), size(C,2));
213213
julia> buf2 = similar(X1);
214214

215215
julia> @btime $X1 .= view($A,1,:) .+ mul!($buf2, $B, ($buf1 .= $C .+ $D'));
216-
7.896 μs (0 allocations: 0 bytes)
216+
9.188 μs (0 allocations: 0 bytes)
217217

218218
julia> @btime @avx $X2 .= view($A,1,:) .+ $B .*ˡ ($C .+ $D');
219-
7.647 μs (0 allocations: 0 bytes)
219+
6.751 μs (0 allocations: 0 bytes)
220+
221+
julia> @test X1 X2
222+
Test Passed
223+
224+
julia> AmulBtest!(X1, B, C, D, view(A,1,:))
225+
226+
julia> AmulBtest2!(X2, B, C, D, view(A,1,:))
220227

221228
julia> @test X1 X2
222229
Test Passed
223230
```
224231
The lazy matrix multiplication operator `` escapes broadcasts and fuses, making it easy to write code that avoids intermediates. However, I would recomend always checking if splitting the operation into pieces, or at least isolating the matrix multiplication, increases performance. That will often be the case, especially if the matrices are large, where a separate multiplication can leverage BLAS (and perhaps take advantage of threads).
225232
This may improve as the optimizations within LoopVectorization improve.
226233

234+
Note that loops will be faster than broadcasting in general. This is because the behavior of broadcasts is determined by runtime information (i.e., dimensions other than the leading dimension of size `1` will be broadcasted; it is not known which these will be at compile time).
235+
```julia
236+
julia> function AmulBtest!(C,A,Bk,Bn,d)
237+
@avx for m axes(A,1), n axes(Bk,2)
238+
ΔCₘₙ = zero(eltype(C))
239+
for k axes(A,2)
240+
ΔCₘₙ += A[m,k] * (Bk[k,n] + Bn[n,k])
241+
end
242+
C[m,n] = ΔCₘₙ + d[m]
243+
end
244+
end
245+
AmulBtest! (generic function with 1 method)
246+
247+
julia> AmulBtest!(X2, B, C, D, view(A,1,:))
248+
249+
julia> @test X1 X2
250+
Test Passed
251+
252+
julia> @benchmark AmulBtest!($X2, $B, $C, $D, view($A,1,:))
253+
BenchmarkTools.Trial:
254+
memory estimate: 0 bytes
255+
allocs estimate: 0
256+
--------------
257+
minimum time: 5.793 μs (0.00% GC)
258+
median time: 5.816 μs (0.00% GC)
259+
mean time: 5.824 μs (0.00% GC)
260+
maximum time: 14.234 μs (0.00% GC)
261+
--------------
262+
samples: 10000
263+
evals/sample: 6
264+
```
265+
227266
</p>
228267
</details>
229268

src/determinestrategy.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
315315
if !(isfinite(u₂float) & isfinite(u₁float)) # brute force
316316
u₁low = u₂low = 1
317317
u₁high = u₂high = REGISTER_COUNT == 32 ? 10 : 6#8
318+
println("Fail")
318319
return solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁step:u₁high, u₂low:u₂step:u₂high)
319320
end
320321
u₁low = floor(Int, u₁float)

src/lower_compute.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11

2-
function load_constrained(op, u₁loop, u₂loop)
3-
unrolleddeps = Symbol[]
2+
function load_constrained(op, u₁loop, u₂loop, forprefetch = false)
43
loopdeps = loopdependencies(op)
5-
u₁loop loopdeps && push!(unrolleddeps, u₁loop)
6-
u₂loop loopdeps && push!(unrolleddeps, u₂loop)
4+
dependsonu₁ = u₁loop loopdeps
5+
if u₂loop === Symbol("##undefined##")
6+
if forprefetch
7+
dependsonu₁ || return false
8+
end
9+
# unrolleddeps = [ u₁loop ]
10+
else
11+
dependsonu₂ = u₂loop loopdeps
12+
if forprefetch
13+
(dependsonu₁ & dependsonu₂) || return false
14+
end
15+
# unrolleddeps = [ u₁loop, u₂loop ]
16+
end
17+
unrolleddeps = Symbol[]
18+
dependsonu₁ && push!(unrolleddeps, u₁loop)
19+
dependsonu₂ && push!(unrolleddeps, u₂loop)
720
any(opp -> isload(opp) && all(in(loopdependencies(opp)), unrolleddeps), parents(op))
821
end
922

src/lower_load.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function prefetchisagoodidea(ls::LoopSet, op::Operation, td::UnrollArgs)
8989
if prod(s -> length(getloop(ls, s)), @view(indices[1:innermostloopind-1])) 120 && length(getloop(ls, innermostloopsym)) 120
9090
if op.ref.ref.offsets[innermostloopind] < 120
9191
for opp operations(ls)
92-
iscompute(opp) && (innermostloopsym loopdependencies(opp)) && load_constrained(opp, u₁loopsym, u₂loopsym) && return 0
92+
iscompute(opp) && (innermostloopsym loopdependencies(opp)) && load_constrained(opp, u₁loopsym, u₂loopsym, true) && return 0
9393
end
9494
return innermostloopind
9595
end

0 commit comments

Comments
 (0)