Skip to content

Commit c2b7d3f

Browse files
committed
more careful with import and exports
1 parent 9f6761f commit c2b7d3f

File tree

8 files changed

+217
-225
lines changed

8 files changed

+217
-225
lines changed

src/TensorKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon
109109
using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend
110110
const TO = TensorOperations
111111

112-
using MatrixAlgebraKit: MatrixAlgebraKit as MAK
112+
using MatrixAlgebraKit
113113

114114
using LRUCache
115115
using OhMyThreads

src/tensors/factorizations/adjoint.jl

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,96 +2,99 @@
22
# ----------------
33
# map algorithms to their adjoint counterpart
44
# TODO: this probably belongs in MatrixAlgebraKit
5-
_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.kwargs...)
6-
_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.kwargs...)
7-
_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.kwargs...)
8-
_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.kwargs...)
9-
_adjoint(alg::PolarViaSVD) = PolarViaSVD(_adjoint(alg.svdalg))
5+
_adjoint(alg::MAK.LAPACK_HouseholderQR) = MAK.LAPACK_HouseholderLQ(; alg.kwargs...)
6+
_adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs...)
7+
_adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...)
8+
_adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...)
9+
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svdalg))
1010
_adjoint(alg::AbstractAlgorithm) = alg
1111

1212
# 1-arg functions
13-
function initialize_output(::typeof(left_null!), t::AdjointTensorMap,
13+
function MAK.initialize_output(::typeof(left_null!), t::AdjointTensorMap,
1414
alg::AbstractAlgorithm)
15-
return adjoint(initialize_output(right_null!, adjoint(t), _adjoint(alg)))
15+
return adjoint(MAK.initialize_output(right_null!, adjoint(t), _adjoint(alg)))
1616
end
17-
function initialize_output(::typeof(right_null!), t::AdjointTensorMap,
17+
function MAK.initialize_output(::typeof(right_null!), t::AdjointTensorMap,
1818
alg::AbstractAlgorithm)
19-
return adjoint(initialize_output(left_null!, adjoint(t), _adjoint(alg)))
19+
return adjoint(MAK.initialize_output(left_null!, adjoint(t), _adjoint(alg)))
2020
end
2121

22-
function left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
22+
function MAK.left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
2323
right_null!(adjoint(t), adjoint(N), _adjoint(alg))
2424
return N
2525
end
26-
function right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
26+
function MAK.right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
2727
left_null!(adjoint(t), adjoint(N), _adjoint(alg))
2828
return N
2929
end
3030

31-
function MatrixAlgebraKit.is_left_isometry(t::AdjointTensorMap; kwargs...)
31+
function MAK.is_left_isometry(t::AdjointTensorMap; kwargs...)
3232
return is_right_isometry(adjoint(t); kwargs...)
3333
end
34-
function MatrixAlgebraKit.is_right_isometry(t::AdjointTensorMap; kwargs...)
34+
function MAK.is_right_isometry(t::AdjointTensorMap; kwargs...)
3535
return is_left_isometry(adjoint(t); kwargs...)
3636
end
3737

3838
# 2-arg functions
3939
for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!),
4040
(:lq_full!, :lq_compact!, :right_polar!, :right_orth!))
41-
@eval function copy_input(::typeof($left_f!), t::AdjointTensorMap)
42-
return adjoint(copy_input($right_f!, adjoint(t)))
41+
@eval function MAK.copy_input(::typeof($left_f!), t::AdjointTensorMap)
42+
return adjoint(MAK.copy_input($right_f!, adjoint(t)))
4343
end
44-
@eval function copy_input(::typeof($right_f!), t::AdjointTensorMap)
45-
return adjoint(copy_input($left_f!, adjoint(t)))
44+
@eval function MAK.copy_input(::typeof($right_f!), t::AdjointTensorMap)
45+
return adjoint(MAK.copy_input($left_f!, adjoint(t)))
4646
end
4747

