4
4
5
5
import Base: sort!, findall, copy!
6
6
import LinearAlgebra: promote_to_array_type, promote_to_arrays_
7
+ using LinearAlgebra: adj_or_trans
7
8
8
9
# ## The SparseVector
9
10
@@ -1768,23 +1769,52 @@ end
1768
1769
1769
1770
const _StridedOrTriangularMatrix{T} = Union{StridedMatrix{T}, LowerTriangular{T}, UnitLowerTriangular{T}, UpperTriangular{T}, UnitUpperTriangular{T}}
1770
1771
1772
+ _fliptri (A:: UpperTriangular ) = LowerTriangular (parent (parent (A)))
1773
+ _fliptri (A:: UnitUpperTriangular ) = UnitLowerTriangular (parent (parent (A)))
1774
+ _fliptri (A:: LowerTriangular ) = UpperTriangular (parent (parent (A)))
1775
+ _fliptri (A:: UnitLowerTriangular ) = UnitUpperTriangular (parent (parent (A)))
1776
+
1771
1777
function (* )(A:: _StridedOrTriangularMatrix{Ta} , x:: AbstractSparseVector{Tx} ) where {Ta,Tx}
1772
1778
require_one_based_indexing (A, x)
1773
1779
m, n = size (A)
1774
1780
length (x) == n || throw (DimensionMismatch ())
1775
1781
Ty = promote_op (matprod, eltype (A), eltype (x))
1776
1782
y = Vector {Ty} (undef, m)
1777
- mul! (y, A, x)
1783
+ mul! (y, A, x, true , false )
1778
1784
end
1779
1785
1780
- function mul! (y:: AbstractVector , A:: _StridedOrTriangularMatrix , x:: AbstractSparseVector , α:: Number , β:: Number )
1786
+ function LinearAlgebra. generic_matvecmul! (y:: AbstractVector , tA, A:: StridedMatrix , x:: AbstractSparseVector ,
1787
+ _add:: MulAddMul = MulAddMul ())
1788
+ if tA == ' N'
1789
+ _spmul! (y, A, x, _add. alpha, _add. beta)
1790
+ elseif tA == ' T'
1791
+ _At_or_Ac_mul_B! (transpose, y, A, x, _add. alpha, _add. beta)
1792
+ elseif tA == ' C'
1793
+ _At_or_Ac_mul_B! (adjoint, y, A, x, _add. alpha, _add. beta)
1794
+ else
1795
+ _spmul! (y, LinearAlgebra. wrap (A, tA), x, _add. alpha, _add. beta)
1796
+ end
1797
+ return y
1798
+ end
1799
+ function LinearAlgebra. generic_matvecmul! (y:: AbstractVector , tA, A:: UpperOrLowerTriangular , x:: AbstractSparseVector ,
1800
+ _add:: MulAddMul = MulAddMul ())
1801
+ @assert tA == ' N'
1802
+ Adata = parent (A)
1803
+ if Adata isa Transpose
1804
+ _At_or_Ac_mul_B! (transpose, y, _fliptri (A), x, _add. alpha, _add. beta)
1805
+ elseif Adata isa Adjoint
1806
+ _At_or_Ac_mul_B! (adjoint, y, _fliptri (A), x, _add. alpha, _add. beta)
1807
+ else # Adata is plain
1808
+ _spmul! (y, A, x, _add. alpha, _add. beta)
1809
+ end
1810
+ return y
1811
+ end
1812
+ function _spmul! (y:: AbstractVector , A:: AbstractMatrix , x:: AbstractSparseVector , α:: Number , β:: Number )
1781
1813
require_one_based_indexing (y, A, x)
1782
1814
m, n = size (A)
1783
1815
length (x) == n && length (y) == m || throw (DimensionMismatch ())
1784
1816
m == 0 && return y
1785
- if β != one (β)
1786
- β == zero (β) ? fill! (y, zero (eltype (y))) : rmul! (y, β)
1787
- end
1817
+ β != one (β) && LinearAlgebra. _rmul_or_fill! (y, β)
1788
1818
α == zero (α) && return y
1789
1819
1790
1820
xnzind = nonzeroinds (x)
@@ -1802,80 +1832,47 @@ function mul!(y::AbstractVector, A::_StridedOrTriangularMatrix, x::AbstractSpars
1802
1832
return y
1803
1833
end
1804
1834
1805
- # * and mul!(C, transpose(A), B)
1806
-
1807
- function * (tA:: Transpose{<:Any,<:_StridedOrTriangularMatrix{Ta}} , x:: AbstractSparseVector{Tx} ) where {Ta,Tx}
1808
- require_one_based_indexing (tA, x)
1809
- m, n = size (tA)
1810
- length (x) == n || throw (DimensionMismatch ())
1811
- Ty = promote_op (matprod, eltype (tA), eltype (x))
1812
- y = Vector {Ty} (undef, m)
1813
- mul! (y, tA, x)
1814
- end
1815
-
1816
- function mul! (y:: AbstractVector , tA:: Transpose{<:Any,<:_StridedOrTriangularMatrix} , x:: AbstractSparseVector , α:: Number , β:: Number )
1817
- require_one_based_indexing (y, tA, x)
1818
- m, n = size (tA)
1835
+ function _At_or_Ac_mul_B! (tfun:: Function ,
1836
+ y:: AbstractVector , A:: _StridedOrTriangularMatrix , x:: AbstractSparseVector ,
1837
+ α:: Number , β:: Number )
1838
+ require_one_based_indexing (y, A, x)
1839
+ n, m = size (A)
1819
1840
length (x) == n && length (y) == m || throw (DimensionMismatch ())
1820
1841
m == 0 && return y
1821
- if β != one (β)
1822
- β == zero (β) ? fill! (y, zero (eltype (y))) : rmul! (y, β)
1823
- end
1842
+ β != one (β) && LinearAlgebra. _rmul_or_fill! (y, β)
1824
1843
α == zero (α) && return y
1825
1844
1826
1845
xnzind = nonzeroinds (x)
1827
1846
xnzval = nonzeros (x)
1828
1847
_nnz = length (xnzind)
1829
1848
_nnz == 0 && return y
1830
1849
1831
- A = tA. parent
1832
1850
Ty = promote_op (matprod, eltype (A), eltype (x))
1833
1851
@inbounds for j = 1 : m
1834
1852
s = zero (Ty)
1835
1853
for i = 1 : _nnz
1836
- s += transpose (A[xnzind[i], j]) * xnzval[i]
1854
+ s += tfun (A[xnzind[i], j]) * xnzval[i]
1837
1855
end
1838
1856
y[j] += s * α
1839
1857
end
1840
1858
return y
1841
1859
end
1842
1860
1843
- # * and mul!(C, adjoint(A), B)
1844
-
1845
- function * (adjA:: Adjoint{<:Any,<:_StridedOrTriangularMatrix{Ta}} , x:: AbstractSparseVector{Tx} ) where {Ta,Tx}
1846
- require_one_based_indexing (adjA, x)
1847
- m, n = size (adjA)
1861
+ function * (A:: AdjOrTrans{<:Any,<:StridedMatrix} , x:: AbstractSparseVector )
1862
+ require_one_based_indexing (A, x)
1863
+ m, n = size (A)
1848
1864
length (x) == n || throw (DimensionMismatch ())
1849
- Ty = promote_op (matprod, eltype (adjA ), eltype (x))
1865
+ Ty = promote_op (matprod, eltype (A ), eltype (x))
1850
1866
y = Vector {Ty} (undef, m)
1851
- mul! (y, adjA , x)
1867
+ mul! (y, A , x, true , false )
1852
1868
end
1853
-
1854
- function mul! (y:: AbstractVector , adjA:: Adjoint{<:Any,<:_StridedOrTriangularMatrix} , x:: AbstractSparseVector , α:: Number , β:: Number )
1855
- require_one_based_indexing (y, adjA, x)
1856
- m, n = size (adjA)
1857
- length (x) == n && length (y) == m || throw (DimensionMismatch ())
1858
- m == 0 && return y
1859
- if β != one (β)
1860
- β == zero (β) ? fill! (y, zero (eltype (y))) : rmul! (y, β)
1861
- end
1862
- α == zero (α) && return y
1863
-
1864
- xnzind = nonzeroinds (x)
1865
- xnzval = nonzeros (x)
1866
- _nnz = length (xnzind)
1867
- _nnz == 0 && return y
1868
-
1869
- A = adjA. parent
1869
+ function * (A:: LinearAlgebra.HermOrSym{<:Any,<:StridedMatrix} , x:: AbstractSparseVector )
1870
+ require_one_based_indexing (A, x)
1871
+ m, n = size (A)
1872
+ length (x) == n || throw (DimensionMismatch ())
1870
1873
Ty = promote_op (matprod, eltype (A), eltype (x))
1871
- @inbounds for j = 1 : m
1872
- s = zero (Ty)
1873
- for i = 1 : _nnz
1874
- s += adjoint (A[xnzind[i], j]) * xnzval[i]
1875
- end
1876
- y[j] += s * α
1877
- end
1878
- return y
1874
+ y = Vector {Ty} (undef, m)
1875
+ mul! (y, A, x, true , false )
1879
1876
end
1880
1877
1881
1878
@@ -1906,15 +1903,26 @@ function densemv(A::AbstractSparseMatrixCSC, x::AbstractSparseVector; trans::Abs
1906
1903
end
1907
1904
1908
1905
# * and mul!
1906
+ function LinearAlgebra. generic_matvecmul! (y:: AbstractVector , tA, A:: AbstractSparseMatrixCSC , x:: AbstractSparseVector ,
1907
+ _add:: MulAddMul = MulAddMul ())
1908
+ if tA == ' N'
1909
+ _spmul! (y, A, x, _add. alpha, _add. beta)
1910
+ elseif tA == ' T'
1911
+ _At_or_Ac_mul_B! ((a,b) -> transpose (a) * b, y, A, x, _add. alpha, _add. beta)
1912
+ elseif tA == ' C'
1913
+ _At_or_Ac_mul_B! ((a,b) -> adjoint (a) * b, y, A, x, _add. alpha, _add. beta)
1914
+ else
1915
+ LinearAlgebra. _generic_matvecmul! (y, ' N' , LinearAlgebra. wrap (A, tA), x, _add)
1916
+ end
1917
+ return y
1918
+ end
1909
1919
1910
- function mul ! (y:: AbstractVector , A:: AbstractSparseMatrixCSC , x:: AbstractSparseVector , α:: Number , β:: Number )
1920
+ function _spmul ! (y:: AbstractVector , A:: AbstractSparseMatrixCSC , x:: AbstractSparseVector , α:: Number , β:: Number )
1911
1921
require_one_based_indexing (y, A, x)
1912
1922
m, n = size (A)
1913
1923
length (x) == n && length (y) == m || throw (DimensionMismatch ())
1914
1924
m == 0 && return y
1915
- if β != one (β)
1916
- β == zero (β) ? fill! (y, zero (eltype (y))) : rmul! (y, β)
1917
- end
1925
+ β != one (β) && LinearAlgebra. _rmul_or_fill! (y, β)
1918
1926
α == zero (α) && return y
1919
1927
1920
1928
xnzind = nonzeroinds (x)
@@ -1936,23 +1944,14 @@ function mul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSparseVe
1936
1944
return y
1937
1945
end
1938
1946
1939
- # * and *(Transpose(A), B)
1940
- mul! (y:: AbstractVector , tA:: Transpose{<:Any,<:AbstractSparseMatrixCSC} , x:: AbstractSparseVector , α:: Number , β:: Number ) =
1941
- _At_or_Ac_mul_B! ((a,b) -> transpose (a) * b, y, tA. parent, x, α, β)
1942
-
1943
- mul! (y:: AbstractVector , adjA:: Adjoint{<:Any,<:AbstractSparseMatrixCSC} , x:: AbstractSparseVector , α:: Number , β:: Number ) =
1944
- _At_or_Ac_mul_B! ((a,b) -> adjoint (a) * b, y, adjA. parent, x, α, β)
1945
-
1946
1947
function _At_or_Ac_mul_B! (tfun:: Function ,
1947
1948
y:: AbstractVector , A:: AbstractSparseMatrixCSC , x:: AbstractSparseVector ,
1948
1949
α:: Number , β:: Number )
1949
1950
require_one_based_indexing (y, A, x)
1950
1951
m, n = size (A)
1951
1952
length (x) == m && length (y) == n || throw (DimensionMismatch ())
1952
1953
n == 0 && return y
1953
- if β != one (β)
1954
- β == zero (β) ? fill! (y, zero (eltype (y))) : rmul! (y, β)
1955
- end
1954
+ β != one (β) && LinearAlgebra. _rmul_or_fill! (y, β)
1956
1955
α == zero (α) && return y
1957
1956
1958
1957
xnzind = nonzeroinds (x)
@@ -1981,11 +1980,8 @@ function *(A::AbstractSparseMatrixCSC, x::AbstractSparseVector)
1981
1980
_dense2sparsevec (y, initcap)
1982
1981
end
1983
1982
1984
- * (tA:: Transpose{<:Any,<:AbstractSparseMatrixCSC} , x:: AbstractSparseVector ) =
1985
- _At_or_Ac_mul_B ((a,b) -> transpose (a) * b, tA. parent, x, promote_op (matprod, eltype (tA), eltype (x)))
1986
-
1987
- * (adjA:: Adjoint{<:Any,<:AbstractSparseMatrixCSC} , x:: AbstractSparseVector ) =
1988
- _At_or_Ac_mul_B ((a,b) -> adjoint (a) * b, adjA. parent, x, promote_op (matprod, eltype (adjA), eltype (x)))
1983
+ * (xA:: AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC} , x:: AbstractSparseVector ) =
1984
+ _At_or_Ac_mul_B ((a,b) -> adj_or_trans (xA)(a) * b, xA. parent, x, promote_op (matprod, eltype (xA), eltype (x)))
1989
1985
1990
1986
function _At_or_Ac_mul_B (tfun:: Function , A:: AbstractSparseMatrixCSC{TvA,TiA} , x:: AbstractSparseVector{TvX,TiX} ,
1991
1987
Tv = promote_op (matprod, TvA, TvX)) where {TvA,TiA,TvX,TiX}
0 commit comments