Skip to content

Commit 5840fc6

Browse files
authored
Resolve some adj/trans and triangular matrix multiplication ambiguities (#325)
* Remove ambiguity in transpose matrix * zeros vector * Resolve transpose vec * Zeros Matrix * disambiguate with transpose-adjoint wrapper * disambiguate against AbstractTriangular
1 parent cf8c78d commit 5840fc6

File tree

3 files changed

+55
-9
lines changed

3 files changed

+55
-9
lines changed

src/FillArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
1111

1212
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
1313
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
14-
issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron
14+
issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron, AbstractTriangular
1515

1616

1717
import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape, BroadcastStyle, Broadcasted

src/fillalgebra.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,18 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b))
9393
*(a::AbstractFillMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b)
9494
*(a::AbstractFillMatrix, b::AbstractZerosVector) = mult_zeros(a, b)
9595

96-
*(a::AbstractZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
97-
*(a::AbstractMatrix, b::AbstractZerosVector) = mult_zeros(a, b)
98-
*(a::AbstractMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b)
96+
for MT in (:AbstractMatrix, :AbstractTriangular)
97+
@eval *(a::AbstractZerosMatrix, b::$MT) = mult_zeros(a, b)
98+
@eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b)
99+
end
100+
# Odd way to deal with the type-parameters to avoid ambiguities
101+
for MT in (:(AbstractMatrix{T}), :(Transpose{<:Any, <:AbstractMatrix{T}}), :(Adjoint{<:Any, <:AbstractMatrix{T}}),
102+
:(AbstractTriangular{T}))
103+
@eval *(a::$MT, b::AbstractZerosVector) where {T} = mult_zeros(a, b)
104+
end
105+
for MT in (:(Transpose{<:Any, <:AbstractVector}), :(Adjoint{<:Any, <:AbstractVector}))
106+
@eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b)
107+
end
99108
*(a::AbstractZerosMatrix, b::AbstractVector) = mult_zeros(a, b)
100109

101110
function lmul_diag(a::Diagonal, b)
@@ -290,13 +299,25 @@ function _adjvec_mul_zeros(a, b)
290299
return a1 * b[1]
291300
end
292301

293-
*(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractMatrix) = (b' * a')'
302+
for MT in (:AbstractMatrix, :AbstractTriangular, :(Adjoint{<:Any,<:TransposeAbsVec}))
303+
@eval *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::$MT) = (b' * a')'
304+
end
305+
# ambiguity
306+
function *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::TransposeAbsVec{<:Any,<:AdjointAbsVec})
307+
# change from Transpose ∘ Adjoint to Adjoint ∘ Transpose
308+
b2 = adjoint(transpose(adjoint(transpose(b))))
309+
a * b2
310+
end
294311
*(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractZerosMatrix) = (b' * a')'
295-
*(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractMatrix) = transpose(transpose(b) * transpose(a))
312+
for MT in (:AbstractMatrix, :AbstractTriangular, :(Transpose{<:Any,<:AdjointAbsVec}))
313+
@eval *(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::$MT) = transpose(transpose(b) * transpose(a))
314+
end
296315
*(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractZerosMatrix) = transpose(transpose(b) * transpose(a))
297316

298317
*(a::AbstractVector, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
299-
*(a::AbstractMatrix, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
318+
for MT in (:AbstractMatrix, :AbstractTriangular)
319+
@eval *(a::$MT, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
320+
end
300321
*(a::AbstractZerosVector, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
301322
*(a::AbstractZerosMatrix, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
302323

@@ -307,7 +328,8 @@ end
307328

308329
*(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::AbstractZeros{<:Any, 1}) = mult_zeros(a, b)
309330

310-
*(D::Diagonal, a::AdjointAbsVec{<:Any,<:AbstractZerosVector}) = (a' * D')'
331+
*(D::Diagonal, a::Adjoint{<:Any,<:AbstractZerosVector}) = (a' * D')'
332+
*(D::Diagonal, a::Transpose{<:Any,<:AbstractZerosVector}) = transpose(transpose(a) * transpose(D))
311333
*(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal) = (D' * a')'
312334
*(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal) = transpose(D*transpose(a))
313335
function _triple_zeromul(x, D::Diagonal, y)
@@ -325,7 +347,7 @@ end
325347
*(x::TransposeAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal, y::AbstractZerosVector) = _triple_zeromul(x, D, y)
326348

327349

328-
function *(a::Transpose{T, <:AbstractVector{T}}, b::AbstractZerosVector{T}) where T<:Real
350+
function *(a::Transpose{T, <:AbstractVector}, b::AbstractZerosVector{T}) where T<:Real
329351
la, lb = length(a), length(b)
330352
if la lb
331353
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))

test/runtests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,14 @@ end
15791579
@test A*Zeros(nA,1) Zeros(mA,1)
15801580
@test a*Zeros(na,3) Zeros(la,3)
15811581

1582+
@test transpose(A) * Zeros(mA) Zeros(nA)
1583+
@test A' * Zeros(mA) Zeros(nA)
1584+
1585+
@test transpose(a) * Zeros(la, 3) Zeros(1,3)
1586+
@test a' * Zeros(la,3) Zeros(1,3)
1587+
1588+
@test Zeros(la)' * Transpose(Adjoint(a)) == 0.0
1589+
15821590
w = zeros(mA)
15831591
@test mul!(w, A, Fill(2,nA), true, false) A * fill(2,nA)
15841592
w .= 2
@@ -1658,6 +1666,22 @@ end
16581666
@test adjoint(A)*fillvec adjoint(A)*Array(fillvec)
16591667
@test adjoint(A)*fillmat adjoint(A)*Array(fillmat)
16601668
end
1669+
1670+
@testset "ambiguities" begin
1671+
UT33 = UpperTriangular(ones(3,3))
1672+
UT11 = UpperTriangular(ones(1,1))
1673+
@test transpose(Zeros(3)) * Transpose(Adjoint([1,2,3])) == 0
1674+
@test Zeros(3)' * Adjoint(Transpose([1,2,3])) == 0
1675+
@test Zeros(3)' * UT33 == Zeros(3)'
1676+
@test transpose(Zeros(3)) * UT33 == transpose(Zeros(3))
1677+
@test UT11 * Zeros(3)' == Zeros(1,3)
1678+
@test UT11 * transpose(Zeros(3)) == Zeros(1,3)
1679+
@test Zeros(2,3) * UT33 == Zeros(2,3)
1680+
@test UT33 * Zeros(3,2) == Zeros(3,2)
1681+
@test UT33 * Zeros(3) == Zeros(3)
1682+
@test Diagonal([1]) * transpose(Zeros(3)) == Zeros(1,3)
1683+
@test Diagonal([1]) * Zeros(3)' == Zeros(1,3)
1684+
end
16611685
end
16621686

16631687
@testset "count" begin

0 commit comments

Comments
 (0)