Skip to content

Commit 4b8a894

Browse files
committed
adapt to permuteblockeddims
1 parent d466ca5 commit 4b8a894

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ MacroTools = "0.5.13"
4040
MapBroadcast = "0.1.5"
4141
SparseArraysBase = "0.5"
4242
SplitApplyCombine = "1.2.3"
43-
TensorAlgebra = "0.3"
43+
TensorAlgebra = "0.3.1"
4444
Test = "1.10"
4545
TypeParameterAccessors = "0.2.0, 0.3"
4646
julia = "1.10"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module BlockSparseArraysTensorAlgebraExt
33
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
44
using TensorAlgebra:
55
TensorAlgebra,
6-
AbstractBlockPermutation,
6+
BlockedTrivialPermutation,
77
BlockedTuple,
88
FusionStyle,
99
ReshapeFusion,
@@ -16,11 +16,10 @@ function TensorAlgebra.FusionStyle(::AbstractBlockSparseArray, ::ReshapeFusion)
1616
end
1717

1818
function TensorAlgebra.matricize(
19-
::BlockReshapeFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
19+
::BlockReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
2020
)
21-
a_perm = permutedims(a, Tuple(biperm))
22-
new_axes = fuseaxes(axes(a_perm), biperm)
23-
return blockreshape(a_perm, new_axes)
21+
new_axes = fuseaxes(axes(a), biperm)
22+
return blockreshape(a, new_axes)
2423
end
2524

2625
function TensorAlgebra.unmatricize(

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Random = "1"
3131
SafeTestsets = "0.1"
3232
SparseArraysBase = "0.5"
3333
Suppressor = "0.2"
34-
TensorAlgebra = "0.3"
34+
TensorAlgebra = "0.3.1"
3535
Test = "1"
3636
TestExtras = "0.3"
3737
TypeParameterAccessors = "0.3"

0 commit comments

Comments
 (0)