Skip to content

Commit e74f057

Browse files
authored
Bandwidth-preserving branches in row/col broadcast (#310)
* Bandwisth preserving branches in row/col broadcast * left/right rowvec bandedcol broadcast * Add tests with sparse row/col
1 parent 6c66f3e commit e74f057

File tree

2 files changed

+217
-124
lines changed

2 files changed

+217
-124
lines changed

src/generic/broadcast.jl

Lines changed: 146 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -385,30 +385,39 @@ function _left_colvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{Ab
385385

386386
d_l, d_u = bandwidths(dest)
387387
A_l, A_u = _broadcast_bandwidths((m-1,n-1),A)
388+
@assert A_u == n-1
388389
B_l, B_u = bandwidths(B)
389390
(d_l min(l,m-1) && d_u min(u,n-1)) || throw(BandError(dest))
390391

391-
for j=1:n
392-
for k = max(1,j-d_u):min(j-u-1,m)
393-
inbands_setindex!(dest, z, k, j)
394-
end
395-
for k = max(1,j-d_u,j-A_u):min(j-B_u-1,j+d_l,m)
396-
inbands_setindex!(dest, f(A[k], zero(V)), k, j)
397-
end
398-
for k = max(1,j-d_u,j-B_u):min(j-A_u-1,j+d_l,m)
399-
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
400-
end
401-
for k = max(1,j-min(A_u,B_u)):min(j+min(A_l,B_l),m)
402-
inbands_setindex!(dest, f(A[k], inbands_getindex(B, k, j)), k, j)
403-
end
404-
for k = max(1,j-d_u,j+B_l+1):min(j+A_l,j+d_l,m)
405-
inbands_setindex!(dest, f(A[k], zero(V)), k, j)
406-
end
407-
for k = max(1,j-d_u,j+A_l+1):min(j+B_l,j+d_l,m)
408-
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
392+
if d_l == B_l == l && d_u == B_u == u
393+
for j=rowsupport(dest)
394+
for k = max(1,j-u):min(j+min(A_l,l),m)
395+
inbands_setindex!(dest, f(A[k], inbands_getindex(B, k, j)), k, j)
396+
end
397+
for k = max(1,j-u,j+A_l+1):min(j+l,m)
398+
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
399+
end
409400
end
410-
for k = max(1,j+l+1):min(j+d_l,m)
411-
inbands_setindex!(dest, z, k, j)
401+
else
402+
for j=rowsupport(dest)
403+
for k = max(1,j-d_u):min(j-u-1,m)
404+
inbands_setindex!(dest, z, k, j)
405+
end
406+
for k = max(1,j-d_u):min(j-B_u-1,j+d_l,m)
407+
inbands_setindex!(dest, f(A[k], zero(V)), k, j)
408+
end
409+
for k = max(1,j-B_u):min(j+min(A_l,B_l),m)
410+
inbands_setindex!(dest, f(A[k], inbands_getindex(B, k, j)), k, j)
411+
end
412+
for k = max(1,j-d_u,j+B_l+1):min(j+A_l,j+d_l,m)
413+
inbands_setindex!(dest, f(A[k], zero(V)), k, j)
414+
end
415+
for k = max(1,j-d_u,j+A_l+1):min(j+B_l,j+d_l,m)
416+
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
417+
end
418+
for k = max(1,j+l+1):min(j+d_l,m)
419+
inbands_setindex!(dest, z, k, j)
420+
end
412421
end
413422
end
414423
dest
@@ -427,32 +436,62 @@ function _right_colvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{A
427436
d_l, d_u = bandwidths(dest)
428437
A_l, A_u = bandwidths(A)
429438
B_l, B_u = _broadcast_bandwidths((m-1,n-1),B)
439+
@assert B_u == n-1
430440
(d_l min(l,m-1) && d_u min(u,n-1)) || throw(BandError(dest))
431441

432-
for j=1:n
433-
for k = max(1,j-d_u):min(j-u-1,m)
434-
inbands_setindex!(dest, z, k, j)
435-
end
436-
for k = max(1,j-d_u,j-A_u):min(j-B_u-1,j+d_l,m)
437-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
438-
end
439-
for k = max(1,j-d_u,j-B_u):min(j-A_u-1,j+d_l,m)
440-
inbands_setindex!(dest, f(zero(T), B[k]), k, j)
441-
end
442-
for k = max(1,j-min(A_u,B_u)):min(j+min(A_l,B_l),m)
443-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), B[k]), k, j)
442+
if d_l == A_l == l && d_u == A_u == u
443+
for j=rowsupport(dest)
444+
for k = max(1,j-u):min(j+min(l,B_l),m)
445+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), B[k]), k, j)
446+
end
447+
for k = max(1,j-u,j+B_l+1):min(j+l,m)
448+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
449+
end
444450
end
445-
for k = max(1,j-d_u,j+B_l+1):min(j+A_l,j+d_l,m)
446-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
451+
else
452+
for j=rowsupport(dest)
453+
for k = max(1,j-d_u):min(j-u-1,m)
454+
inbands_setindex!(dest, z, k, j)
455+
end
456+
for k = max(1,j-d_u):min(j-A_u-1,j+d_l,m)
457+
inbands_setindex!(dest, f(zero(T), B[k]), k, j)
458+
end
459+
for k = max(1,j-A_u):min(j+min(A_l,B_l),m)
460+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), B[k]), k, j)
461+
end
462+
for k = max(1,j-d_u,j+B_l+1):min(j+A_l,j+d_l,m)
463+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
464+
end
465+
for k = max(1,j-d_u,j+A_l+1):min(j+B_l,j+d_l,m)
466+
inbands_setindex!(dest, f(zero(T), B[k]), k, j)
467+
end
468+
for k = max(1,j+l+1):min(j+d_l,m)
469+
inbands_setindex!(dest, z, k, j)
470+
end
447471
end
448-
for k = max(1,j-d_u,j+A_l+1):min(j+B_l,j+d_l,m)
449-
inbands_setindex!(dest, f(zero(T), B[k]), k, j)
472+
end
473+
dest
474+
end
475+
476+
function __left_rowvec_banded_broadcast!(dest, f, (A,B),
477+
::BandedColumns, ::Tuple{DualLayout{ArrayLayouts.DenseRowMajor}, BandedColumns},
478+
(l, u), (A_l,A_u), (m,n))
479+
480+
D = bandeddata(dest)
481+
Bd = bandeddata(B)
482+
D .= f.(A, Bd)
483+
return nothing
484+
end
485+
486+
function __left_rowvec_banded_broadcast!(dest, f, (A,B), _1, _2, (l, u), (A_l,A_u), (m,n))
487+
for j=rowsupport(dest)
488+
for k = max(1,j-u):min(j-A_u-1,j+l,m)
489+
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
450490
end
451-
for k = max(1,j+l+1):min(j+d_l,m)
452-
inbands_setindex!(dest, z, k, j)
491+
for k = max(1,j-min(A_u,u)):min(j+l,m)
492+
inbands_setindex!(dest, f(A[j], inbands_getindex(B, k, j)), k, j)
453493
end
454494
end
455-
dest
456495
end
457496

