Skip to content
Open
1 change: 1 addition & 0 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac
const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}}
const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes}
const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V}
const DiagonalFill{T,V<:AbstractFillVector{T}} = Diagonal{T,V}
const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}}
const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}}

Expand Down
240 changes: 240 additions & 0 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
end
$OP(a::AbstractOnesMatrix) = fillsimilar(a, reverse(axes(a)))
$OP(a::FillMatrix) = Fill($OP(a.value), reverse(a.axes))
$OP(a::RectDiagonal) = RectDiagonal(vec($OP(a.diag)), reverse(a.axes))
end
end

Expand Down Expand Up @@ -80,10 +81,57 @@

*(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b)
*(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b)
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector})
@eval begin
function *(A::AbstractFillVector, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2))
end
function *(A::AbstractFillMatrix, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2))
end
end
end

# this treats a size (n,) vector as a nx1 matrix, so b needs to have 1 row
# special cased, as OnesMatrix * OnesMatrix isn't a Ones
*(a::AbstractOnesVector, b::AbstractOnesMatrix) = mult_ones(a, b)
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector})
@eval begin
*(A::AbstractOnesVector, B::$type) = mult_ones(A, B)
*(A::AbstractOnesMatrix, B::$type) = mult_ones(A, B)
end
end

for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector})
for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector, AbstractFillMatrix, AbstractZerosMatrix, AbstractOnesMatrix)
@eval begin
function *(A::$type1, B::$type2)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
end
end
end
end

for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}, )
@eval begin
function *(A::AbstractZerosVector, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
end
function *(A::AbstractZerosMatrix, B::$type)
size(A,2) == size(B,1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
end
end
end

*(a::AbstractZerosMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b)
*(a::AbstractZerosMatrix, b::AbstractZerosVector) = mult_zeros(a, b)
Expand Down Expand Up @@ -485,6 +533,198 @@
@inline elconvert(::Type{T}, A::AbstractUnitRange) where T<:Integer = AbstractUnitRange{T}(A)
@inline elconvert(::Type{T}, A::AbstractArray) where T = AbstractArray{T}(A)

# RectDiagonal Multiplication
const RectDiagonalZeros{T,V<:AbstractZerosVector{T}} = RectDiagonal{T,V}
const RectDiagonalOnes{T,V<:AbstractOnesVector{T}} = RectDiagonal{T,V}

function *(A::RectDiagonal, B::Diagonal)
check_matmul_sizes(A, B)
len = minimum(size(A))
RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2)))
end
function *(A::Diagonal, B::RectDiagonal)
check_matmul_sizes(A, B)
len = minimum(size(B))
RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2)))
end

for type in (AbstractMatrix, AbstractTriangular, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})
@eval begin
function *(A::RectDiagonal, B::$type)
check_matmul_sizes(A, B)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B))
diag = A.diag
out = fill!(similar(diag, TS, axes(A,1), axes(B,2)), 0)
len = Base.OneTo(minimum(size(A)))
out[len, :] .= view(diag, len) .* view(B, len, :)
out

Check warning on line 560 in src/fillalgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/fillalgebra.jl#L560

Added line #L560 was not covered by tests
end

function *(A::$type, B::RectDiagonal)
check_matmul_sizes(A, B)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B))
out = fill!(similar(A, TS, axes(A,1), axes(B, 2)), 0)
len = Base.OneTo(minimum(size(B)))
out[:, len] .= view(A, :, len) .* view(reshape(B.diag, 1, :), Base.OneTo(1), len)
out

Check warning on line 569 in src/fillalgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/fillalgebra.jl#L569

Added line #L569 was not covered by tests
end
end
end

function *(A::RectDiagonal, x::AbstractVector)
check_matmul_sizes(A, x)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(x))
diag = A.diag
out = fill!(similar(diag, TS, axes(A,1)), 0)
len = Base.OneTo(minimum(size(A)))
out[len] .= view(diag, len) .* view(x, len)
out

Check warning on line 581 in src/fillalgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/fillalgebra.jl#L581

Added line #L581 was not covered by tests
end

function *(A::RectDiagonal, B::RectDiagonal)
check_matmul_sizes(A, B)
TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B))
out = fill!(similar(A.diag, TS, min(size(A, 1), size(B, 2))), 0)
len = Base.OneTo(min(minimum(size(A)), minimum(size(B))))
out[len] .= view(A.diag, len) .* view(B.diag, len)
RectDiagonal(out, (size(A,1), size(B,2)))
end

for type in (RectDiagonal, RectDiagonalZeros)
@eval begin
function *(A::$type, B::AbstractZerosMatrix)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2))
end

function *(A::$type, B::AbstractZerosVector)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1))
end

function *(A::AbstractZerosMatrix, B::$type)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2))
end

*(A::AdjointAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B
*(A::TransposeAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B
*(A::$type, B::AdjointAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B)
*(A::$type, B::TransposeAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B)
end
end

for type in (AbstractMatrix, RectDiagonal, Diagonal, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}, AbstractTriangular)
@eval begin
function *(A::$type, B::RectDiagonalZeros)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2))
end
function *(A::RectDiagonalZeros, B::$type)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2))
end
end
end
function *(A::RectDiagonalZeros, B::AbstractVector)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1))
end
function *(A::RectDiagonalZeros, B::RectDiagonalZeros)
check_matmul_sizes(A, B)
Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2))
end

