Skip to content

Commit be12653

Browse files
committed
ldiv! for Diagonal and a sparse vector
1 parent 9548149 commit be12653

File tree

2 files changed

+179
-19
lines changed

2 files changed

+179
-19
lines changed

src/linalg.jl

Lines changed: 176 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,7 @@ function rdiv!(A::AbstractSparseMatrixCSC{T}, D::Diagonal{T}) where T
13121312
A
13131313
end
13141314

1315-
function ldiv!(D::Diagonal{T}, A::AbstractSparseMatrixCSC{T}) where {T}
1315+
function ldiv!(D::Diagonal{T}, A::Union{AbstractSparseMatrixCSC{T}, AbstractSparseVector{T}}) where {T}
13161316
# require_one_based_indexing(A)
13171317
if size(A, 1) != length(D.diag)
13181318
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $(size(A, 1)) rows"))
@@ -1877,47 +1877,204 @@ inv(A::AbstractSparseMatrixCSC) = error("The inverse of a sparse matrix can ofte
18771877
## scale methods
18781878

18791879
# Copy colptr and rowval from one sparse matrix to another
1880-
function copyinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
1881-
if getcolptr(C) !== getcolptr(A)
1880+
function copyinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC; copy_rows=true, copy_cols=true)
1881+
if copy_cols && getcolptr(C) !== getcolptr(A)
18821882
resize!(getcolptr(C), length(getcolptr(A)))
18831883
copyto!(getcolptr(C), getcolptr(A))
18841884
end
1885-
if rowvals(C) !== rowvals(A)
1885+
if copy_rows && rowvals(C) !== rowvals(A)
18861886
resize!(rowvals(C), length(rowvals(A)))
18871887
copyto!(rowvals(C), rowvals(A))
18881888
end
18891889
end
18901890

1891+
"""
1892+
rowcheck_index(A::AbstractSparseMatrixCSC, row::Integer, col::Integer)
1893+
1894+
Check if A[row, col] is a stored value, and return the index of the row in `rowvals(A)`.
1895+
Returns `(row_exists, row_ind)`, where `row_exists::Bool` signifies
1896+
whether the corresponding index is populated, and `row_ind` is the index.
1897+
If `row_exists` is `false`, the `row_ind` is the index where the value should be inserted into
1898+
`rowvals(A)` such that the subarray `@view rowvals(A)[nzrange(A, col)]` remains sorted.
1899+
"""
1900+
@inline function rowcheck_index(A::AbstractSparseMatrixCSC, row::Integer, col::Integer)
1901+
nzinds = nzrange(A, col)
1902+
rows_col = @view rowvals(A)[nzinds]
1903+
# faster implementation of row ∈ rows_col and obtaining the index,
1904+
# assuming that rows_col is sorted
1905+
row_ind_col = searchsortedfirst(rows_col, row)
1906+
row_exists = row_ind_col axes(rows_col,1) && rows_col[row_ind_col] == row
1907+
row_ind = row_ind_col + first(nzinds) - firstindex(nzinds)
1908+
row_exists, row_ind
1909+
end
1910+
1911+
"""
1912+
mergeinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
1913+
1914+
Update `C` to contain stored values corresponding to the stored indices of `A`.
1915+
Stored indices common to `C` and `A` are not touched. Indices of `A` at which
1916+
`C` did not have a stored value are populated with zeros after the call.
1917+
1918+
# Examples
1919+
```jldoctest
1920+
julia> A = spzeros(3,3);
1921+
1922+
julia> A[4:4:8] .= 1;
1923+
1924+
julia> A
1925+
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
1926+
⋅ 1.0 ⋅
1927+
⋅ ⋅ 1.0
1928+
⋅ ⋅ ⋅
1929+
1930+
julia> C = spzeros(3,3);
1931+
1932+
julia> C[2:4:6] .= 2;
1933+
1934+
julia> C
1935+
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
1936+
⋅ ⋅ ⋅
1937+
2.0 ⋅ ⋅
1938+
⋅ 2.0 ⋅
1939+
1940+
julia> SparseArrays.mergeinds!(C, A)
1941+
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
1942+
⋅ 0.0 ⋅
1943+
2.0 ⋅ 0.0
1944+
⋅ 2.0 ⋅
1945+
```
1946+
"""
1947+
function mergeinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
1948+
C_colptr = getcolptr(C)
1949+
for col in axes(A,2)
1950+
n_extra = 0
1951+
for ind in nzrange(A, col)
1952+
row = rowvals(A)[ind]
1953+
row_exists, ind = rowcheck_index(C, row, col)
1954+
if !row_exists
1955+
n_extra += 1
1956+
insert!(rowvals(C), ind, row)
1957+
insert!(nonzeros(C), ind, zero(eltype(C)))
1958+
C_colptr[col+1] += 1
1959+
end
1960+
end
1961+
if !iszero(n_extra)
1962+
@views C_colptr[col+2:end] .+= n_extra
1963+
end
1964+
end
1965+
C
1966+
end
1967+
18911968
# multiply by diagonal matrix as vector
1892-
function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal)
1969+
function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal, alpha::Number, beta::Number)
18931970
m, n = size(A)
1894-
b = D.diag
1971+
b = D.diag
18951972
lb = length(b)
1896-
n == lb || throw(DimensionMismatch("A has size ($m, $n) but D has size ($lb, $lb)"))
1897-
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
1898-
copyinds!(C, A)
1973+
n == lb || throw(DimensionMismatch(lazy"A has size ($m, $n) but D has size ($lb, $lb)"))
1974+
size(A)==size(C) || throw(DimensionMismatch(lazy"A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
1975+
beta_is_zero = iszero(beta)
1976+
rows_match = rowvals(C) == rowvals(A)
1977+
cols_match = getcolptr(C) == getcolptr(A)
1978+
identical_nzinds = rows_match && cols_match
18991979
Cnzval = nonzeros(C)
19001980
Anzval = nonzeros(A)
1901-
resize!(Cnzval, length(Anzval))
1902-
for col in axes(A,2), p in nzrange(A, col)
1903-
@inbounds Cnzval[p] = Anzval[p] * b[col]
1981+
if beta_is_zero || identical_nzinds
1982+
identical_nzinds || copyinds!(C, A, copy_rows = !rows_match, copy_cols = !cols_match)
1983+
resize!(Cnzval, length(Anzval))
1984+
if beta_is_zero
1985+
if isone(alpha)
1986+
for col in axes(A,2), p in nzrange(A, col)
1987+
@inbounds Cnzval[p] = Anzval[p] * b[col]
1988+
end
1989+
else
1990+
for col in axes(A,2), p in nzrange(A, col)
1991+
@inbounds Cnzval[p] = Anzval[p] * b[col] * alpha
1992+
end
1993+
end
1994+
else
1995+
if isone(alpha)
1996+
for col in axes(A,2), p in nzrange(A, col)
1997+
@inbounds Cnzval[p] = Anzval[p] * b[col] + Cnzval[p] * beta
1998+
end
1999+
else
2000+
for col in axes(A,2), p in nzrange(A, col)
2001+
@inbounds Cnzval[p] = Anzval[p] * b[col] * alpha + Cnzval[p] * beta
2002+
end
2003+
end
2004+
end
2005+
else
2006+
mergeinds!(C, A)
2007+
for col in axes(C,2), p in nzrange(C, col)
2008+
row = rowvals(C)[p]
2009+
# check if the index (row, col) is stored in A
2010+
row_exists, row_ind_A = rowcheck_index(A, row, col)
2011+
if row_exists
2012+
if isone(alpha)
2013+
@inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] + Cnzval[p] * beta
2014+
else
2015+
@inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] * alpha + Cnzval[p] * beta
2016+
end
2017+
else # A[row,col] == 0
2018+
@inbounds Cnzval[p] = Cnzval[p] * beta
2019+
end
2020+
end
19042021
end
19052022
C
19062023
end
19072024

