Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/abstractblocksparsearray/arraylayouts.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ArrayLayouts: ArrayLayouts, DualLayout, MemoryLayout, MulAdd
using ArrayLayouts: ArrayLayouts, DenseColumnMajor, DualLayout, MemoryLayout, MulAdd
using BlockArrays: BlockLayout
using SparseArraysBase: SparseLayout
using TypeParameterAccessors: parenttype, similartype
Expand Down Expand Up @@ -56,6 +56,22 @@ function Base.similar(
return similar(BlockSparseArray{elt, length(axes), output_blocktype′}, axes)
end

# BlockSparseMatrix * dense Vector → dense Vector
# Returns a plain Vector instead of a BlockedVector
function Base.similar(
mul::MulAdd{
<:BlockLayout{<:SparseLayout},
<:DenseColumnMajor,
<:Any,
},
elt::Type,
axes::Tuple{<:AbstractUnitRange},
)
# Convert blocked axes to plain axes to avoid creating BlockedVector
plain_axes = map(ax -> Base.OneTo(length(ax)), axes)
return similar(mul.B, elt, plain_axes)
end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
# TODO: Define `blocktype`/`blockstype` for `SubArray` wrapping `BlockSparseArray`.
Expand Down
30 changes: 26 additions & 4 deletions src/blocksparsearrayinterface/arraylayouts.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
using ArrayLayouts: ArrayLayouts, DenseColumnMajor, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
using BlockArrays: BlockArrays, BlockLayout, muladd!
using SparseArraysBase: SparseLayout
using LinearAlgebra: LinearAlgebra, dot, mul!

const muladd!_blocksparse = blocksparse_style(muladd!)
# Matrix-matrix case
function muladd!_blocksparse(
α::Number, a1::AbstractArray, a2::AbstractArray, β::Number, a_dest::AbstractArray
α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix
)
mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
return a_dest
end

# Matrix-vector case: BlockSparseMatrix * dense Vector
function muladd!_blocksparse(
α::Number, a1::AbstractMatrix, a2::AbstractVector, β::Number, a_dest::AbstractVector
)
mul!_blocksparse(a_dest, a1, a2, α, β)
return a_dest
end

function ArrayLayouts.materialize!(
m::MatMulMatAdd{
<:BlockLayout{<:SparseLayout},
Expand All @@ -21,15 +30,28 @@ function ArrayLayouts.materialize!(
muladd!_blocksparse(m.α, m.A, m.B, m.β, m.C)
return m.C
end

# BlockSparseMatrix * dense Vector → dense Vector
function ArrayLayouts.materialize!(
m::MatMulVecAdd{
<:BlockLayout{<:SparseLayout},
<:DenseColumnMajor,
<:DenseColumnMajor,
},
)
muladd!_blocksparse(m.α, m.A, m.B, m.β, m.C)
return m.C
end

# BlockSparseMatrix * BlockSparseVector (not yet implemented)
function ArrayLayouts.materialize!(
m::MatMulVecAdd{
<:BlockLayout{<:SparseLayout},
<:BlockLayout{<:SparseLayout},
<:BlockLayout{<:SparseLayout},
},
)
error("Not implemented.")
matmul!(m)
error("BlockSparseMatrix * BlockSparseVector not implemented.")
return m.C
end

Expand Down
29 changes: 29 additions & 0 deletions src/blocksparsearrayinterface/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using BlockArrays: Block
using LinearAlgebra: LinearAlgebra, mul!

const mul!_blocksparse = blocksparse_style(mul!)
Expand All @@ -11,3 +12,31 @@ function mul!_blocksparse(
mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
return a_dest
end

# Matrix-vector multiplication: BlockSparseMatrix * dense Vector
function mul!_blocksparse(
a_dest::AbstractVector,
a1::AbstractMatrix,
a2::AbstractVector,
α::Number = true,
β::Number = false,
)
# Scale destination by β (or zero it if β == 0)
if iszero(β)
fill!(a_dest, zero(eltype(a_dest)))
elseif !isone(β)
a_dest .*= β
end

# Accumulate A[i,j] * v[j_range] into C[i_range]
for I in eachblockstoredindex(a1)
i, j = Int.(Tuple(I))
block_A = a1[I]
row_range = axes(a1, 1)[Block(i)]
col_range = axes(a1, 2)[Block(j)]
v_block = @view a2[col_range]
c_block = @view a_dest[row_range]
mul!(c_block, block_A, v_block, α, true) # β=true to accumulate
end
return a_dest
end
16 changes: 16 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,22 @@ arrayts = (Array, JLArray)
@test a1' * a2 ≈ Array(a1)' * Array(a2)
@test dot(a1, a2) ≈ a1' * a2
end
@testset "Matrix-vector multiplication" begin
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a[Block(1, 2)] = dev(randn(elt, size(@view(a[Block(1, 2)]))))
a[Block(2, 1)] = dev(randn(elt, size(@view(a[Block(2, 1)]))))
v = dev(randn(elt, 5)) # Dense vector matching column dimension

c = a * v
@test Array(c) ≈ Array(a) * Array(v)
@test c isa typeof(v) # Result should match input vector type
@test length(c) == size(a, 1)

# Test with adjoint/transpose
v2 = dev(randn(elt, 5)) # Vector matching row dimension for transposed multiplication
@test Array(a' * v2) ≈ Array(a)' * Array(v2)
@test Array(transpose(a) * v2) ≈ transpose(Array(a)) * Array(v2)
end
@testset "cat" begin
a1 = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))
Expand Down
Loading