Skip to content

Commit c9367f7

Browse files
authored
Triangular of Tridiagonal, use unalias (#37)
* Triangular of Tridiagonal * Use unalias, improve array-valued mul * use fill! for MulAdd * Increase coverage
1 parent ebf0312 commit c9367f7

File tree

7 files changed

+82
-68
lines changed

7 files changed

+82
-68
lines changed

src/ArrayLayouts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import Base: AbstractArray, AbstractMatrix, AbstractVector,
2727
similar, @_gc_preserve_end, @_gc_preserve_begin,
2828
@nexprs, @ncall, @ntuple, tuple_type_tail,
2929
all, any, isbitsunion, issubset, replace_in_print_matrix, replace_with_centered_mark,
30-
strides, unsafe_convert, first_index
30+
strides, unsafe_convert, first_index, unalias
3131

3232
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcasted,
3333
combine_eltypes, DefaultArrayStyle, instantiate, materialize,

src/ldiv.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ macro _layoutldiv(Typ)
138138

139139
LinearAlgebra.ldiv!(A::Factorization, x::$Typ) = ArrayLayouts.ldiv!(A,x)
140140

141+
LinearAlgebra.ldiv!(A::Bidiagonal, B::$Typ) = ArrayLayouts.ldiv!(A,B)
142+
143+
141144
Base.:\(A::$Typ, x::AbstractVector) = ArrayLayouts.ldiv(A,x)
142145
Base.:\(A::$Typ, x::AbstractMatrix) = ArrayLayouts.ldiv(A,x)
143146

src/memorylayout.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ MemoryLayout(A::Type{UpperTriangular{T,P}}) where {T,P} = triangularlayout(Upper
485485
MemoryLayout(A::Type{UnitUpperTriangular{T,P}}) where {T,P} = triangularlayout(UnitUpperTriangularLayout, MemoryLayout(P))
486486
MemoryLayout(A::Type{LowerTriangular{T,P}}) where {T,P} = triangularlayout(LowerTriangularLayout, MemoryLayout(P))
487487
MemoryLayout(A::Type{UnitLowerTriangular{T,P}}) where {T,P} = triangularlayout(UnitLowerTriangularLayout, MemoryLayout(P))
488-
triangularlayout(_::Type{Tri}, ::MemoryLayout) where Tri = Tri{UnknownLayout}()
488+
triangularlayout(::Type{Tri}, ::MemoryLayout) where Tri = Tri{UnknownLayout}()
489489
triangularlayout(::Type{Tri}, ::ML) where {Tri, ML<:AbstractColumnMajor} = Tri{ML}()
490490
triangularlayout(::Type{Tri}, ::ML) where {Tri, ML<:AbstractRowMajor} = Tri{ML}()
491491
triangularlayout(::Type{Tri}, ::ML) where {Tri, ML<:ConjLayout{<:AbstractRowMajor}} = Tri{ML}()
@@ -574,6 +574,17 @@ transposelayout(ml::SymTridiagonalLayout) = ml
574574
transposelayout(ml::TridiagonalLayout) = ml
575575
transposelayout(ml::ConjLayout{DiagonalLayout}) = ml
576576

577+
triangularlayout(::Type{<:TriangularLayout{UPLO,'N'}}, ::TridiagonalLayout{ML}) where {UPLO,ML} = BidiagonalLayout{ML}()
578+
triangularlayout(::Type{<:TriangularLayout{UPLO,'N'}}, ::TridiagonalLayout{FillLayout}) where UPLO = BidiagonalLayout{FillLayout}()
579+
triangularlayout(::Type{<:TriangularLayout{UPLO,'U'}}, ::TridiagonalLayout{FillLayout}) where UPLO = BidiagonalLayout{FillLayout}()
580+
581+
bidiagonaluplo(::Union{UpperTriangular,UnitUpperTriangular}) = 'U'
582+
bidiagonaluplo(::Union{LowerTriangular,UnitLowerTriangular}) = 'L'
583+
diagonaldata(U::Union{UnitUpperTriangular{T},UnitLowerTriangular{T}}) where T = Ones{T}(size(U,1))
584+
diagonaldata(U::Union{UpperTriangular{T},LowerTriangular{T}}) where T = diagonaldata(triangulardata(U))
585+
supdiagonaldata(U::Union{UnitUpperTriangular,UpperTriangular}) = supdiagonaldata(triangulardata(U))
586+
subdiagonaldata(U::Union{UnitLowerTriangular,LowerTriangular}) = subdiagonaldata(triangulardata(U))
587+
577588
adjointlayout(::Type{<:Real}, ml::SymTridiagonalLayout) = ml
578589
adjointlayout(::Type{<:Real}, ml::TridiagonalLayout) = ml
579590
adjointlayout(::Type{<:Real}, ml::BidiagonalLayout) = ml

src/mul.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ axes(M::Mul, p::Int) = axes(M)[p]
1919
length(M::Mul) = prod(size(M))
2020
size(M::Mul) = map(length,axes(M))
2121

22-
_mul_axes(A::Tuple{<:Any,<:Any}, ::Tuple{<:Any}) = (A[1],)
23-
_mul_axes(A::Tuple{<:Any}, B::Tuple{<:Any,<:Any}) = (A[1],B[2])
24-
_mul_axes(A::Tuple{<:Any,<:Any}, B::Tuple{<:Any,<:Any}) = (A[1],B[2])
22+
_mul_axes(A::Tuple{<:Any,<:Any}, ::Tuple{<:Any}) = (A[1],) # matrix * vector
23+
_mul_axes(A::Tuple{<:Any}, B::Tuple{<:Any,<:Any}) = (A[1],B[2]) # vector * matrix
24+
_mul_axes(A::Tuple{<:Any,<:Any}, B::Tuple{<:Any,<:Any}) = (A[1],B[2]) # matrix * matrix
25+
_mul_axes(::Tuple{}, ::Tuple{}) = () # scalar * scalar
26+
_mul_axes(::Tuple{}, B::Tuple) = B # scalar * B
27+
_mul_axes(A::Tuple, B::Tuple{}) = A # A * scalar
2528

2629
axes(M::Mul) = _mul_axes(axes(M.A), axes(M.B))
2730

src/muladd.jl

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ end
1919
MulAdd{StyleA,StyleB,StyleC}(α, A, B, β, C)
2020
end
2121

22-
@inline MulAdd(α, A::AA, B::BB, β, C::CC) where {AA,BB,CC} =
22+
@inline MulAdd(α, A::AA, B::BB, β, C::CC) where {AA,BB,CC} =
2323
MulAdd{typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))}(α, A, B, β, C)
2424