1908-
function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC)
2025+
function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC, alpha::Number, beta::Number)
19092026
m, n = size(A)
19102027
b = D.diag
19112028
lb = length(b)
1912-
m == lb || throw(DimensionMismatch("D has size ($lb, $lb) but A has size ($m, $n)"))
1913-
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
1914-
copyinds!(C, A)
2029+
m == lb || throw(DimensionMismatch(lazy"D has size ($lb, $lb) but A has size ($m, $n)"))
2030+
size(A)==size(C) || throw(DimensionMismatch(lazy"A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
2031+
beta_is_zero = iszero(beta)
2032+
rows_match = rowvals(C) == rowvals(A)
2033+
cols_match = getcolptr(C) == getcolptr(A)
2034+
identical_nzinds = rows_match && cols_match
19152035
Cnzval = nonzeros(C)
19162036
Anzval = nonzeros(A)
19172037
Arowval = rowvals(A)
1918-
resize!(Cnzval, length(Anzval))
1919-
for col in axes(A,2), p in nzrange(A, col)
1920-
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
2038+
if beta_is_zero || identical_nzinds
2039+
identical_nzinds || copyinds!(C, A, copy_rows = !rows_match, copy_cols = !cols_match)
2040+
resize!(Cnzval, length(Anzval))
2041+
if beta_is_zero
2042+
if isone(alpha)
2043+
for col in axes(A,2), p in nzrange(A, col)
2044+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
2045+
end
2046+
else
2047+
for col in axes(A,2), p in nzrange(A, col)
2048+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha
2049+
end
2050+
end
2051+
else
2052+
if isone(alpha)
2053+
for col in axes(A,2), p in nzrange(A, col)
2054+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] + Cnzval[p] * beta
2055+
end
2056+
else
2057+
for col in axes(A,2), p in nzrange(A, col)
2058+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha + Cnzval[p] * beta
2059+
end
2060+
end
2061+
end
2062+
else
2063+
mergeinds!(C, A)
2064+
for col in axes(C,2), p in nzrange(C, col)
2065+
row = rowvals(C)[p]
2066+
# check if the index (row, col) is stored in A
2067+
row_exists, row_ind_A = rowcheck_index(A, row, col)
2068+
if row_exists
2069+
if isone(alpha)
2070+
@inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] + Cnzval[p] * beta
2071+
else
2072+
@inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] * alpha + Cnzval[p] * beta
2073+
end
2074+
else # A[row,col] == 0
2075+
@inbounds Cnzval[p] = Cnzval[p] * beta
2076+
end
2077+
end
19212078
end
19222079
C
19232080
end

test/linalg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ end
338338
@test lmul!(transpose(copy(D)), copy(b)) transpose(MD)*bd
339339
@test lmul!(adjoint(copy(D)), copy(b)) MD'*bd
340340
end
341+
342+
v = sprand(eltype(D), size(D,1), 0.1)
343+
@test ldiv!(D, copy(v)) == D \ v
341344
end
342345
end
343346

0 commit comments

Comments
 (0)