Skip to content

Commit 1a05656

Browse files
committed
KroneckerArraysTensorAlgebraExt extension
1 parent 84c66af commit 1a05656

File tree

3 files changed

+43
-37
lines changed

3 files changed

+43
-37
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1414
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
15-
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1615

1716
[weakdeps]
1817
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1918
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
19+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2020
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
2121

2222
[extensions]
2323
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
24+
KroneckerArraysTensorAlgebraExt = "TensorAlgebra"
2425
KroneckerArraysTensorProductsExt = "TensorProducts"
2526

2627
[compat]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
module KroneckerArraysTensorAlgebraExt
2+
3+
using KroneckerArrays: KroneckerArrays, KroneckerArray
4+
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize
5+
6+
struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle
7+
a::A
8+
b::B
9+
end
10+
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
11+
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
12+
function TensorAlgebra.FusionStyle(a::KroneckerArray)
13+
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
14+
end
15+
function matricize_kronecker(
16+
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
17+
)
18+
return matricize(arg1(style), arg1(a), biperm) matricize(arg2(style), arg2(a), biperm)
19+
end
20+
function TensorAlgebra.matricize(
21+
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
22+
)
23+
return matricize_kronecker(style, a, biperm)
24+
end
25+
# Fix ambiguity error.
26+
# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this.
27+
using TensorAlgebra: BlockedTrivialPermutation, unmatricize
28+
function TensorAlgebra.matricize(
29+
style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
30+
)
31+
return matricize_kronecker(style, a, biperm)
32+
end
33+
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
34+
return unmatricize(arg1(style), arg1(a), arg1.(ax))
35+
unmatricize(arg2(style), arg2(a), arg2.(ax))
36+
end
37+
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
38+
return unmatricize_kronecker(style, a, ax)
39+
end
40+
41+
end

src/kroneckerarray.jl

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -494,39 +494,3 @@ function Base.broadcasted(
494494
)
495495
return broadcasted(style, /, a, f.args[2])
496496
end
497-
498-
using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize
499-
struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle
500-
a::A
501-
b::B
502-
end
503-
arg1(style::KroneckerFusion) = style.a
504-
arg2(style::KroneckerFusion) = style.b
505-
function TensorAlgebra.FusionStyle(a::KroneckerArray)
506-
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
507-
end
508-
function matricize_kronecker(
509-
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
510-
)
511-
return matricize(arg1(style), arg1(a), biperm) matricize(arg2(style), arg2(a), biperm)
512-
end
513-
function TensorAlgebra.matricize(
514-
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
515-
)
516-
return matricize_kronecker(style, a, biperm)
517-
end
518-
# Fix ambiguity error.
519-
# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this.
520-
using TensorAlgebra: BlockedTrivialPermutation, unmatricize
521-
function TensorAlgebra.matricize(
522-
style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
523-
)
524-
return matricize_kronecker(style, a, biperm)
525-
end
526-
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
527-
return unmatricize(arg1(style), arg1(a), arg1.(ax))
528-
unmatricize(arg2(style), arg2(a), arg2.(ax))
529-
end
530-
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
531-
return unmatricize_kronecker(style, a, ax)
532-
end

0 commit comments

Comments
 (0)