Skip to content

Commit db120cf

Browse files
authored
GradedArray (#6)
1 parent ac926e3 commit db120cf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+831
-221
lines changed

Project.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,34 @@ version = "0.1.0"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
8+
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
89
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
10+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
911
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1012
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1114
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1215
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
1316
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
17+
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
18+
19+
[weakdeps]
20+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
21+
22+
[extensions]
23+
GradedArraysTensorAlgebraExt = "TensorAlgebra"
1424

1525
[compat]
1626
BlockArrays = "1.5.0"
27+
BlockSparseArrays = "0.4.0"
1728
Compat = "4.16.0"
29+
DerivableInterfaces = "0.4.4"
1830
FillArrays = "1.13.0"
1931
HalfIntegers = "1.6.0"
32+
LinearAlgebra = "1.10.0"
2033
Random = "1.10.0"
2134
SplitApplyCombine = "1.2.3"
35+
TensorAlgebra = "0.2.7"
2236
TensorProducts = "0.1.3"
37+
TypeParameterAccessors = "0.3.9"
2338
julia = "1.10"

docs/Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
[deps]
2-
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
32
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
5+
6+
[compat]
7+
Documenter = "1"
8+
GradedArrays = "0.1"
9+
Literate = "2"

examples/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
[deps]
22
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
3+
4+
[compat]
5+
GradedArrays = "0.1"
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
module GradedArraysTensorAlgebraExt
2+
3+
using BlockArrays: Block, BlockIndexRange, blockedrange, blocks
4+
using BlockSparseArrays:
5+
BlockSparseArrays,
6+
AbstractBlockSparseArray,
7+
AbstractBlockSparseArrayInterface,
8+
BlockSparseArray,
9+
BlockSparseArrayInterface,
10+
BlockSparseMatrix,
11+
BlockSparseVector,
12+
block_merge
13+
using DerivableInterfaces: @interface
14+
using GradedArrays.GradedUnitRanges:
15+
GradedUnitRanges,
16+
AbstractGradedUnitRange,
17+
blockmergesortperm,
18+
blocksortperm,
19+
dual,
20+
invblockperm,
21+
nondual,
22+
unmerged_tensor_product
23+
using LinearAlgebra: Adjoint, Transpose
24+
using TensorAlgebra:
25+
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
26+
using TensorProducts: OneToOne
27+
28+
#=
29+
reducewhile(f, op, collection, state)
30+
31+
reducewhile(x -> length(x) < 3, vcat, ["a", "b", "c", "d"], 2; init=String[]) ==
32+
(["b", "c"], 4)
33+
=#
34+
function reducewhile(f, op, collection, state; init)
35+
prev_result = init
36+
prev_state = state
37+
result = prev_result
38+
while f(result)
39+
prev_result = result
40+
prev_state = state
41+
value_and_state = iterate(collection, state)
42+
isnothing(value_and_state) && break
43+
value, state = value_and_state
44+
result = op(result, value)
45+
end
46+
return prev_result, prev_state
47+
end
48+
49+
#=
50+
groupreducewhile(f, op, collection, ngroups)
51+
52+
groupreducewhile((i, x) -> length(x) ≤ i, vcat, ["a", "b", "c", "d", "e", "f"], 3; init=String[]) ==
53+
(["a"], ["b", "c"], ["d", "e", "f"])
54+
=#
55+
function groupreducewhile(f, op, collection, ngroups; init)
56+
state = firstindex(collection)
57+
return ntuple(ngroups) do group_number
58+
result, state = reducewhile(x -> f(group_number, x), op, collection, state; init)
59+
return result
60+
end
61+
end
62+
63+
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
64+
65+
# Sort the blocks by sector and then merge the common sectors.
66+
function block_mergesort(a::AbstractArray)
67+
I = blockmergesortperm.(axes(a))
68+
return a[I...]
69+
end
70+
71+
function TensorAlgebra.fusedims(
72+
::SectorFusion, a::AbstractArray, merged_axes::AbstractUnitRange...
73+
)
74+
# First perform a fusion using a block reshape.
75+
# TODO avoid groupreducewhile. Require refactor of fusedims.
76+
unmerged_axes = groupreducewhile(
77+
unmerged_tensor_product, axes(a), length(merged_axes); init=OneToOne()
78+
) do i, axis
79+
return length(axis) length(merged_axes[i])
80+
end
81+
82+
a_reshaped = fusedims(BlockReshapeFusion(), a, unmerged_axes...)
83+
# Sort the blocks by sector and merge the equivalent sectors.
84+
return block_mergesort(a_reshaped)
85+
end
86+
87+
function TensorAlgebra.splitdims(
88+
::SectorFusion, a::AbstractArray, split_axes::AbstractUnitRange...
89+
)
90+
# First, fuse axes to get `blockmergesortperm`.
91+
# Then unpermute the blocks.
92+
axes_prod = groupreducewhile(
93+
unmerged_tensor_product, split_axes, ndims(a); init=OneToOne()
94+
) do i, axis
95+
return length(axis) length(axes(a, i))
96+
end
97+
blockperms = blocksortperm.(axes_prod)
98+
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
99+
100+
# TODO: This is doing extra copies of the blocks,
101+
# use `@view a[axes_prod...]` instead.
102+
# That will require implementing some reindexing logic
103+
# for this combination of slicing.
104+
a_unblocked = a[sorted_axes...]
105+
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
106+
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
107+
end
108+
109+
end

src/GradedArrays.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,32 @@
11
module GradedArrays
22

3-
include("lib/LabelledNumbers/LabelledNumbers.jl")
3+
include("LabelledNumbers/LabelledNumbers.jl")
44
using .LabelledNumbers: LabelledNumbers
5-
include("lib/GradedUnitRanges/GradedUnitRanges.jl")
6-
using .GradedUnitRanges: GradedUnitRanges
7-
include("lib/SymmetrySectors/SymmetrySectors.jl")
5+
include("GradedUnitRanges/GradedUnitRanges.jl")
6+
# This makes the following names accessible
7+
# as `GradedArrays.x`.
8+
using .GradedUnitRanges:
9+
GradedUnitRanges,
10+
GradedOneTo,
11+
GradedUnitRange,
12+
GradedUnitRangeDual,
13+
LabelledUnitRangeDual,
14+
blocklabels,
15+
blockmergesortperm,
16+
blocksortperm,
17+
dag,
18+
dual,
19+
dual_type,
20+
flip,
21+
gradedrange,
22+
isdual,
23+
nondual,
24+
nondual_type,
25+
sector_type,
26+
space_isequal,
27+
unmerged_tensor_product
28+
include("SymmetrySectors/SymmetrySectors.jl")
829
using .SymmetrySectors: SymmetrySectors
30+
include("gradedarray.jl")
931

1032
end

src/lib/GradedUnitRanges/GradedUnitRanges.jl renamed to src/GradedUnitRanges/GradedUnitRanges.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module GradedUnitRanges
22

33
export gradedrange
44

5-
include("blockedunitrange.jl")
65
include("gradedunitrange.jl")
76
include("dual.jl")
87
include("labelledunitrangedual.jl")
File renamed without changes.
File renamed without changes.

src/lib/GradedUnitRanges/gradedunitrange.jl renamed to src/GradedUnitRanges/gradedunitrange.jl

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
using BlockArrays:
22
BlockArrays,
3+
AbstractBlockVector,
34
AbstractBlockedUnitRange,
45
Block,
6+
BlockIndex,
57
BlockIndexRange,
68
BlockRange,
79
BlockSlice,
10+
BlockVector,
811
BlockedOneTo,
912
BlockedUnitRange,
1013
block,
@@ -15,7 +18,14 @@ using BlockArrays:
1518
blocks,
1619
blockindex,
1720
combine_blockaxes,
21+
mortar,
1822
sortedunion
23+
using BlockSparseArrays:
24+
BlockSparseArrays,
25+
_blocks,
26+
blockedunitrange_findblock,
27+
blockedunitrange_findblockindex,
28+
blockedunitrange_getindices
1929
using Compat: allequal
2030
using FillArrays: Fill
2131
using ..LabelledNumbers:
@@ -127,11 +137,15 @@ function BlockArrays.findblock(a::AbstractGradedUnitRange, index::Integer)
127137
return blockedunitrange_findblock(unlabel_blocks(a), index)
128138
end
129139

130-
function blockedunitrange_findblock(a::AbstractGradedUnitRange, index::Integer)
140+
function BlockSparseArrays.blockedunitrange_findblock(
141+
a::AbstractGradedUnitRange, index::Integer
142+
)
131143
return blockedunitrange_findblock(unlabel_blocks(a), index)
132144
end
133145

134-
function blockedunitrange_findblockindex(a::AbstractGradedUnitRange, index::Integer)
146+
function BlockSparseArrays.blockedunitrange_findblockindex(
147+
a::AbstractGradedUnitRange, index::Integer
148+
)
135149
return blockedunitrange_findblockindex(unlabel_blocks(a), index)
136150
end
137151

@@ -221,32 +235,36 @@ function firstblockindices(a::AbstractGradedUnitRange)
221235
return labelled.(firstblockindices(unlabel_blocks(a)), blocklabels(a))
222236
end
223237

224-
function blockedunitrange_getindices(a::AbstractGradedUnitRange, index::Block{1})
238+
function BlockSparseArrays.blockedunitrange_getindices(
239+
a::AbstractGradedUnitRange, index::Block{1}
240+
)
225241
return labelled(unlabel_blocks(a)[index], get_label(a, index))
226242
end
227243

228-
function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::Vector{<:Integer})
244+
function BlockSparseArrays.blockedunitrange_getindices(
245+
a::AbstractGradedUnitRange, indices::Vector{<:Integer}
246+
)
229247
return map(index -> a[index], indices)
230248
end
231249

