Skip to content

Commit 8b4c777

Browse files
authored
matching bandwidth branch in banded broadcast (#328)
1 parent 12e99db commit 8b4c777

File tree

2 files changed

+40
-48
lines changed

2 files changed

+40
-48
lines changed

src/BandedMatrices.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,4 @@ include("interfaceimpl.jl")
9696

9797
include("precompile.jl")
9898

99-
# function _precompile_()
100-
# precompile(Tuple{typeof(gbmm!), Char, Char, Float64, BandedMatrix{Float64,Array{Float64,2},Base.OneTo{Int64}}, BandedMatrix{Float64,Array{Float64,2},Base.OneTo{Int64}}, Float64, BandedMatrix{Float64,Array{Float64,2},Base.OneTo{Int64}}})
101-
# end
102-
103-
# _precompile_()
104-
105-
# precompile instructions
106-
let B = BandedMatrix(0=>zeros(0)), v = zeros(size(B,2))
107-
BT = typeof(B)
108-
vT = typeof(v)
109-
@assert precompile(+, (BT, BT))
110-
@assert precompile(-, (BT,))
111-
@assert precompile(-, (BT, BT))
112-
@assert precompile(*, (BT, vT))
113-
end
114-
11599
end #module

src/generic/broadcast.jl

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ end
490490
function __left_rowvec_banded_broadcast!(dest, f, (A,B), _1, _2, (l, u), (A_l,A_u), (m,n))
491491
for j=rowsupport(dest)
492492
for k = max(1,j-u):min(j-A_u-1,j+l,m)
493-
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
493+
inbands_setindex!(dest, f(zero(eltype(A)), inbands_getindex(B, k, j)), k, j)
494494
end
495495
for k = max(1,j-min(A_u,u)):min(j+l,m)
496496
inbands_setindex!(dest, f(A[j], inbands_getindex(B, k, j)), k, j)
@@ -554,7 +554,7 @@ end
554554
function __right_rowvec_banded_broadcast!(dest, f, (A,B), _1, _2, (l, u), (B_l,B_u), (m,n))
555555
for j=rowsupport(dest)
556556
for k = max(1,j-u):min(j-B_u-1,j+l,m)
557-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
557+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(eltype(B))), k, j)
558558
end
559559
for k = max(1,j-min(u,B_u)):min(j+l,m)
560560
inbands_setindex!(dest, f(inbands_getindex(A, k, j), B[j]), k, j)
@@ -640,37 +640,45 @@ function __banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatri
640640
B_l, B_u = bandwidths(B)
641641
(d_l min(l,m-1) && d_u min(u,n-1)) || throw(BandError(dest))
642642

643-
for j=rowsupport(dest)
644-
for k = max(1,j-d_u):min(j-u-1,j+d_l,m)
645-
inbands_setindex!(dest, z, k, j)
646-
end
647-
for k = max(1,j-min(A_u,d_u)):min(j-B_u-1,j+min(A_l,d_l),m)
648-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
649-
end
650-
for k = max(1, j+A_l+1, j-d_u):min(j-B_u-1,j+d_l,m)
651-
# This is hit in A+B with bandwidth(A) == (-2,2) && bandwidth(B) == (0,0)
652-
# The result has a bandwidth (2,0). This sets dest[band(1)] to zero
653-
inbands_setindex!(dest, z, k, j)
654-
end
655-
for k = max(1,j-min(B_u,d_u)):min(j-A_u-1,j+min(B_l,d_l),m)
656-
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
657-
end
658-
for k = max(1,j-min(A_u,B_u,d_u)):min(j+min(A_l,B_l,d_l),m)
659-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), inbands_getindex(B, k, j)), k, j)
660-
end
661-
for k = max(1,j+B_l+1,j-min(A_u,d_u)):min(j+min(A_l,d_l),m)
662-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
663-
end
664-
for k = max(1, j+B_l+1, j-d_u):min(j-A_u-1,j+d_l,m)
665-
# This is hit in A + B with bandwidth(A) == (2,-2) && bandwidth(B) == (0,0)
666-
# The result has a bandwidth (2,0). This sets dest[band(-1)] to zero
667-
inbands_setindex!(dest, z, k, j)
668-
end
669-
for k = max(1,j+A_l+1,j-min(d_u, B_u)):min(j+min(B_l,d_l),m)
670-
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
643+
if (d_l,d_u) == (A_l,A_u) == (B_l,B_u) == (l,u)
644+
for j=rowsupport(dest)
645+
for k = max(1,j-u):min(j+l,m)
646+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), inbands_getindex(B, k, j)), k, j)
647+
end
671648
end
672-
for k = max(1,j-d_u,j+l+1):min(j+d_l,m)
673-
inbands_setindex!(dest, z, k, j)
649+
else
650+
for j=rowsupport(dest)
651+
for k = max(1,j-d_u):min(j-u-1,j+d_l,m)
652+
inbands_setindex!(dest, z, k, j)
653+
end
654+
for k = max(1,j-min(A_u,d_u)):min(j-B_u-1,j+min(A_l,d_l),m)
655+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
656+
end
657+
for k = max(1, j+A_l+1, j-d_u):min(j-B_u-1,j+d_l,m)
658+
# This is hit in A+B with bandwidth(A) == (-2,2) && bandwidth(B) == (0,0)
659+
# The result has a bandwidth (2,0). This sets dest[band(1)] to zero
660+
inbands_setindex!(dest, z, k, j)
661+
end
662+
for k = max(1,j-min(B_u,d_u)):min(j-A_u-1,j+min(B_l,d_l),m)
663+
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
664+
end
665+
for k = max(1,j-min(A_u,B_u,d_u)):min(j+min(A_l,B_l,d_l),m)
666+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), inbands_getindex(B, k, j)), k, j)
667+
end
668+
for k = max(1,j+B_l+1,j-min(A_u,d_u)):min(j+min(A_l,d_l),m)
669+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
670+
end
671+
for k = max(1, j+B_l+1, j-d_u):min(j-A_u-1,j+d_l,m)
672+
# This is hit in A + B with bandwidth(A) == (2,-2) && bandwidth(B) == (0,0)
673+
# The result has a bandwidth (2,0). This sets dest[band(-1)] to zero
674+
inbands_setindex!(dest, z, k, j)
675+
end
676+
for k = max(1,j+A_l+1,j-min(d_u, B_u)):min(j+min(B_l,d_l),m)
677+
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
678+
end
679+
for k = max(1,j-d_u,j+l+1):min(j+d_l,m)
680+
inbands_setindex!(dest, z, k, j)
681+
end
674682
end
675683
end
676684
dest

0 commit comments

Comments
 (0)