458497
function _left_rowvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix{T},AbstractMatrix{V}}, _1, _2) where {T,V}
@@ -465,33 +504,57 @@ function _left_rowvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{Ab
465504

466505
d_l, d_u = bandwidths(dest)
467506
A_l, A_u = _broadcast_bandwidths((m-1,n-1),A)
507+
@assert A_l == m-1
468508
B_l, B_u = bandwidths(B)
469509
(d_l min(l,m-1) && d_u min(u,n-1)) || throw(BandError(dest))
470510

471-
for j=1:n
472-
for k = max(1,j-d_u):min(j-u-1,m)
473-
inbands_setindex!(dest, z, k, j)
474-
end
475-
for k = max(1,j-d_u,j-A_u):min(j-B_u-1,j+d_l,m)
476-
inbands_setindex!(dest, f(A[j], zero(V)), k, j)
477-
end
478-
for k = max(1,j-d_u,j-B_u):min(j-A_u-1,j+d_l,m)
479-
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
480-
end
481-
for k = max(1,j-min(A_u,B_u)):min(j+min(A_l,B_l),m)
482-
inbands_setindex!(dest, f(A[j], inbands_getindex(B, k, j)), k, j)
483-
end
484-
for k = max(1,j-d_u,j+B_l+1):min(j+A_l,j+d_l,m)
485-
inbands_setindex!(dest, f(A[j], zero(V)), k, j)
511+
if d_l == B_l == l && d_u == B_u == u
512+
__left_rowvec_banded_broadcast!(dest, f, (A,B), _1, _2,
513+
(l, u), (A_l,A_u), (m,n))
514+
else
515+
for j=rowsupport(dest)
516+
for k = max(1,j-d_u):min(j-u-1,m)
517+
inbands_setindex!(dest, z, k, j)
518+
end
519+
for k = max(1,j-d_u,j-A_u):min(j-B_u-1,j+d_l,m)
520+
inbands_setindex!(dest, f(A[j], zero(V)), k, j)
521+
end
522+
for k = max(1,j-d_u,j-B_u):min(j-A_u-1,j+d_l,m)
523+
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
524+
end
525+
for k = max(1,j-min(A_u,B_u)):min(j+min(A_l,B_l),m)
526+
inbands_setindex!(dest, f(A[j], inbands_getindex(B, k, j)), k, j)
527+
end
528+
for k = max(1,j-d_u,j+B_l+1):min(j+d_l,m)
529+
inbands_setindex!(dest, f(A[j], zero(V)), k, j)
530+
end
531+
for k = max(1,j+l+1):min(j+d_l,m)
532+
inbands_setindex!(dest, z, k, j)
533+
end
486534
end
487-
for k = max(1,j-d_u,j+A_l+1):min(j+B_l,j+d_l,m)
488-
inbands_setindex!(dest, f(zero(T), inbands_getindex(B, k, j)), k, j)
535+
end
536+
dest
537+
end
538+
539+
function __right_rowvec_banded_broadcast!(dest, f, (A,B),
540+
::BandedColumns, ::Tuple{BandedColumns, DualLayout{ArrayLayouts.DenseRowMajor}},
541+
(l, u), (B_l,B_u), (m,n))
542+
543+
D = bandeddata(dest)
544+
Ad = bandeddata(A)
545+
D .= f.(Ad, B)
546+
return nothing
547+
end
548+
549+
function __right_rowvec_banded_broadcast!(dest, f, (A,B), _1, _2, (l, u), (B_l,B_u), (m,n))
550+
for j=rowsupport(dest)
551+
for k = max(1,j-u):min(j-B_u-1,j+l,m)
552+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
489553
end
490-
for k = max(1,j+l+1):min(j+d_l,m)
491-
inbands_setindex!(dest, z, k, j)
554+
for k = max(1,j-min(u,B_u)):min(j+l,m)
555+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), B[j]), k, j)
492556
end
493557
end
494-
dest
495558
end
496559

