11module BlockSparseArraysTensorAlgebraExt
22using BlockArrays: AbstractBlockedUnitRange
3- using GradedUnitRanges: tensor_product
4- using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
53
6- function TensorAlgebra.:⊗ (a1:: AbstractBlockedUnitRange , a2:: AbstractBlockedUnitRange )
7- return tensor_product (a1, a2)
8- end
4+ using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5+ using TensorProducts: OneToOne
96
10- using BlockArrays: AbstractBlockedUnitRange
117using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
12- using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
138
149TensorAlgebra. FusionStyle (:: AbstractBlockedUnitRange ) = BlockReshapeFusion ()
1510
@@ -46,13 +41,12 @@ using DerivableInterfaces: @interface
4641using GradedUnitRanges:
4742 GradedUnitRanges,
4843 AbstractGradedUnitRange,
49- OneToOne,
5044 blockmergesortperm,
5145 blocksortperm,
5246 dual,
5347 invblockperm,
5448 nondual,
55- tensor_product
49+ unmerged_tensor_product
5650using LinearAlgebra: Adjoint, Transpose
5751using TensorAlgebra:
5852 TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
@@ -77,10 +71,17 @@ function block_mergesort(a::AbstractArray)
7771end
7872
7973function TensorAlgebra. fusedims (
80- :: SectorFusion , a:: AbstractArray , axes :: AbstractUnitRange...
74+ :: SectorFusion , a:: AbstractArray , merged_axes :: AbstractUnitRange...
8175)
8276 # First perform a fusion using a block reshape.
83- a_reshaped = fusedims (BlockReshapeFusion (), a, axes... )
77+ # TODO avoid groupreducewhile. Require refactor of fusedims.
78+ unmerged_axes = groupreducewhile (
79+ unmerged_tensor_product, axes (a), length (merged_axes); init= OneToOne ()
80+ ) do i, axis
81+ return length (axis) ≤ length (merged_axes[i])
82+ end
83+
84+ a_reshaped = fusedims (BlockReshapeFusion (), a, unmerged_axes... )
8485 # Sort the blocks by sector and merge the equivalent sectors.
8586 return block_mergesort (a_reshaped)
8687end
@@ -90,10 +91,11 @@ function TensorAlgebra.splitdims(
9091)
9192 # First, fuse axes to get `blockmergesortperm`.
9293 # Then unpermute the blocks.
93- axes_prod =
94- groupreducewhile (tensor_product, split_axes, ndims (a); init= OneToOne ()) do i, axis
95- return length (axis) ≤ length (axes (a, i))
96- end
94+ axes_prod = groupreducewhile (
95+ unmerged_tensor_product, split_axes, ndims (a); init= OneToOne ()
96+ ) do i, axis
97+ return length (axis) ≤ length (axes (a, i))
98+ end
9799 blockperms = blocksortperm .(axes_prod)
98100 sorted_axes = map ((r, I) -> only (axes (r[I])), axes_prod, blockperms)
99101
@@ -106,34 +108,11 @@ function TensorAlgebra.splitdims(
106108 return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes... )
107109end
108110
109- # This is a temporary fix for `eachindex` being broken for BlockSparseArrays
110- # with mixed dual and non-dual axes. This shouldn't be needed once
111- # GradedUnitRanges is rewritten using BlockArrays v1.
112- # TODO : Delete this once GradedUnitRanges is rewritten.
113- function Base. eachindex (a:: AbstractBlockSparseArray )
114- return CartesianIndices (nondual .(axes (a)))
115- end
116-
117111# TODO : Handle this through some kind of trait dispatch, maybe
118112# a `SymmetryStyle`-like trait to check if the block sparse
119113# matrix has graded axes.
120114function Base. axes (a:: Adjoint{<:Any,<:AbstractBlockSparseMatrix} )
121115 return dual .(reverse (axes (a' )))
122116end
123117
124- # This definition is only needed since calls like
125- # `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
126- # returns a `BlockSparseVector` instead of a `BlockVector`
127- # due to limitations in the `BlockArray` type not allowing
128- # axes with non-Int element types.
129- # TODO : Remove this once that issue is fixed,
130- # see https://github.com/JuliaArrays/BlockArrays.jl/pull/405.
131- using BlockArrays: BlockRange
132- using LabelledNumbers: label
133- function GradedUnitRanges. blocklabels (a:: BlockSparseVector )
134- return map (BlockRange (a)) do block
135- return label (blocks (a)[Int (block)])
136- end
137- end
138-
139118end
0 commit comments