Skip to content

Commit 3079e69

Browse files
committed
More generalizations
1 parent 072502d commit 3079e69

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

src/blocksparsearray/blocksparsearray.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ using BlockArrays:
66
blockedrange,
77
blocklength,
88
undef_blocks
9-
using DerivableInterfaces: @interface, similartype
9+
using DerivableInterfaces: @interface
1010
using Dictionaries: Dictionary
1111
using SparseArraysBase: SparseArrayDOK
12+
using TypeParameterAccessors: similartype
1213

1314
"""
1415
SparseArrayDOK{T}(undef_blocks, axes)

src/blocksparsearrayinterface/arraylayouts.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ using LinearAlgebra: LinearAlgebra, dot, mul!
1111
return a_dest
1212
end
1313

14+
function DerivableInterfaces.interface(m::MulAdd)
15+
return interface(m.A, m.B, m.C)
16+
end
17+
1418
function ArrayLayouts.materialize!(
1519
m::MatMulMatAdd{
1620
<:BlockLayout{<:SparseLayout},
1721
<:BlockLayout{<:SparseLayout},
1822
<:BlockLayout{<:SparseLayout},
1923
},
2024
)
21-
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
22-
@interface BlockSparseArrayInterface() muladd!(m.α, m.A, m.B, m.β, m.C)
25+
@interface interface(m) muladd!(m.α, m.A, m.B, m.β, m.C)
2326
return m.C
2427
end
2528
function ArrayLayouts.materialize!(
@@ -29,7 +32,7 @@ function ArrayLayouts.materialize!(
2932
<:BlockLayout{<:SparseLayout},
3033
},
3134
)
32-
@interface BlockSparseArrayInterface() matmul!(m)
35+
@interface interface(m) matmul!(m)
3336
return m.C
3437
end
3538

@@ -42,5 +45,5 @@ end
4245
end
4346

4447
function Base.copy(d::Dot{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}})
45-
return @interface BlockSparseArrayInterface() dot(d.A, d.B)
48+
return @interface interface(d.A, d.B) dot(d.A, d.B)
4649
end

src/factorizations/svd.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
using DiagonalArrays: diagonaltype
12
using MatrixAlgebraKit:
23
MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
4+
using TypeParameterAccessors: realtype
35

46
"""
57
BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm)
@@ -24,10 +26,7 @@ function similar_output(
2426
::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm
2527
)
2628
U = similar(A, axes(A, 1), S_axes[1])
27-
T = real(eltype(A))
28-
# TODO: this should be replaced with a more general similar function that can handle setting
29-
# the blocktype and element type - something like S = similar(A, BlockType(...))
30-
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, S_axes)
29+
S = similar(A, BlockType(diagonaltype(realtype(blocktype(A)))), S_axes)
3130
Vt = similar(A, S_axes[2], axes(A, 2))
3231
return U, S, Vt
3332
end

0 commit comments

Comments
 (0)