11module BlockSparseArraysTensorAlgebraExt
2- using BlockArrays: AbstractBlockedUnitRange
3-
4- using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5- using TensorProducts: OneToOne
62
3+ using BlockArrays: AbstractBlockedUnitRange
74using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
5+ using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
86
97TensorAlgebra. FusionStyle (:: AbstractBlockedUnitRange ) = BlockReshapeFusion ()
108
@@ -20,99 +18,4 @@ function TensorAlgebra.splitdims(
2018 return blockreshape (a, axes)
2119end
2220
23- using BlockArrays:
24- AbstractBlockVector,
25- AbstractBlockedUnitRange,
26- Block,
27- BlockIndexRange,
28- blockedrange,
29- blocks
30- using BlockSparseArrays:
31- BlockSparseArrays,
32- AbstractBlockSparseArray,
33- AbstractBlockSparseArrayInterface,
34- AbstractBlockSparseMatrix,
35- BlockSparseArray,
36- BlockSparseArrayInterface,
37- BlockSparseMatrix,
38- BlockSparseVector,
39- block_merge
40- using DerivableInterfaces: @interface
41- using GradedUnitRanges:
42- GradedUnitRanges,
43- AbstractGradedUnitRange,
44- blockmergesortperm,
45- blocksortperm,
46- dual,
47- invblockperm,
48- nondual,
49- unmerged_tensor_product
50- using LinearAlgebra: Adjoint, Transpose
51- using TensorAlgebra:
52- TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
53-
54- # TODO : Make a `ReduceWhile` library.
55- include (" reducewhile.jl" )
56-
57- TensorAlgebra. FusionStyle (:: AbstractGradedUnitRange ) = SectorFusion ()
58-
59- # TODO : Need to implement this! Will require implementing
60- # `block_merge(a::AbstractUnitRange, blockmerger::BlockedUnitRange)`.
61- function BlockSparseArrays. block_merge (
62- a:: AbstractGradedUnitRange , blockmerger:: AbstractBlockedUnitRange
63- )
64- return a
65- end
66-
67- # Sort the blocks by sector and then merge the common sectors.
68- function block_mergesort (a:: AbstractArray )
69- I = blockmergesortperm .(axes (a))
70- return a[I... ]
71- end
72-
73- function TensorAlgebra. fusedims (
74- :: SectorFusion , a:: AbstractArray , merged_axes:: AbstractUnitRange...
75- )
76- # First perform a fusion using a block reshape.
77- # TODO avoid groupreducewhile. Require refactor of fusedims.
78- unmerged_axes = groupreducewhile (
79- unmerged_tensor_product, axes (a), length (merged_axes); init= OneToOne ()
80- ) do i, axis
81- return length (axis) ≤ length (merged_axes[i])
82- end
83-
84- a_reshaped = fusedims (BlockReshapeFusion (), a, unmerged_axes... )
85- # Sort the blocks by sector and merge the equivalent sectors.
86- return block_mergesort (a_reshaped)
87- end
88-
89- function TensorAlgebra. splitdims (
90- :: SectorFusion , a:: AbstractArray , split_axes:: AbstractUnitRange...
91- )
92- # First, fuse axes to get `blockmergesortperm`.
93- # Then unpermute the blocks.
94- axes_prod = groupreducewhile (
95- unmerged_tensor_product, split_axes, ndims (a); init= OneToOne ()
96- ) do i, axis
97- return length (axis) ≤ length (axes (a, i))
98- end
99- blockperms = blocksortperm .(axes_prod)
100- sorted_axes = map ((r, I) -> only (axes (r[I])), axes_prod, blockperms)
101-
102- # TODO : This is doing extra copies of the blocks,
103- # use `@view a[axes_prod...]` instead.
104- # That will require implementing some reindexing logic
105- # for this combination of slicing.
106- a_unblocked = a[sorted_axes... ]
107- a_blockpermed = a_unblocked[invblockperm .(blockperms)... ]
108- return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes... )
109- end
110-
111- # TODO : Handle this through some kind of trait dispatch, maybe
112- # a `SymmetryStyle`-like trait to check if the block sparse
113- # matrix has graded axes.
114- function Base. axes (a:: Adjoint{<:Any,<:AbstractBlockSparseMatrix} )
115- return dual .(reverse (axes (a' )))
116- end
117-
11821end
0 commit comments