2525
MulAdd(A, B) = MulAdd(Mul(A, B))
2626
function MulAdd(M::Mul)
2727
TV = eltype(M)
28-
MulAdd(scalarone(TV), M.A, M.B, scalarzero(TV), fillzeros(TV,axes(M)))
28+
MulAdd(scalarone(TV), M.A, M.B, scalarzero(TV), mulzeros(TV,M))
2929
end
3030

3131
@inline eltype(::MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC}) where {StyleA,StyleB,StyleC,T,AA,BB,CC} =
@@ -69,18 +69,11 @@ muladd!(α, A, B, β, C) = materialize!(MulAdd(α, A, B, β, C))
6969
materialize(M::MulAdd) = copy(instantiate(M))
7070
copy(M::MulAdd) = copyto!(similar(M), M)
7171

72-
@inline function copyto!(dest::AbstractArray{T}, M::MulAdd) where T
73-
M.C === dest || copyto!(dest, M.C)
74-
muladd!(M.α, M.A, M.B, M.β, dest)
75-
end
72+
_fill_copyto!(dest, C) = copyto!(dest, C)
73+
_fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload
7674

77-
@inline function copyto!(dest::AbstractArray{T}, M::MulAdd{<:Any,<:Any,ZerosLayout}) where T
78-
α,A,B,β,C = M.α, M.A, M.B, M.β, M.C
79-
if !isbitstype(T) # instantiate
80-
dest .= β .* view(A,:,1) .* Ref(B[1]) # get shape right
81-
end
82-
muladd!(α, A, B, β, dest)
83-
end
75+
@inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T =
76+
muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C))
8477