48-
@eval function initialize_output(::typeof($left_f!), t::AdjointTensorMap,
48+
@eval function MAK.initialize_output(::typeof($left_f!), t::AdjointTensorMap,
4949
alg::AbstractAlgorithm)
50-
return reverse(adjoint.(initialize_output($right_f!, adjoint(t), _adjoint(alg))))
50+
return reverse(adjoint.(MAK.initialize_output($right_f!, adjoint(t), _adjoint(alg))))
5151
end
52-
@eval function initialize_output(::typeof($right_f!), t::AdjointTensorMap,
52+
@eval function MAK.initialize_output(::typeof($right_f!), t::AdjointTensorMap,
5353
alg::AbstractAlgorithm)
54-
return reverse(adjoint.(initialize_output($left_f!, adjoint(t), _adjoint(alg))))
54+
return reverse(adjoint.(MAK.initialize_output($left_f!, adjoint(t), _adjoint(alg))))
5555
end
5656

57-
@eval function $left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
57+
@eval function MAK.$left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
5858
$right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
5959
return F
6060
end
61-
@eval function $right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
61+
@eval function MAK.$right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
6262
$left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
6363
return F
6464
end
6565
end
6666

6767
# 3-arg functions
6868
for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
69-
@eval function copy_input(::typeof($f!), t::AdjointTensorMap)
70-
return adjoint(copy_input($f!, adjoint(t)))
69+
@eval function MAK.copy_input(::typeof($f!), t::AdjointTensorMap)
70+
return adjoint(MAK.copy_input($f!, adjoint(t)))
7171
end
7272

73-
@eval function initialize_output(::typeof($f!), t::AdjointTensorMap,
73+
@eval function MAK.initialize_output(::typeof($f!), t::AdjointTensorMap,
7474
alg::AbstractAlgorithm)
75-
return reverse(adjoint.(initialize_output($f!, adjoint(t), _adjoint(alg))))
75+
return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg))))
7676
end
77-
@eval function $f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
77+
@eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
7878
$f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
7979
return F
8080
end
8181

8282
# disambiguate by prohibition
83-
@eval function initialize_output(::typeof($f!), t::AdjointTensorMap,
83+
@eval function MAK.initialize_output(::typeof($f!), t::AdjointTensorMap,
8484
alg::DiagonalAlgorithm)
8585
throw(MethodError($f!, (t, alg)))
8686
end
8787
end
8888
# avoid amgiguity
89-
function initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap,
89+
function MAK.initialize_output(::typeof(svd_trunc!), t::AdjointTensorMap,
9090
alg::TruncatedAlgorithm)
91-
return initialize_output(svd_compact!, t, alg.alg)
91+
return MAK.initialize_output(svd_compact!, t, alg.alg)
9292
end
9393
# to fix ambiguity
94-
function svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm)
94+
function MAK.svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm)
9595
USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg)
96-
return truncate(svd_trunc!, USVᴴ′, alg.trunc)
96+
return MAK.truncate(svd_trunc!, USVᴴ′, alg.trunc)
97+
end
98+
function MAK.svd_compact!(t::AdjointTensorMap, USVᴴ, alg::DiagonalAlgorithm)
99+
return MAK.svd_compact!(t, USVᴴ, alg.alg)
97100
end

src/tensors/factorizations/diagonal.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ _repack_diagonal(d::DiagonalTensorMap) = Diagonal(d.data)
55
for f in (:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null,
66
:lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full,
77
:eigh_trunc, :eigh_vals, :left_polar, :right_polar)
8-
@eval copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d)
8+
@eval MAK.copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d)
99
end
1010

1111
for f! in (:eig_full!, :eig_trunc!)
12-
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap,
12+
@eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap,
1313
::DiagonalAlgorithm)
1414
return d, similar(d)
1515
end
1616
end
1717

1818
for f! in (:eigh_full!, :eigh_trunc!)
19-
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap,
19+
@eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap,
2020
::DiagonalAlgorithm)
2121
if scalartype(d) <: Real
2222
return d, similar(d)
@@ -27,36 +27,36 @@ for f! in (:eigh_full!, :eigh_trunc!)
2727
end
2828

