diff --git a/src/broadcast.jl b/src/broadcast.jl index ec722b3..48b6d72 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -96,7 +96,13 @@ import BandedMatrices: _isweakzero function blockbandwidths(bc::Broadcasted) (a,b) = size(bc) bnds = (a-1,b-1) - _isweakzero(bc.f, bc.args...) && return min.(bnds, max.(_broadcast_blockbandwidths.(Ref(bnds), bc.args, Ref(axes(bc)))...)) + if _isweakzero(bc.f, bc.args...) + ax = axes(bc) + t = map(bc.args) do x + _broadcast_blockbandwidths(bnds, x, ax) + end + return min.(bnds, max.(t...)) + end bnds end diff --git a/test/test_broadcasting.jl b/test/test_broadcasting.jl index cf1c04b..06758fb 100644 --- a/test/test_broadcasting.jl +++ b/test/test_broadcasting.jl @@ -289,6 +289,14 @@ import Base: oneto @test C == A + A end end + + @testset "blockbandwidths" begin + B = BlockArray(ones(6,6), 1:3, 1:3) + BB = BlockBandedMatrix(B, (1,1)) + bc = Broadcast.broadcasted(+, BB, BB) + bbw = @inferred blockbandwidths(bc) + @test bbw == blockbandwidths(BB) + end end end # module