Skip to content

Commit e1edd71

Browse files
committed
Simplifications based on Lukas's comments
1 parent 14c971d commit e1edd71

File tree

3 files changed

+15
-65
lines changed

3 files changed

+15
-65
lines changed

src/MatrixAlgebra.jl

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,7 @@
11
module MatrixAlgebra
22

33
using LinearAlgebra: LinearAlgebra
4-
using MatrixAlgebraKit:
5-
eig_full,
6-
eig_full!,
7-
eig_trunc,
8-
eig_trunc!,
9-
eig_vals,
10-
eig_vals!,
11-
eigh_full,
12-
eigh_full!,
13-
eigh_trunc,
14-
eigh_trunc!,
15-
eigh_vals,
16-
eigh_vals!,
17-
left_orth,
18-
left_orth!,
19-
left_polar,
20-
left_polar!,
21-
lq_full,
22-
lq_full!,
23-
lq_compact,
24-
lq_compact!,
25-
qr_full,
26-
qr_full!,
27-
qr_compact,
28-
qr_compact!,
29-
right_orth,
30-
right_orth!,
31-
right_polar,
32-
right_polar!,
33-
svd_full,
34-
svd_full!,
35-
svd_compact,
36-
svd_compact!,
37-
svd_trunc,
38-
svd_trunc!
4+
using MatrixAlgebraKit
395

406
for (f, f_full, f_compact) in (
417
(:qr, :qr_full, :qr_compact),

src/TensorAlgebra.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module TensorAlgebra
33
export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals
44

55
include("MatrixAlgebra.jl")
6-
using .MatrixAlgebra: MatrixAlgebra
76
include("blockedtuple.jl")
87
include("blockedpermutation.jl")
98
include("BaseExtensions/BaseExtensions.jl")

src/factorizations.jl

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,26 @@
11
using LinearAlgebra: LinearAlgebra
2-
using .MatrixAlgebra: MatrixAlgebra
32
using MatrixAlgebraKit: MatrixAlgebraKit
43

5-
function factorize_with(f, A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
6-
# tensor to matrix
7-
A_mat = fusedims(A, biperm)
8-
9-
# factorization
10-
X, Y = f(A_mat; kwargs...)
11-
12-
# matrix to tensor
13-
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
14-
axes_X = (axes_codomain..., axes(X, 2))
15-
axes_Y = (axes(Y, 1), axes_domain...)
16-
return splitdims(X, axes_X), splitdims(Y, axes_Y)
17-
end
18-
19-
for (f, f_mat) in (
20-
(:qr, :(MatrixAlgebra.qr)),
21-
(:lq, :(MatrixAlgebra.lq)),
22-
(:left_polar, :(MatrixAlgebra.left_polar)),
23-
(:right_polar, :(MatrixAlgebra.right_polar)),
24-
(:polar, :(MatrixAlgebra.polar)),
25-
(:left_orth, :(MatrixAlgebra.left_orth)),
26-
(:right_orth, :(MatrixAlgebra.right_orth)),
27-
(:orth, :(MatrixAlgebra.orth)),
28-
(:factorize, :(MatrixAlgebra.factorize)),
4+
for f in (
5+
:qr, :lq, :left_polar, :right_polar, :polar, :left_orth, :right_orth, :orth, :factorize
296
)
307
@eval begin
318
function $f(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
329
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
3310
return $f(A, biperm; kwargs...)
3411
end
3512
function $f(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
36-
return factorize_with($f_mat, A, biperm; kwargs...)
13+
# tensor to matrix
14+
A_mat = fusedims(A, biperm)
15+
16+
# factorization
17+
X, Y = MatrixAlgebra.$f(A_mat; kwargs...)
18+
19+
# matrix to tensor
20+
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
21+
axes_X = (axes_codomain..., axes(X, 2))
22+
axes_Y = (axes(Y, 1), axes_domain...)
23+
return splitdims(X, axes_X), splitdims(Y, axes_Y)
3724
end
3825
end
3926
end
@@ -209,9 +196,7 @@ function eigvals(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwa
209196
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
210197
return eigvals(A, biperm; kwargs...)
211198
end
212-
function eigvals(
213-
A::AbstractArray, biperm::BlockedPermutation{2}; ishermitian=nothing, kwargs...
214-
)
199+
function eigvals(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
215200
A_mat = fusedims(A, biperm)
216201
return MatrixAlgebra.eigvals!(A_mat; kwargs...)
217202
end

0 commit comments

Comments
 (0)