Skip to content

Commit f8eda16

Browse files
committed
fix dependencies
1 parent b42cb13 commit f8eda16

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2323
[weakdeps]
2424
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
2525
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
26+
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2627

2728
[extensions]
2829
BlockSparseArraysGradedUnitRangesExt = "GradedUnitRanges"
29-
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"]
30+
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorProducts", "TensorAlgebra"]
3031

3132
[compat]
3233
Adapt = "4.1.1"
@@ -38,14 +39,15 @@ DiagonalArrays = "0.3"
3839
Dictionaries = "0.4.3"
3940
FillArrays = "1.13.0"
4041
GPUArraysCore = "0.1.0, 0.2"
41-
GradedUnitRanges = "0.2"
42+
GradedUnitRanges = "0.2.2"
4243
LabelledNumbers = "0.1.0"
4344
LinearAlgebra = "1.10"
4445
MacroTools = "0.5.13"
4546
MapBroadcast = "0.1.5"
4647
SparseArraysBase = "0.5"
4748
SplitApplyCombine = "1.2.3"
4849
TensorAlgebra = "0.2.4"
50+
TensorProducts = "0.1.2"
4951
Test = "1.10"
5052
TypeParameterAccessors = "0.2.0, 0.3"
5153
julia = "1.10"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module BlockSparseArraysTensorAlgebraExt
22
using BlockArrays: AbstractBlockedUnitRange
3+
34
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5+
using TensorProducts: OneToOne
46

5-
using BlockArrays: AbstractBlockedUnitRange
67
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
7-
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
88

99
TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
1010

@@ -45,7 +45,8 @@ using GradedUnitRanges:
4545
blocksortperm,
4646
dual,
4747
invblockperm,
48-
nondual
48+
nondual,
49+
unmerged_tensor_product
4950
using LinearAlgebra: Adjoint, Transpose
5051
using TensorAlgebra:
5152
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
@@ -70,10 +71,17 @@ function block_mergesort(a::AbstractArray)
7071
end
7172

7273
function TensorAlgebra.fusedims(
73-
::SectorFusion, a::AbstractArray, axes::AbstractUnitRange...
74+
::SectorFusion, a::AbstractArray, merged_axes::AbstractUnitRange...
7475
)
7576
# First perform a fusion using a block reshape.
76-
a_reshaped = fusedims(BlockReshapeFusion(), a, axes...)
77+
# TODO avoid groupreducewhile. Require refactor of fusedims.
78+
unmerged_axes = groupreducewhile(
79+
unmerged_tensor_product, axes(a), length(merged_axes); init=OneToOne()
80+
) do i, axis
81+
return length(axis) length(merged_axes[i])
82+
end
83+
84+
a_reshaped = fusedims(BlockReshapeFusion(), a, unmerged_axes...)
7785
# Sort the blocks by sector and merge the equivalent sectors.
7886
return block_mergesort(a_reshaped)
7987
end
@@ -83,10 +91,11 @@ function TensorAlgebra.splitdims(
8391
)
8492
# First, fuse axes to get `blockmergesortperm`.
8593
# Then unpermute the blocks.
86-
axes_prod =
87-
groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis
88-
return length(axis) length(axes(a, i))
89-
end
94+
axes_prod = groupreducewhile(
95+
unmerged_tensor_product, split_axes, ndims(a); init=OneToOne()
96+
) do i, axis
97+
return length(axis) length(axes(a, i))
98+
end
9099
blockperms = blocksortperm.(axes_prod)
91100
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
92101

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ BlockArrays = "1"
2929
BlockSparseArrays = "0.3"
3030
DiagonalArrays = "0.3"
3131
GPUArraysCore = "0.2"
32-
GradedUnitRanges = "0.2"
32+
GradedUnitRanges = "0.2.2"
3333
JLArrays = "0.2"
3434
LabelledNumbers = "0.1"
3535
LinearAlgebra = "1"

test/test_gradedunitrangesext.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using GradedUnitRanges:
99
GradedUnitRange,
1010
GradedUnitRangeDual,
1111
blocklabels,
12+
dag,
1213
dual,
1314
gradedrange,
1415
isdual

0 commit comments

Comments
 (0)