232-
function blockedunitrange_getindices(
250+
function BlockSparseArrays.blockedunitrange_getindices(
233251
a::AbstractGradedUnitRange,
234252
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}},
235253
)
236254
return mortar(map(b -> a[b], blocks(indices)))
237255
end
238256

239-
function blockedunitrange_getindices(a::AbstractGradedUnitRange, index)
257+
function BlockSparseArrays.blockedunitrange_getindices(a::AbstractGradedUnitRange, index)
240258
return labelled(unlabel_blocks(a)[index], get_label(a, index))
241259
end
242260

243-
function blockedunitrange_getindices(
261+
function BlockSparseArrays.blockedunitrange_getindices(
244262
a::AbstractGradedUnitRange, indices::BlockIndexRange{1}
245263
)
246264
return a[block(indices)][only(indices.indices)]
247265
end
248266

249-
function blockedunitrange_getindices(
267+
function BlockSparseArrays.blockedunitrange_getindices(
250268
a::AbstractGradedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
251269
)
252270
# Without converting `indices` to `Vector`,
@@ -268,22 +286,28 @@ function blocklabels(a::AbstractUnitRange, indices)
268286
end
269287
end
270288

271-
function blockedunitrange_getindices(
289+
function BlockSparseArrays.blockedunitrange_getindices(
272290
ga::AbstractGradedUnitRange, indices::AbstractUnitRange{<:Integer}
273291
)
274292
a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices)
275293
return labelled_blocks(a_indices, blocklabels(ga, indices))
276294
end
277295

278-
function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice)
296+
function BlockSparseArrays.blockedunitrange_getindices(
297+
a::AbstractGradedUnitRange, indices::BlockSlice
298+
)
279299
return a[indices.block]
280300
end
281301

282-
function blockedunitrange_getindices(ga::AbstractGradedUnitRange, indices::BlockRange)
302+
function BlockSparseArrays.blockedunitrange_getindices(
303+
ga::AbstractGradedUnitRange, indices::BlockRange
304+
)
283305
return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices))
284306
end
285307

286-
function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndex{1})
308+
function BlockSparseArrays.blockedunitrange_getindices(
309+
a::AbstractGradedUnitRange, indices::BlockIndex{1}
310+
)
287311
return a[block(indices)][blockindex(indices)]
288312
end
289313

@@ -398,7 +422,7 @@ end
398422
# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices))
399423
# return blockedrange(blocklengths)
400424
# ```
401-
function blockedunitrange_getindices(
425+
function BlockSparseArrays.blockedunitrange_getindices(
402426
a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}}
403427
)
404428
blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))

0 commit comments

Comments
 (0)