497560
function _right_rowvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix{T},AbstractMatrix{V}}, _1, _2) where {T,V}
@@ -505,29 +568,32 @@ function _right_rowvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{A
505568
d_l, d_u = bandwidths(dest)
506569
A_l, A_u = bandwidths(A)
507570
B_l, B_u = _broadcast_bandwidths((m-1,n-1),B)
571+
@assert B_l == m-1
508572
(d_l min(l,m-1) && d_u min(u,n-1)) || throw(BandError(dest))
509573

510-
for j=1:n
511-
for k = max(1,j-d_u):min(j-u-1,m)
512-
inbands_setindex!(dest, z, k, j)
513-
end
514-
for k = max(1,j-d_u,j-A_u):min(j-B_u-1,j+d_l,m)
515-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
516-
end
517-
for k = max(1,j-d_u,j-B_u):min(j-A_u-1,j+d_l,m)
518-
inbands_setindex!(dest, f(zero(T), B[j]), k, j)
519-
end
520-
for k = max(1,j-min(A_u,B_u)):min(j+min(A_l,B_l),m)
521-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), B[j]), k, j)
522-
end
523-
for k = max(1,j-d_u,j+B_l+1):min(j+A_l,j+d_l,m)
524-
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
525-
end
526-
for k = max(1,j-d_u,j+A_l+1):min(j+B_l,j+d_l,m)
527-
inbands_setindex!(dest, f(zero(T), B[j]), k, j)
528-
end
529-
for k = max(1,j+l+1):min(j+d_l,m)
530-
inbands_setindex!(dest, z, k, j)
574+
if d_l == A_l == l && d_u == A_u == u
575+
__right_rowvec_banded_broadcast!(dest, f, (A,B), _1, _2,
576+
(l, u), (B_l,B_u), (m,n))
577+
else
578+
for j=rowsupport(dest)
579+
for k = max(1,j-d_u):min(j-u-1,m)
580+
inbands_setindex!(dest, z, k, j)
581+
end
582+
for k = max(1,j-d_u,j-A_u):min(j-B_u-1,j+d_l,m)
583+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), zero(V)), k, j)
584+
end
585+
for k = max(1,j-d_u,j-B_u):min(j-A_u-1,j+d_l,m)
586+
inbands_setindex!(dest, f(zero(T), B[j]), k, j)
587+
end
588+
for k = max(1,j-min(A_u,B_u)):min(j+A_l,m)
589+
inbands_setindex!(dest, f(inbands_getindex(A, k, j), B[j]), k, j)
590+
end
591+
for k = max(1,j-d_u,j+A_l+1):min(j+d_l,m)
592+
inbands_setindex!(dest, f(zero(T), B[j]), k, j)
593+
end
594+
for k = max(1,j+l+1):min(j+d_l,m)
595+
inbands_setindex!(dest, z, k, j)
596+
end
531597
end
532598
end
533599
dest

0 commit comments

Comments
 (0)