Skip to content

Commit ad8c6b2

Browse files
committed
Revert changes to SVD
1 parent d4365f0 commit ad8c6b2

File tree

4 files changed

+23
-8
lines changed

4 files changed

+23
-8
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2222

2323
[weakdeps]
2424
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
25+
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2526

2627
[extensions]
2728
BlockSparseArraysTensorAlgebraExt = "TensorAlgebra"
29+
BlockSparseArraysTensorProductsExt = "TensorProducts"
2830

2931
[compat]
3032
Adapt = "4.1.1"
@@ -43,6 +45,7 @@ MatrixAlgebraKit = "0.2.2"
4345
SparseArraysBase = "0.7.1"
4446
SplitApplyCombine = "1.2.3"
4547
TensorAlgebra = "0.3.2"
48+
TensorProducts = "0.1.7"
4649
Test = "1.10"
4750
TypeParameterAccessors = "0.4.1"
4851
julia = "1.10"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module BlockSparseArraysTensorProductsExt
2+
3+
using BlockSparseArrays: BlockUnitRange, blockrange, eachblockaxis
4+
using TensorProducts: TensorProducts, tensor_product
5+
# TODO: Dispatch on `FusionStyle` to allow different kinds of products,
6+
# for example to allow merging common symmetry sectors.
7+
function TensorProducts.tensor_product(a1::BlockUnitRange, a2::BlockUnitRange)
8+
new_blockaxes = vec(
9+
map(splat(tensor_product), Iterators.product(eachblockaxis(a1), eachblockaxis(a2)))
10+
)
11+
return blockrange(new_blockaxes)
12+
end
13+
14+
end

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function Base.setindex!(
9393
# https://github.com/JuliaLang/julia/pull/52487).
9494
# TODO: Delete once we drop support for Julia v1.10.
9595
aI = @view! a[I...]
96-
copyto!(aI, value)
96+
aI .= value
9797
return a
9898
end
9999

src/factorizations/svd.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,21 @@ function MatrixAlgebraKit.initialize_output(
4949
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
5050
)
5151
bm, bn = blocksize(A)
52-
(bmn, mindim) = findmin((bm, bn))
52+
bmn = min(bm, bn)
5353

5454
brows = eachblockaxis(axes(A, 1))
5555
bcols = eachblockaxis(axes(A, 2))
5656
u_axes = similar(brows, bmn)
57-
v_axes = similar(bcols, bmn)
57+
v_axes = similar(brows, bmn)
5858

5959
# fill in values for blocks that are present
6060
bIs = collect(eachblockstoredindex(A))
6161
browIs = Int.(first.(Tuple.(bIs)))
6262
bcolIs = Int.(last.(Tuple.(bIs)))
6363
for bI in eachblockstoredindex(A)
6464
row, col = Int.(Tuple(bI))
65-
dim = (row, col)[mindim]
66-
u_axes[dim] = infimum(brows[row], bcols[col])
67-
v_axes[dim] = infimum(bcols[col], brows[row])
65+
u_axes[col] = infimum(brows[row], bcols[col])
66+
v_axes[col] = infimum(bcols[col], brows[row])
6867
end
6968

7069
# fill in values for blocks that aren't present, pairing them in order of occurence
@@ -84,10 +83,9 @@ function MatrixAlgebraKit.initialize_output(
8483
# allocate output
8584
for bI in eachblockstoredindex(A)
8685
brow, bcol = Tuple(bI)
87-
bdim = (brow, bcol)[mindim]
8886
block = @view!(A[bI])
8987
block_alg = block_algorithm(alg, block)
90-
U[brow, bdim], S[bdim, bdim], Vt[bdim, bcol] = MatrixAlgebraKit.initialize_output(
88+
U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit.initialize_output(
9189
svd_compact!, block, block_alg
9290
)
9391
end

0 commit comments

Comments
 (0)