Skip to content

Commit f7f482b

Browse files
dkarraschViralBShah
authored andcommitted
add specific Adjoint{StridedMatrix} * SparseVector method (#32403)
add complex and adjoint tests
1 parent d97d399 commit f7f482b

File tree

2 files changed

+93
-37
lines changed

2 files changed

+93
-37
lines changed

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,7 +1494,7 @@ function (*)(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
14941494
require_one_based_indexing(A, x)
14951495
m, n = size(A)
14961496
length(x) == n || throw(DimensionMismatch())
1497-
Ty = promote_type(Ta, Tx)
1497+
Ty = promote_op(matprod, Ta, Tx)
14981498
y = Vector{Ty}(undef, m)
14991499
mul!(y, A, x)
15001500
end
@@ -1531,23 +1531,62 @@ end
15311531

15321532
function *(transA::Transpose{<:Any,<:StridedMatrix{Ta}}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
15331533
require_one_based_indexing(transA, x)
1534-
A = transA.parent
1535-
m, n = size(A)
1536-
length(x) == m || throw(DimensionMismatch())
1537-
Ty = promote_type(Ta, Tx)
1538-
y = Vector{Ty}(undef, n)
1539-
mul!(y, transpose(A), x)
1534+
m, n = size(transA)
1535+
length(x) == n || throw(DimensionMismatch())
1536+
Ty = promote_op(matprod, Ta, Tx)
1537+
y = Vector{Ty}(undef, m)
1538+
mul!(y, transA, x)
15401539
end
15411540

15421541
mul!(y::AbstractVector{Ty}, transA::Transpose{<:Any,<:StridedMatrix}, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
1543-
(A = transA.parent; mul!(y, transpose(A), x, one(Tx), zero(Ty)))
1542+
mul!(y, transA, x, one(Tx), zero(Ty))
15441543

15451544
function mul!(y::AbstractVector, transA::Transpose{<:Any,<:StridedMatrix}, x::AbstractSparseVector, α::Number, β::Number)
1545+
require_one_based_indexing(y, transA, x)
1546+
m, n = size(transA)
1547+
length(x) == n && length(y) == m || throw(DimensionMismatch())
1548+
m == 0 && return y
1549+
if β != one(β)
1550+
β == zero(β) ? fill!(y, zero(eltype(y))) : rmul!(y, β)
1551+
end
1552+
α == zero(α) && return y
1553+
1554+
xnzind = nonzeroinds(x)
1555+
xnzval = nonzeros(x)
1556+
_nnz = length(xnzind)
1557+
_nnz == 0 && return y
1558+
15461559
A = transA.parent
1547-
require_one_based_indexing(y, A, x)
1548-
m, n = size(A)
1549-
length(x) == m && length(y) == n || throw(DimensionMismatch())
1550-
n == 0 && return y
1560+
Ty = promote_op(matprod, eltype(A), eltype(x))
1561+
@inbounds for j = 1:m
1562+
s = zero(Ty)
1563+
for i = 1:_nnz
1564+
s += transpose(A[xnzind[i], j]) * xnzval[i]
1565+
end
1566+
y[j] += s * α
1567+
end
1568+
return y
1569+
end
1570+
1571+
# * and mul!(C, adjoint(A), B)
1572+
1573+
function *(adjA::Adjoint{<:Any,<:StridedMatrix{Ta}}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
1574+
require_one_based_indexing(adjA, x)
1575+
m, n = size(adjA)
1576+
length(x) == n || throw(DimensionMismatch())
1577+
Ty = promote_op(matprod, Ta, Tx)
1578+
y = Vector{Ty}(undef, m)
1579+
mul!(y, adjA, x)
1580+
end
1581+
1582+
mul!(y::AbstractVector{Ty}, adjA::Adjoint{<:Any,<:StridedMatrix}, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
1583+
mul!(y, adjA, x, one(Tx), zero(Ty))
1584+
1585+
function mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:StridedMatrix}, x::AbstractSparseVector, α::Number, β::Number)
1586+
require_one_based_indexing(y, adjA, x)
1587+
m, n = size(adjA)
1588+
length(x) == n && length(y) == m || throw(DimensionMismatch())
1589+
m == 0 && return y
15511590
if β != one(β)
15521591
β == zero(β) ? fill!(y, zero(eltype(y))) : rmul!(y, β)
15531592
end
@@ -1558,11 +1597,12 @@ function mul!(y::AbstractVector, transA::Transpose{<:Any,<:StridedMatrix}, x::Ab
15581597
_nnz = length(xnzind)
15591598
_nnz == 0 && return y
15601599

1561-
s0 = zero(eltype(A)) * zero(eltype(x))
1562-
@inbounds for j = 1:n
1563-
s = zero(s0)
1600+
A = adjA.parent
1601+
Ty = promote_op(matprod, eltype(A), eltype(x))
1602+
@inbounds for j = 1:m
1603+
s = zero(Ty)
15641604
for i = 1:_nnz
1565-
s += A[xnzind[i], j] * xnzval[i]
1605+
s += adjoint(A[xnzind[i], j]) * xnzval[i]
15661606
end
15671607
y[j] += s * α
15681608
end

stdlib/SparseArrays/test/sparsevector.jl

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -841,30 +841,46 @@ end
841841

842842
@testset "BLAS Level-2" begin
843843
@testset "dense A * sparse x -> dense y" begin
844-
let A = randn(9, 16), x = sprand(16, 0.7)
845-
xf = Array(x)
846-
for α in [0.0, 1.0, 2.0], β in [0.0, 0.5, 1.0]
847-
y = rand(9)
848-
rr = α*A*xf + β*y
849-
@test mul!(y, A, x, α, β) === y
850-
@test y rr
844+
for TA in (Float64, ComplexF64), Tx in (Float64, ComplexF64)
845+
T = Base.promote_op(LinearAlgebra.matprod, TA, Tx)
846+
let A = randn(TA, 9, 16), x = sprand(Tx, 16, 0.7)
847+
xf = Array(x)
848+
for α in [0.0, 1.0, 2.0], β in [0.0, 0.5, 1.0]
849+
y = rand(T, 9)
850+
rr = α*A*xf + β*y
851+
@test mul!(y, A, x, α, β) === y
852+
@test y rr
853+
end
854+
y = A*x
855+
@test isa(y, Vector{T})
856+
@test A*x A*xf
851857
end
852-
y = A*x
853-
@test isa(y, Vector{Float64})
854-
@test A*x A*xf
855-
end
856858

857-
let A = randn(16, 9), x = sprand(16, 0.7)
858-
xf = Array(x)
859-
for α in [0.0, 1.0, 2.0], β in [0.0, 0.5, 1.0]
860-
y = rand(9)
861-
rr = α*A'xf + β*y
862-
@test mul!(y, transpose(A), x, α, β) === y
863-
@test y rr
859+
let A = randn(TA, 16, 9), x = sprand(Tx, 16, 0.7)
860+
xf = Array(x)
861+
for α in [0.0, 1.0, 2.0], β in [0.0, 0.5, 1.0]
862+
y = rand(T, 9)
863+
rr = α*transpose(A)*xf + β*y
864+
@test mul!(y, transpose(A), x, α, β) === y
865+
@test y rr
866+
end
867+
y = *(transpose(A), x)
868+
@test isa(y, Vector{T})
869+
@test y *(transpose(A), xf)
870+
end
871+
872+
let A = randn(TA, 16, 9), x = sprand(Tx, 16, 0.7)
873+
xf = Array(x)
874+
for α in [0.0, 1.0, 2.0], β in [0.0, 0.5, 1.0]
875+
y = rand(T, 9)
876+
rr = α*A'xf + β*y
877+
@test mul!(y, adjoint(A), x, α, β) === y
878+
@test y rr
879+
end
880+
y = *(adjoint(A), x)
881+
@test isa(y, Vector{T})
882+
@test y *(adjoint(A), xf)
864883
end
865-
y = *(transpose(A), x)
866-
@test isa(y, Vector{Float64})
867-
@test y *(transpose(A), xf)
868884
end
869885
end
870886
@testset "sparse A * sparse x -> dense y" begin

0 commit comments

Comments
 (0)