Skip to content

Commit 82a7673

Browse files
committed
add mat * dense_vec operation
1 parent 10e69a2 commit 82a7673

File tree

3 files changed

+72
-4
lines changed

3 files changed

+72
-4
lines changed

src/blocksparsearrayinterface/arraylayouts.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1-
using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
1+
using ArrayLayouts: ArrayLayouts, DenseColumnMajor, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
22
using BlockArrays: BlockArrays, BlockLayout, muladd!
33
using SparseArraysBase: SparseLayout
44
using LinearAlgebra: LinearAlgebra, dot, mul!
55

66
const muladd!_blocksparse = blocksparse_style(muladd!)
7+
# Matrix-matrix case
78
function muladd!_blocksparse(
8-
α::Number, a1::AbstractArray, a2::AbstractArray, β::Number, a_dest::AbstractArray
9+
α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix
910
)
1011
mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
1112
return a_dest
1213
end
1314

15+
# Matrix-vector case: BlockSparseMatrix * dense Vector
16+
function muladd!_blocksparse(
17+
α::Number, a1::AbstractMatrix, a2::AbstractVector, β::Number, a_dest::AbstractVector
18+
)
19+
mul!_blocksparse(a_dest, a1, a2, α, β)
20+
return a_dest
21+
end
22+
1423
function ArrayLayouts.materialize!(
1524
m::MatMulMatAdd{
1625
<:BlockLayout{<:SparseLayout},
@@ -21,15 +30,28 @@ function ArrayLayouts.materialize!(
2130
muladd!_blocksparse(m.α, m.A, m.B, m.β, m.C)
2231
return m.C
2332
end
33+
34+
# BlockSparseMatrix * dense Vector → dense Vector
35+
function ArrayLayouts.materialize!(
36+
m::MatMulVecAdd{
37+
<:BlockLayout{<:SparseLayout},
38+
<:DenseColumnMajor,
39+
<:DenseColumnMajor,
40+
},
41+
)
42+
muladd!_blocksparse(m.α, m.A, m.B, m.β, m.C)
43+
return m.C
44+
end
45+
46+
# BlockSparseMatrix * BlockSparseVector (not yet implemented)
2447
function ArrayLayouts.materialize!(
2548
m::MatMulVecAdd{
2649
<:BlockLayout{<:SparseLayout},
2750
<:BlockLayout{<:SparseLayout},
2851
<:BlockLayout{<:SparseLayout},
2952
},
3053
)
31-
error("Not implemented.")
32-
matmul!(m)
54+
error("BlockSparseMatrix * BlockSparseVector not implemented.")
3355
return m.C
3456
end
3557

src/blocksparsearrayinterface/linearalgebra.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BlockArrays: Block
12
using LinearAlgebra: LinearAlgebra, mul!
23

34
const mul!_blocksparse = blocksparse_style(mul!)
@@ -11,3 +12,31 @@ function mul!_blocksparse(
1112
mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
1213
return a_dest
1314
end
15+
16+
# Matrix-vector multiplication: BlockSparseMatrix * dense Vector
17+
function mul!_blocksparse(
18+
a_dest::AbstractVector,
19+
a1::AbstractMatrix,
20+
a2::AbstractVector,
21+
α::Number = true,
22+
β::Number = false,
23+
)
24+
# Scale destination by β (or zero it if β == 0)
25+
if iszero(β)
26+
fill!(a_dest, zero(eltype(a_dest)))
27+
elseif !isone(β)
28+
a_dest .*= β
29+
end
30+
31+
# Accumulate A[i,j] * v[j_range] into C[i_range]
32+
for I in eachblockstoredindex(a1)
33+
i, j = Int.(Tuple(I))
34+
block_A = a1[I]
35+
row_range = axes(a1, 1)[Block(i)]
36+
col_range = axes(a1, 2)[Block(j)]
37+
v_block = @view a2[col_range]
38+
c_block = @view a_dest[row_range]
39+
mul!(c_block, block_A, v_block, α, true) # β=true to accumulate
40+
end
41+
return a_dest
42+
end

test/test_basics.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,23 @@ arrayts = (Array, JLArray)
426426
@test a1' * a2 ≈ Array(a1)' * Array(a2)
427427
@test dot(a1, a2) a1' * a2
428428
end
429+
@testset "Matrix-vector multiplication" begin
430+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
431+
a[Block(1, 2)] = dev(randn(elt, size(@view(a[Block(1, 2)]))))
432+
a[Block(2, 1)] = dev(randn(elt, size(@view(a[Block(2, 1)]))))
433+
v = dev(randn(elt, 5)) # Dense vector matching column dimension
434+
435+
# Matrix-vector multiplication uses scalar indexing for block iteration
436+
c = @allowscalar a * v
437+
@allowscalar @test Array(c) ≈ Array(a) * Array(v)
438+
@test eltype(c) == elt
439+
@test length(c) == size(a, 1)
440+
441+
# Test with adjoint/transpose
442+
v2 = dev(randn(elt, 5)) # Vector matching row dimension for transposed multiplication
443+
@allowscalar @test Array(a' * v2) Array(a)' * Array(v2)
444+
@allowscalar @test Array(transpose(a) * v2) transpose(Array(a)) * Array(v2)
445+
end
429446
@testset "cat" begin
430447
a1 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
431448
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))

0 commit comments

Comments
 (0)