8578
# Modified from LinearAlgebra._generic_matmatmul!
8679
function tile_size(T, S, R)
@@ -226,32 +219,28 @@ function _default_blasmul!(::IndexCartesian, α, A::AbstractMatrix, B::AbstractV
226219
C
227220
end
228221

229-
default_blasmul!(α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector) =
222+
default_blasmul!(α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector) =
230223
_default_blasmul!(Base.IndexStyle(typeof(A)), α, A, B, β, C)
231224

232225
function materialize!(M::MatMulMatAdd)
233226
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
234-
if C B
235-
B = copy(B)
236-
end
237-
default_blasmul!(α, A, B, iszero(β) ? false : β, C)
227+
default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C)
238228
end
239229

240230
function materialize!(M::MatMulMatAdd{<:AbstractStridedLayout,<:AbstractStridedLayout,<:AbstractStridedLayout})
241-
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
242-
if C B
243-
B = copy(B)
244-
end
231+
α, Ain, Bin, β, C = M.α, M.A, M.B, M.β, M.C
232+
A = unalias(C, Ain)
233+
B = unalias(C, Bin)
245234
ts = tile_size(eltype(A), eltype(B), eltype(C))
246235
if iszero(β) # false is a "strong" zero to wipe out NaNs
247236
if ts == 0 || !(axes(A) isa NTuple{2,OneTo{Int}}) || !(axes(B) isa NTuple{2,OneTo{Int}}) || !(axes(C) isa NTuple{2,OneTo{Int}})
248-
default_blasmul!(α, A, B, false, C)
249-
else
237+
default_blasmul!(α, A, B, false, C)
238+
else
250239
tiled_blasmul!(ts, α, A, B, false, C)
251240
end
252241
else
253242
if ts == 0 || !(axes(A) isa NTuple{2,OneTo{Int}}) || !(axes(B) isa NTuple{2,OneTo{Int}}) || !(axes(C) isa NTuple{2,OneTo{Int}})
254-
default_blasmul!(α, A, B, β, C)
243+
default_blasmul!(α, A, B, β, C)
255244
else
256245
tiled_blasmul!(ts, α, A, B, β, C)
257246
end
@@ -260,29 +249,11 @@ end
260249

261250
function materialize!(M::MatMulVecAdd)
262251
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
263-
if C B
264-
B = copy(B)
265-
end
266-
default_blasmul!(α, A, B, iszero(β) ? false : β, C)
252+
default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C)
267253
end
268254

269-
# make copy to make sure always works
270-
@inline function _gemv!(tA, α, A, x, β, y)
271-
if x y
272-
BLAS.gemv!(tA, α, A, copy(x), β, y)
273-
else
274-
BLAS.gemv!(tA, α, A, x, β, y)
275-
end
276-
end
277-
278-
# make copy to make sure always works
279-
@inline function _gemm!(tA, tB, α, A, B, β, C)
280-
if B C
281-
BLAS.gemm!(tA, tB, α, A, copy(B), β, C)
282-
else
283-
BLAS.gemm!(tA, tB, α, A, B, β, C)
284-
end
285-
end
255+
@inline _gemv!(tA, α, A, x, β, y) = BLAS.gemv!(tA, α, unalias(y,A), unalias(y,x), β, y)
256+
@inline _gemm!(tA, tB, α, A, B, β, C) = BLAS.gemm!(tA, tB, α, unalias(C,A), unalias(C,B), β, C)
286257

287258
# work around pointer issues
288259
@inline materialize!(M::BlasMatMulVecAdd{<:AbstractColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout}) =
@@ -350,21 +321,8 @@ end
350321
###
351322

352323
# make copy to make sure always works
353-
@inline function _symv!(tA, α, A, x, β, y)
354-
if x y
355-
BLAS.symv!(tA, α, A, copy(x), β, y)
356-
else
357-
BLAS.symv!(tA, α, A, x, β, y)
358-
end
359-
end
360-
361-
@inline function _hemv!(tA, α, A, x, β, y)
362-
if x y
363-
BLAS.hemv!(tA, α, A, copy(x), β, y)
364-
else
365-
BLAS.hemv!(tA, α, A, x, β, y)
366-
end
367-
end
324+
@inline _symv!(tA, α, A, x, β, y) = BLAS.symv!(tA, α, unalias(y,A), unalias(y,x), β, y)
325+
@inline _hemv!(tA, α, A, x, β, y) = BLAS.hemv!(tA, α, unalias(y,A), unalias(y,x), β, y)
368326

