Skip to content

Commit 5088cdd

Browse files
authored
Fast path for densecolumnmajor row broadcast (#325)
1 parent 644a762 commit 5088cdd

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/generic/broadcast.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ function _right_colvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{A
477477
end
478478

479479
function __left_rowvec_banded_broadcast!(dest, f, (A,B),
480-
::BandedColumns, ::Tuple{DualLayout{ArrayLayouts.DenseRowMajor}, BandedColumns},
480+
::BandedColumns,
481+
::Tuple{Union{DenseColumnMajor, DualLayout{ArrayLayouts.DenseRowMajor}}, BandedColumns},
481482
(l, u), (A_l,A_u), (m,n))
482483

483484
D = bandeddata(dest)
@@ -540,7 +541,8 @@ function _left_rowvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{Ab
540541
end
541542

542543
function __right_rowvec_banded_broadcast!(dest, f, (A,B),
543-
::BandedColumns, ::Tuple{BandedColumns, DualLayout{ArrayLayouts.DenseRowMajor}},
544+
::BandedColumns,
545+
::Tuple{BandedColumns, Union{DualLayout{ArrayLayouts.DenseRowMajor}, DenseColumnMajor}},
544546
(l, u), (B_l,B_u), (m,n))
545547

546548
D = bandeddata(dest)

test/test_broadcasting.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,15 +491,26 @@ import BandedMatrices: BandedStyle, BandedRows
491491
@test b_ .* A_ == b_ .* Matrix(A_)
492492
@test b_ .* A_ isa BandedMatrix
493493
@test bandwidths(b_ .* A_) == bandwidths(A_)
494+
494495
@test b_' .* A_ == b_' .* Matrix(A_)
495496
@test b_' .* A_ isa BandedMatrix
496497
@test bandwidths(b_' .* A_) == bandwidths(A_)
498+
499+
@test permutedims(b_) .* A_ == permutedims(b_) .* Matrix(A_)
500+
@test permutedims(b_) .* A_ isa BandedMatrix
501+
@test bandwidths(permutedims(b_) .* A_) == bandwidths(A_)
502+
497503
@test A_ .* b_ == Matrix(A_) .* b_
498504
@test A_ .* b_ isa BandedMatrix
499505
@test bandwidths(A_ .* b_) == bandwidths(A_)
506+
500507
@test A_ .* b_' == Matrix(A_) .* b_'
501508
@test A_ .* b_' isa BandedMatrix
502509
@test bandwidths(A_ .* b_') == bandwidths(A_)
510+
511+
@test A_ .* permutedims(b_) == Matrix(A_) .* permutedims(b_)
512+
@test A_ .* permutedims(b_) isa BandedMatrix
513+
@test bandwidths(A_ .* permutedims(b_)) == bandwidths(A_)
503514
end
504515

505516
# division tests currently don't deal with Inf/NaN correctly,
@@ -546,11 +557,17 @@ import BandedMatrices: BandedStyle, BandedRows
546557
D_ .= b_' .* A_
547558
@test D_ == b_' .* A_
548559

560+
D_ .= permutedims(b_) .* A_
561+
@test D_ == permutedims(b_) .* A_
562+
549563
D_ .= A_ .* b_
550564
@test D_ == A_ .* b_
551565

552566
D_ .= A_ .* b_'
553567
@test D_ == A_ .* b_'
568+
569+
D_ .= A_ .* permutedims(b_)
570+
@test D_ == A_ .* permutedims(b_)
554571
end
555572
end
556573
end

0 commit comments

Comments
 (0)