2929
for f! in (:qr_full!, :qr_compact!)
30-
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap,
30+
@eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap,
3131
::DiagonalAlgorithm)
3232
return d, similar(d)
3333
end
3434
# to avoid ambiguities
35-
@eval function initialize_output(::typeof($f!), d::AdjointTensorMap,
35+
@eval function MAK.initialize_output(::typeof($f!), d::AdjointTensorMap,
3636
::DiagonalAlgorithm)
3737
return d, similar(d)
3838
end
3939
end
4040
for f! in (:lq_full!, :lq_compact!)
41-
@eval function initialize_output(::typeof($f!), d::AbstractTensorMap,
41+
@eval function MAK.initialize_output(::typeof($f!), d::AbstractTensorMap,
4242
::DiagonalAlgorithm)
4343
return similar(d), d
4444
end
4545
# to avoid ambiguities
46-
@eval function initialize_output(::typeof($f!), d::AdjointTensorMap,
46+
@eval function MAK.initialize_output(::typeof($f!), d::AdjointTensorMap,
4747
::DiagonalAlgorithm)
4848
return similar(d), d
4949
end
5050
end
5151

52-
function initialize_output(::typeof(left_orth!), d::DiagonalTensorMap)
52+
function MAK.initialize_output(::typeof(left_orth!), d::DiagonalTensorMap)
5353
return d, similar(d)
5454
end
55-
function initialize_output(::typeof(right_orth!), d::DiagonalTensorMap)
55+
function MAK.initialize_output(::typeof(right_orth!), d::DiagonalTensorMap)
5656
return similar(d), d
5757
end
5858

59-
function initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::DiagonalAlgorithm)
59+
function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::DiagonalAlgorithm)
6060
V_cod = fuse(codomain(t))
6161
V_dom = fuse(domain(t))
6262
U = similar(t, codomain(t) V_cod)
@@ -68,15 +68,15 @@ end
6868
for f! in
6969
(:qr_full!, :qr_compact!, :lq_full!, :lq_compact!, :eig_full!, :eig_trunc!, :eigh_full!,
7070
:eigh_trunc!, :right_orth!, :left_orth!)
71-
@eval function $f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm)
72-
check_input($f!, d, F, alg)
71+
@eval function MAK.$f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm)
72+
MAK.check_input($f!, d, F, alg)
7373
$f!(_repack_diagonal(d), _repack_diagonal.(F), alg)
7474
return F
7575
end
7676
end
7777

7878
for f! in (:qr_full!, :qr_compact!)
79-
@eval function check_input(::typeof($f!), d::AbstractTensorMap, QR,
79+
@eval function MAK.check_input(::typeof($f!), d::AbstractTensorMap, QR,
8080
::DiagonalAlgorithm)
8181
Q, R = QR
8282
@assert d isa DiagonalTensorMap
@@ -91,7 +91,7 @@ for f! in (:qr_full!, :qr_compact!)
9191
end
9292

9393
for f! in (:lq_full!, :lq_compact!)
94-
@eval function check_input(::typeof($f!), d::AbstractTensorMap, LQ,
94+
@eval function MAK.check_input(::typeof($f!), d::AbstractTensorMap, LQ,
9595
::DiagonalAlgorithm)
9696
L, Q = LQ
9797
@assert d isa DiagonalTensorMap
@@ -106,25 +106,25 @@ for f! in (:lq_full!, :lq_compact!)
106106
end
107107

108108
# disambiguate
109-
svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm) = svd_full!(t, USVᴴ, alg)
109+
MAK.svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm) = svd_full!(t, USVᴴ, alg)
110110

111111
# f_vals
112112
# ------
113113

114114
for f! in (:eig_vals!, :eigh_vals!, :svd_vals!)
115-
@eval function $f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm)
116-
check_input($f!, d, V, alg)
115+
@eval function MAK.$f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm)
116+
MAK.check_input($f!, d, V, alg)
117117
$f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg)
118118
return V
119119
end
120-
@eval function initialize_output(::typeof($f!), d::DiagonalTensorMap,
120+
@eval function MAK.initialize_output(::typeof($f!), d::DiagonalTensorMap,
121121
alg::DiagonalAlgorithm)
122-
data = initialize_output($f!, _repack_diagonal(d), alg)
122+
data = MAK.initialize_output($f!, _repack_diagonal(d), alg)
123123
return DiagonalTensorMap(data, d.domain)
124124
end
125125
end
126126

127-
function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm)
127+
function MAK.check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm)
128128
domain(t) == codomain(t) ||
129129
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
130130

@@ -144,7 +144,7 @@ function check_input(::typeof(eig_full!), t::DiagonalTensorMap, DV, ::DiagonalAl
144144
return nothing
145145
end
146146

147-
function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm)
147+
function MAK.check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, ::DiagonalAlgorithm)
148148
domain(t) == codomain(t) ||
149149
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
150150