369327

370328
materialize!(M::BlasMatMulVecAdd{<:SymmetricLayout{<:AbstractColumnMajor},<:AbstractStridedLayout,<:AbstractStridedLayout}) =
@@ -411,10 +369,28 @@ scalarone(::Type{<:AbstractArray{T}}) where T = scalarone(T)
411369
scalarzero(::Type{T}) where T = zero(T)
412370
scalarzero(::Type{<:AbstractArray{T}}) where T = scalarzero(T)
413371

414-
fillzeros(::Type{T}, ax) where T = Zeros{T}(ax)
372+
fillzeros(::Type{T}, ax) where T<:Number = Zeros{T}(ax)
373+
mulzeros(::Type{T}, M) where T<:Number = fillzeros(T, axes(M))
374+
375+
# initiate array-valued MulAdd
376+
function _mulzeros!(dest::AbstractVector{T}, A, B) where T
377+
for k in axes(dest,1)
378+
dest[k] = similar(Mul(A[k,1],B[1]), eltype(T))
379+
end
380+
dest
381+
end
382+
383+
function _mulzeros!(dest::AbstractMatrix{T}, A, B) where T
384+
for j in axes(dest,2), k in axes(dest,1)
385+
dest[k,j] = similar(Mul(A[k,1],B[1,j]), eltype(T))
386+
end
387+
dest
388+
end
389+
390+
mulzeros(::Type{T}, M) where T<:AbstractArray = _mulzeros!(similar(Array{T}, axes(M)), M.A, M.B)
415391

416392
###
417-
# Fill
393+
# Fill
418394
###
419395

420396
copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout}) = M.α*M.A*M.B + M.β*M.C

test/test_layouts.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,19 @@ struct FooNumber <: Number end
126126
@test MemoryLayout(Bidiagonal(view(randn(10),[1,2,3]), view(randn(10),[1,2]), :U)) isa BidiagonalLayout{UnknownLayout}
127127
@test MemoryLayout(SymTridiagonal(view(randn(10),[1,2,3]), view(randn(10),[1,2]))) isa SymTridiagonalLayout{UnknownLayout}
128128
@test MemoryLayout(Tridiagonal(view(randn(10),[1,2]), view(randn(10),[1,2,3]), view(randn(10),[1,2]))) isa TridiagonalLayout{UnknownLayout}
129+
130+
@testset "Triangular of Tridiagonal" begin
131+
U = UpperTriangular(T)
132+
L = LowerTriangular(T)
133+
@test MemoryLayout(U) isa BidiagonalLayout{DenseColumnMajor}
134+
@test MemoryLayout(L) isa BidiagonalLayout{DenseColumnMajor}
135+
@test diagonaldata(U) == diagonaldata(L) == diagonaldata(T)
136+
@test subdiagonaldata(L) == subdiagonaldata(T)
137+
@test supdiagonaldata(U) == supdiagonaldata(T)
138+
@test bidiagonaluplo(U) == 'U'
139+
@test bidiagonaluplo(L) == 'L'
140+
@test_throws MethodError subdiagonaldata(U)
141+
end
129142
end
130143

131144
@testset "Symmetric/Hermitian" begin

test/test_muladd.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ Random.seed!(0)
7474
@test all(c .=== BLAS.gemv!('N', 2one(T), A, b, one(T), copy(b)))
7575
end
7676
end
77+
78+
@testset "Matrix{Int} * Vector{Vector{Int}}" begin
79+
A, x = [1 2; 3 4] , [[1,2],[3,4]]
80+
X = reshape([[1 2],[3 4], [5 6], [7 8]],2,2)
81+
@test mul(A,x) == A*x
82+
@test mul(A,X) == A*X
83+
@test mul(X,A) == X*A
84+
end
7785
end
7886

7987
@testset "Matrix * Matrix" begin

0 commit comments

Comments
 (0)