diff --git a/src/abstractblocksparsearray/arraylayouts.jl b/src/abstractblocksparsearray/arraylayouts.jl index fd3079d..d61c479 100644 --- a/src/abstractblocksparsearray/arraylayouts.jl +++ b/src/abstractblocksparsearray/arraylayouts.jl @@ -1,4 +1,4 @@ -using ArrayLayouts: ArrayLayouts, DualLayout, MemoryLayout, MulAdd +using ArrayLayouts: ArrayLayouts, DenseColumnMajor, DualLayout, MemoryLayout, MulAdd using BlockArrays: BlockLayout using SparseArraysBase: SparseLayout using TypeParameterAccessors: parenttype, similartype @@ -56,6 +56,22 @@ function Base.similar( return similar(BlockSparseArray{elt, length(axes), output_blocktype′}, axes) end +# BlockSparseMatrix * dense Vector → dense Vector +# Returns a plain Vector instead of a BlockedVector +function Base.similar( + mul::MulAdd{ + <:BlockLayout{<:SparseLayout}, + <:DenseColumnMajor, + <:Any, + }, + elt::Type, + axes::Tuple{<:AbstractUnitRange}, + ) + # Convert blocked axes to plain axes to avoid creating BlockedVector + plain_axes = map(ax -> Base.OneTo(length(ax)), axes) + return similar(mul.B, elt, plain_axes) +end + # Materialize a SubArray view. function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes) # TODO: Define `blocktype`/`blockstype` for `SubArray` wrapping `BlockSparseArray`. diff --git a/src/blocksparsearrayinterface/arraylayouts.jl b/src/blocksparsearrayinterface/arraylayouts.jl index 22d4845..7a37662 100644 --- a/src/blocksparsearrayinterface/arraylayouts.jl +++ b/src/blocksparsearrayinterface/arraylayouts.jl @@ -1,16 +1,25 @@ -using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd +using ArrayLayouts: ArrayLayouts, DenseColumnMajor, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd using BlockArrays: BlockArrays, BlockLayout, muladd! using SparseArraysBase: SparseLayout using LinearAlgebra: LinearAlgebra, dot, mul! const muladd!_blocksparse = blocksparse_style(muladd!) +# Matrix-matrix case function muladd!_blocksparse( - α::Number, a1::AbstractArray, a2::AbstractArray, β::Number, a_dest::AbstractArray + α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix ) mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β) return a_dest end +# Matrix-vector case: BlockSparseMatrix * dense Vector +function muladd!_blocksparse( + α::Number, a1::AbstractMatrix, a2::AbstractVector, β::Number, a_dest::AbstractVector + ) + mul!_blocksparse(a_dest, a1, a2, α, β) + return a_dest +end + function ArrayLayouts.materialize!( m::MatMulMatAdd{ <:BlockLayout{<:SparseLayout}, @@ -21,6 +30,20 @@ function ArrayLayouts.materialize!( muladd!_blocksparse(m.α, m.A, m.B, m.β, m.C) return m.C end + +# BlockSparseMatrix * dense Vector → dense Vector +function ArrayLayouts.materialize!( + m::MatMulVecAdd{ + <:BlockLayout{<:SparseLayout}, + <:DenseColumnMajor, + <:DenseColumnMajor, + }, + ) + muladd!_blocksparse(m.α, m.A, m.B, m.β, m.C) + return m.C +end + +# BlockSparseMatrix * BlockSparseVector (not yet implemented) function ArrayLayouts.materialize!( m::MatMulVecAdd{ <:BlockLayout{<:SparseLayout}, @@ -28,8 +51,7 @@ function ArrayLayouts.materialize!( <:BlockLayout{<:SparseLayout}, }, ) - error("Not implemented.") - matmul!(m) + error("BlockSparseMatrix * BlockSparseVector not implemented.") return m.C end diff --git a/src/blocksparsearrayinterface/linearalgebra.jl b/src/blocksparsearrayinterface/linearalgebra.jl index e5f4313..b108068 100644 --- a/src/blocksparsearrayinterface/linearalgebra.jl +++ b/src/blocksparsearrayinterface/linearalgebra.jl @@ -1,3 +1,4 @@ +using BlockArrays: Block using LinearAlgebra: LinearAlgebra, mul! const mul!_blocksparse = blocksparse_style(mul!) @@ -11,3 +12,31 @@ function mul!_blocksparse( mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β) return a_dest end + +# Matrix-vector multiplication: BlockSparseMatrix * dense Vector +function mul!_blocksparse( + a_dest::AbstractVector, + a1::AbstractMatrix, + a2::AbstractVector, + α::Number = true, + β::Number = false, + ) + # Scale destination by β (or zero it if β == 0) + if iszero(β) + fill!(a_dest, zero(eltype(a_dest))) + elseif !isone(β) + a_dest .*= β + end + + # Accumulate A[i,j] * v[j_range] into C[i_range] + for I in eachblockstoredindex(a1) + i, j = Int.(Tuple(I)) + block_A = a1[I] + row_range = axes(a1, 1)[Block(i)] + col_range = axes(a1, 2)[Block(j)] + v_block = @view a2[col_range] + c_block = @view a_dest[row_range] + mul!(c_block, block_A, v_block, α, true) # β=true to accumulate + end + return a_dest +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 2310ec5..69169ee 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -426,6 +426,22 @@ arrayts = (Array, JLArray) @test a1' * a2 ≈ Array(a1)' * Array(a2) @test dot(a1, a2) ≈ a1' * a2 end + @testset "Matrix-vector multiplication" begin + a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3])) + a[Block(1, 2)] = dev(randn(elt, size(@view(a[Block(1, 2)])))) + a[Block(2, 1)] = dev(randn(elt, size(@view(a[Block(2, 1)])))) + v = dev(randn(elt, 5)) # Dense vector matching column dimension + + c = a * v + @test Array(c) ≈ Array(a) * Array(v) + @test c isa typeof(v) # Result should match input vector type + @test length(c) == size(a, 1) + + # Test with adjoint/transpose + v2 = dev(randn(elt, 5)) # Vector matching row dimension for transposed multiplication + @test Array(a' * v2) ≈ Array(a)' * Array(v2) + @test Array(transpose(a) * v2) ≈ transpose(Array(a)) * Array(v2) + end @testset "cat" begin a1 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3])) a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))