@@ -164,21 +164,21 @@ function check_input(::typeof(eigh_full!), t::DiagonalTensorMap, DV, ::DiagonalA
164164
return nothing
165165
end
166166

167-
function check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
167+
function MAK.check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
168168
@assert D isa DiagonalTensorMap
169169
@check_scalar D t
170170
@check_space D space(t)
171171
return nothing
172172
end
173173

174-
function check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
174+
function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
175175
@assert D isa DiagonalTensorMap
176176
@check_scalar D t real
177177
@check_space D space(t)
178178
return nothing
179179
end
180180

181-
function check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
181+
function MAK.check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
182182
@assert D isa DiagonalTensorMap
183183
@check_scalar D t real
184184
@check_space D space(t)

src/tensors/factorizations/factorizations.jl

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,24 @@
33
# using submodule here to import MatrixAlgebraKit functions without polluting namespace
44
module Factorizations
55

6-
export eig, eig!, eigh, eigh!
7-
export tsvd, tsvd!, svdvals, svdvals!
8-
export leftorth, leftorth!, rightorth, rightorth!
9-
export leftnull, leftnull!, rightnull, rightnull!
10-
export qr_full, qr_compact, qr_null
11-
export qr_full!, qr_compact!, qr_null!
12-
export lq_full, lq_compact, lq_null
13-
export lq_full!, lq_compact!, lq_null!
14-
export copy_oftype, factorisation_scalartype, one!
15-
export TruncationScheme, notrunc, trunctol, truncerror, truncrank, truncspace, truncfilter,
16-
PolarViaSVD
6+
export copy_oftype, factorisation_scalartype, one!, truncspace
177

188
using ..TensorKit
199
using ..TensorKit: AdjointTensorMap, SectorDict, blocktype, foreachblock, one!
2010

21-
using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!
22-
import LinearAlgebra: eigen, eigen!, isposdef, isposdef!, ishermitian
11+
using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!, eigen, eigen!,
12+
isposdef, isposdef!, ishermitian
2313

2414
using TensorOperations: Index2Tuple
2515

2616
using MatrixAlgebraKit
17+
import MatrixAlgebraKit as MAK
2718
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm
2819
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue,
2920
TruncationByError, TruncationIntersection, TruncationByFilter,
3021
TruncationByOrder
31-
using MatrixAlgebraKit: PolarViaSVD
32-
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, LAPACK_QRIteration, LAPACK_HouseholderQR,
33-
LAPACK_HouseholderLQ, LAPACK_HouseholderQL, LAPACK_HouseholderRQ
34-
import MatrixAlgebraKit: default_algorithm,
35-
copy_input, check_input, initialize_output,
36-
qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!,
37-
svd_compact!, svd_full!, svd_trunc!, svd_vals!,
38-
eigh_full!, eigh_trunc!, eigh_vals!,
39-
eig_full!, eig_trunc!, eig_vals!,
40-
left_polar!, left_orth_polar!, right_polar!, right_orth_polar!,
41-
left_null_svd!, right_null_svd!, left_orth_svd!, right_orth_svd!,
42-
left_orth!, right_orth!, left_null!, right_null!,
43-
truncate, findtruncated, findtruncated_svd,
44-
diagview, isisometry
45-
using MatrixAlgebraKit: qr_pullback!, qr_null_pullback!,
46-
lq_pullback!, lq_null_pullback!,
47-
svd_pullback!, svd_trunc_pullback!,
48-
eig_pullback!, eig_trunc_pullback!,
49-
eigh_pullback!, eigh_trunc_pullback!,
50-
left_polar_pullback!, right_polar_pullback!
22+
using MatrixAlgebraKit: left_orth_polar!, right_orth_polar!, left_orth_svd!,
23+
right_orth_svd!, left_null_svd!, right_null_svd!, diagview
5124

5225
include("utility.jl")
5326
include("matrixalgebrakit.jl")
@@ -58,7 +31,7 @@ include("pullbacks.jl")
5831

5932
TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A)
6033

61-
function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
34+
function MatrixAlgebraKit.isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
6235
t = permute(t, (p₁, p₂); copy=false)
6336
return isisometry(t)
6437
end
@@ -67,10 +40,10 @@ end
6740
# LinearAlgebra overloads
6841
#------------------------------#
6942

70-
function eigen(t::AbstractTensorMap; kwargs...)
43+
function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...)
7144
return ishermitian(t) ? eigh_full(t; kwargs...) : eig_full(t; kwargs...)
7245
end
73-
function eigen!(t::AbstractTensorMap; kwargs...)
46+
function LinearAlgebra.eigen!(t::AbstractTensorMap; kwargs...)
7447
return ishermitian(t) ? eigh_full!(t; kwargs...) : eig_full!(t; kwargs...)
7548
end
7649

0 commit comments

Comments
 (0)