11module BlockSparseArraysTensorAlgebraExt
22
33using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
4- using TensorAlgebra:
5- TensorAlgebra,
6- BlockedTrivialPermutation,
7- BlockedTuple,
8- FusionStyle,
9- ReshapeFusion,
10- fuseaxes
4+ using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes
115
126struct BlockReshapeFusion <: FusionStyle end
137
@@ -20,12 +14,12 @@ using BlockSparseArrays: blocksparse
2014using SparseArraysBase: eachstoredindex
2115using TensorAlgebra: TensorAlgebra, matricize, unmatricize
2216function TensorAlgebra. matricize (
23- :: BlockReshapeFusion , a:: AbstractArray , biperm :: BlockedTrivialPermutation{2}
17+ :: BlockReshapeFusion , a:: AbstractArray , length1 :: Val , length2 :: Val
2418 )
25- ax = fuseaxes (axes (a), biperm )
19+ ax = fuseaxes (axes (a), length1, length2 )
2620 reshaped_blocks_a = reshape (blocks (a), map (blocklength, ax))
2721 key (I) = Block (Tuple (I))
28- value (I) = matricize (reshaped_blocks_a[I], biperm )
22+ value (I) = matricize (reshaped_blocks_a[I], length1, length2 )
2923 Is = eachstoredindex (reshaped_blocks_a)
3024 bs = if isempty (Is)
3125 # Catch empty case and make sure the type is constrained properly.
@@ -45,16 +39,17 @@ using BlockArrays: blocklengths
4539function TensorAlgebra. unmatricize (
4640 :: BlockReshapeFusion ,
4741 m:: AbstractMatrix ,
48- blocked_ax:: BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} ,
42+ codomain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
43+ domain_axes:: Tuple{Vararg{AbstractUnitRange}} ,
4944 )
50- ax = Tuple (blocked_ax )
45+ ax = (codomain_axes ... , domain_axes ... )
5146 reshaped_blocks_m = reshape (blocks (m), map (blocklength, ax))
5247 function f (I)
5348 block_axes_I = BlockedTuple (
5449 map (ntuple (identity, length (ax))) do i
5550 return Base. axes1 (ax[i][Block (I[i])])
5651 end ,
57- blocklengths (blocked_ax ),
52+ ( length (codomain_axes), length (domain_axes) ),
5853 )
5954 return unmatricize (reshaped_blocks_m[I], block_axes_I)
6055 end
0 commit comments