*(a::RectDiagonalFill, b::Number) = RectDiagonal(a.diag * b, a.axes)
*(a::Number, b::RectDiagonalFill) = RectDiagonal(a * b.diag, b.axes)

# DiagonalFill Multiplication
const DiagonalZeros{T,V<:AbstractZerosVector{T}} = Diagonal{T,V}
const DiagonalOnes{T,V<:AbstractOnesVector{T}} = Diagonal{T,V}
mat_types = (AbstractMatrix, RectDiagonal, AbstractZerosMatrix,
AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular,
LowerTriangular, UpperTriangular, AbstractTriangular, Symmetric, Hermitian, LinearAlgebra.HermOrSym,
SymTridiagonal, UpperHessenberg, LinearAlgebra.AdjOrTransAbsMat, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement)
for type in tuple(AbstractVector, AbstractZerosVector, mat_types...)
@eval begin
function *(A::DiagonalFill, B::$type)
check_matmul_sizes(A, B)
getindex_value(A.diag) * B
end
*(A::DiagonalZeros, B::$type) = Zeros(A) * B
function *(A::DiagonalOnes, B::$type)
check_matmul_sizes(A, B)
convert(AbstractArray{promote_type(eltype(A), eltype(B))}, deepcopy(B))
end
end
end
*(A::DiagonalOnes, B::AbstractRange) = one(eltype(A)) * B

for type in mat_types
@eval begin
function *(A::$type, B::DiagonalFill)
check_matmul_sizes(A, B)
getindex_value(B.diag) * A
end
*(A::$type, B::DiagonalZeros) = A * Zeros(B)
function *(A::$type, B::DiagonalOnes)
check_matmul_sizes(A, B)
convert(AbstractMatrix{promote_type(eltype(A), eltype(B))}, deepcopy(A))
end
end
end

for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector}, RectDiagonalZeros)
@eval begin
*(A::$type2, B::$type1) = Zeros(A) * B
*(A::$type1, B::$type2) = A * Zeros(B)
end
end
@eval begin
*(A::Diagonal, B::$type1) = Diagonal(A.diag .* B.diag)
*(A::$type1, B::Diagonal) = Diagonal(A.diag .* B.diag)
end
end

for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
for type2 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
@eval begin
*(A::$type1, B::$type2) = Diagonal(A.diag .* B.diag)
end
end
end

*(A::RectDiagonalFill, B::DiagonalZeros) = A * Zeros(B)
*(A::DiagonalZeros, B::RectDiagonalFill) = Zeros(A) * B
for type in (DiagonalFill, DiagonalOnes)
@eval begin
function *(A::$type, B::RectDiagonalFill)
check_matmul_sizes(A, B)
len = Base.OneTo(minimum(size(B)))
RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(B))
end

function *(A::RectDiagonalFill, B::$type)
check_matmul_sizes(A, B)
len = Base.OneTo(minimum(size(A)))
RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(A))
end
end
end

function *(Da::Diagonal, A::RectDiagonal, Db::Diagonal)
check_matmul_sizes(Da, A)
check_matmul_sizes(A, Db)
len = Base.OneTo(minimum(size(A)))
diag = view(Da.diag, len) .* view(A.diag, len) .* view(Db.diag, len)
if diag isa Zeros
Zeros{eltype(diag)}(axes(A))
else
RectDiagonal(diag, axes(A))
end
end

####
# norm
####
Expand Down
15 changes: 15 additions & 0 deletions src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,21 @@ broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r))
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r))

# ternary broadcasting
for type1 in (AbstractArray, AbstractFill, AbstractZeros)
for type2 in (AbstractArray, AbstractFill, AbstractZeros)
for type3 in (AbstractArray, AbstractFill, AbstractZeros)
if type1 === AbstractZeros || type2 === AbstractZeros || type3 === AbstractZeros
@eval begin
broadcasted(::DefaultArrayStyle, ::typeof(*), a::$type1, b::$type2, c::$type3) = Zeros{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c)))
end
end
end
end
end
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractOnes, b::AbstractOnes, c::AbstractOnes) = Ones{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c)))
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractFill, b::AbstractFill, c::AbstractFill) = Fill(getindex_value(a)*getindex_value(b)*getindex_value(c), broadcast_shape(axes(a), axes(b), axes(c)))

# support AbstractFill .^ k
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r))
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r))
Expand Down
24 changes: 24 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,30 @@ function *(D::Diagonal, A::OneElementMatrix)
OneElement(val, A.ind, size(A))
end

function *(A::OneElementMatrix, D::DiagonalZeros)
check_matmul_sizes(A, D)
Zeros{promote_type(eltype(A),eltype(D))}(size(A, 1), size(D, 2))
end

function *(D::DiagonalZeros, A::OneElementMatrix)
check_matmul_sizes(D, A)
Zeros{promote_type(eltype(A),eltype(D))}(size(D, 1), size(A, 2))
end

for type in (DiagonalFill, DiagonalOnes)
@eval begin
function *(A::OneElementMatrix, D::$type)
check_matmul_sizes(A, D)
getindex_value(D.diag) * A
end

function *(D::$type, A::OneElementMatrix)
check_matmul_sizes(D, A)
getindex_value(D.diag) * A
end
end
end

# Inplace multiplication

# We use this for out overloads for _mul! for OneElement because its more efficient
Expand Down
Loading
Loading