Skip to content

Commit d41e8e7

Browse files
committed
Start fixing extensions
1 parent 9c64d2a commit d41e8e7

File tree

4 files changed

+102
-97
lines changed

4 files changed

+102
-97
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2626

2727
[extensions]
2828
BlockSparseArraysGradedUnitRangesExt = "GradedUnitRanges"
29-
BlockSparseArraysTensorAlgebraExt = ["TensorProducts", "TensorAlgebra"]
29+
BlockSparseArraysGradedUnitRangesTensorAlgebraExt = ["GradedUnitRanges", "TensorAlgebra", "TensorProducts"]
30+
BlockSparseArraysTensorAlgebraExt = "TensorAlgebra"
3031

3132
[compat]
3233
Adapt = "4.1.1"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
module BlockSparseArraysGradedUnitRangesTensorAlgebraExt
2+
3+
using BlockArrays:
4+
AbstractBlockVector,
5+
AbstractBlockedUnitRange,
6+
Block,
7+
BlockIndexRange,
8+
blockedrange,
9+
blocks
10+
using BlockSparseArrays:
11+
BlockSparseArrays,
12+
AbstractBlockSparseArray,
13+
AbstractBlockSparseArrayInterface,
14+
AbstractBlockSparseMatrix,
15+
BlockSparseArray,
16+
BlockSparseArrayInterface,
17+
BlockSparseMatrix,
18+
BlockSparseVector,
19+
block_merge
20+
using DerivableInterfaces: @interface
21+
using GradedUnitRanges:
22+
GradedUnitRanges,
23+
AbstractGradedUnitRange,
24+
blockmergesortperm,
25+
blocksortperm,
26+
dual,
27+
invblockperm,
28+
nondual,
29+
unmerged_tensor_product
30+
using LinearAlgebra: Adjoint, Transpose
31+
using TensorAlgebra:
32+
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
33+
using TensorProducts: OneToOne
34+
35+
# TODO: Make a `ReduceWhile` library.
36+
include("reducewhile.jl")
37+
38+
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
39+
40+
# TODO: Need to implement this! Will require implementing
41+
# `block_merge(a::AbstractUnitRange, blockmerger::BlockedUnitRange)`.
42+
function BlockSparseArrays.block_merge(
43+
a::AbstractGradedUnitRange, blockmerger::AbstractBlockedUnitRange
44+
)
45+
return a
46+
end
47+
48+
# Sort the blocks by sector and then merge the common sectors.
49+
function block_mergesort(a::AbstractArray)
50+
I = blockmergesortperm.(axes(a))
51+
return a[I...]
52+
end
53+
54+
function TensorAlgebra.fusedims(
55+
::SectorFusion, a::AbstractArray, merged_axes::AbstractUnitRange...
56+
)
57+
# First perform a fusion using a block reshape.
58+
# TODO avoid groupreducewhile. Require refactor of fusedims.
59+
unmerged_axes = groupreducewhile(
60+
unmerged_tensor_product, axes(a), length(merged_axes); init=OneToOne()
61+
) do i, axis
62+
return length(axis) length(merged_axes[i])
63+
end
64+
65+
a_reshaped = fusedims(BlockReshapeFusion(), a, unmerged_axes...)
66+
# Sort the blocks by sector and merge the equivalent sectors.
67+
return block_mergesort(a_reshaped)
68+
end
69+
70+
function TensorAlgebra.splitdims(
71+
::SectorFusion, a::AbstractArray, split_axes::AbstractUnitRange...
72+
)
73+
# First, fuse axes to get `blockmergesortperm`.
74+
# Then unpermute the blocks.
75+
axes_prod = groupreducewhile(
76+
unmerged_tensor_product, split_axes, ndims(a); init=OneToOne()
77+
) do i, axis
78+
return length(axis) length(axes(a, i))
79+
end
80+
blockperms = blocksortperm.(axes_prod)
81+
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
82+
83+
# TODO: This is doing extra copies of the blocks,
84+
# use `@view a[axes_prod...]` instead.
85+
# That will require implementing some reindexing logic
86+
# for this combination of slicing.
87+
a_unblocked = a[sorted_axes...]
88+
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
89+
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
90+
end
91+
92+
# TODO: Handle this through some kind of trait dispatch, maybe
93+
# a `SymmetryStyle`-like trait to check if the block sparse
94+
# matrix has graded axes.
95+
function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
96+
return dual.(reverse(axes(a')))
97+
end
98+
99+
end
Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module BlockSparseArraysTensorAlgebraExt
2+
23
using BlockArrays: AbstractBlockedUnitRange
34

45
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5-
using TensorProducts: OneToOne
66

77
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
88

@@ -20,99 +20,4 @@ function TensorAlgebra.splitdims(
2020
return blockreshape(a, axes)
2121
end
2222

23-
using BlockArrays:
24-
AbstractBlockVector,
25-
AbstractBlockedUnitRange,
26-
Block,
27-
BlockIndexRange,
28-
blockedrange,
29-
blocks
30-
using BlockSparseArrays:
31-
BlockSparseArrays,
32-
AbstractBlockSparseArray,
33-
AbstractBlockSparseArrayInterface,
34-
AbstractBlockSparseMatrix,
35-
BlockSparseArray,
36-
BlockSparseArrayInterface,
37-
BlockSparseMatrix,
38-
BlockSparseVector,
39-
block_merge
40-
using DerivableInterfaces: @interface
41-
using GradedUnitRanges:
42-
GradedUnitRanges,
43-
AbstractGradedUnitRange,
44-
blockmergesortperm,
45-
blocksortperm,
46-
dual,
47-
invblockperm,
48-
nondual,
49-
unmerged_tensor_product
50-
using LinearAlgebra: Adjoint, Transpose
51-
using TensorAlgebra:
52-
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
53-
54-
# TODO: Make a `ReduceWhile` library.
55-
include("reducewhile.jl")
56-
57-
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
58-
59-
# TODO: Need to implement this! Will require implementing
60-
# `block_merge(a::AbstractUnitRange, blockmerger::BlockedUnitRange)`.
61-
function BlockSparseArrays.block_merge(
62-
a::AbstractGradedUnitRange, blockmerger::AbstractBlockedUnitRange
63-
)
64-
return a
65-
end
66-
67-
# Sort the blocks by sector and then merge the common sectors.
68-
function block_mergesort(a::AbstractArray)
69-
I = blockmergesortperm.(axes(a))
70-
return a[I...]
71-
end
72-
73-
function TensorAlgebra.fusedims(
74-
::SectorFusion, a::AbstractArray, merged_axes::AbstractUnitRange...
75-
)
76-
# First perform a fusion using a block reshape.
77-
# TODO avoid groupreducewhile. Require refactor of fusedims.
78-
unmerged_axes = groupreducewhile(
79-
unmerged_tensor_product, axes(a), length(merged_axes); init=OneToOne()
80-
) do i, axis
81-
return length(axis) length(merged_axes[i])
82-
end
83-
84-
a_reshaped = fusedims(BlockReshapeFusion(), a, unmerged_axes...)
85-
# Sort the blocks by sector and merge the equivalent sectors.
86-
return block_mergesort(a_reshaped)
87-
end
88-
89-
function TensorAlgebra.splitdims(
90-
::SectorFusion, a::AbstractArray, split_axes::AbstractUnitRange...
91-
)
92-
# First, fuse axes to get `blockmergesortperm`.
93-
# Then unpermute the blocks.
94-
axes_prod = groupreducewhile(
95-
unmerged_tensor_product, split_axes, ndims(a); init=OneToOne()
96-
) do i, axis
97-
return length(axis) length(axes(a, i))
98-
end
99-
blockperms = blocksortperm.(axes_prod)
100-
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
101-
102-
# TODO: This is doing extra copies of the blocks,
103-
# use `@view a[axes_prod...]` instead.
104-
# That will require implementing some reindexing logic
105-
# for this combination of slicing.
106-
a_unblocked = a[sorted_axes...]
107-
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
108-
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
109-
end
110-
111-
# TODO: Handle this through some kind of trait dispatch, maybe
112-
# a `SymmetryStyle`-like trait to check if the block sparse
113-
# matrix has graded axes.
114-
function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
115-
return dual.(reverse(axes(a')))
116-
end
117-
11823
end

0 commit comments

Comments
 (0)