Skip to content

Commit db13df8

Browse files
committed
Merge branch 'main' into lb/add_tutorial
2 parents 8a2c34f + 3a06898 commit db13df8

File tree

23 files changed

+510
-842
lines changed

23 files changed

+510
-842
lines changed

.github/workflows/CompatCheck.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
# https://github.com/julia-actions/julia-downgrade-compat/issues/25
2626
julia-version: ['1.10', '1.12']
2727
steps:
28-
- uses: actions/checkout@v5
28+
- uses: actions/checkout@v6
2929
- uses: julia-actions/setup-julia@v2
3030
with:
3131
version: ${{ matrix.julia-version }}

.github/workflows/Documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
os:
2121
- ubuntu-latest
2222
steps:
23-
- uses: actions/checkout@v5
23+
- uses: actions/checkout@v6
2424
- uses: julia-actions/setup-julia@latest
2525
with:
2626
version: ${{ matrix.version }}

.github/workflows/DocumentationCleanup.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
contents: write
1717
steps:
1818
- name: Checkout gh-pages branch
19-
uses: actions/checkout@v5
19+
uses: actions/checkout@v6
2020
with:
2121
ref: gh-pages
2222
- name: Delete preview and history + push changes

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Combinatorics = "1"
3535
FiniteDifferences = "0.12"
3636
LRUCache = "1.0.2"
3737
LinearAlgebra = "1"
38-
MatrixAlgebraKit = "0.5.0"
38+
MatrixAlgebraKit = "0.6.0"
3939
OhMyThreads = "0.8.0"
4040
PackageExtensionCompat = "1"
4141
Printf = "1"

docs/src/lib/spaces.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ the resuling `HomSpace` after applying certain tensor operations.
123123

124124
```@docs
125125
flip(W::HomSpace{S}, I) where {S}
126-
TensorKit.permute(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
126+
TensorKit.permute(::HomSpace, ::Index2Tuple)
127127
TensorKit.select(::HomSpace{S}, ::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
128128
TensorKit.compose(::HomSpace{S}, ::HomSpace{S}) where {S}
129129
insertleftunit(::HomSpace, ::Val{i}) where {i}

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ function ChainRulesCore.rrule(
6969
return permute(tsrc, p; copy = true), permute_pullback
7070
end
7171

72+
function ChainRulesCore.rrule(
73+
::typeof(transpose), tsrc::AbstractTensorMap, p::Index2Tuple; copy::Bool = false
74+
)
75+
function transpose_pullback(Δtdst)
76+
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc)
77+
return NoTangent(), transpose(unthunk(Δtdst), invp; copy = true), NoTangent()
78+
end
79+
return transpose(tsrc, p; copy = true), transpose_pullback
80+
end
81+
7282
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
7383
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
7484
return tr(A), tr_pullback

src/TensorKit.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ export left_orth, right_orth, left_null, right_null,
7979
qr_full!, qr_compact!, qr_null!, lq_full!, lq_compact!, lq_null!,
8080
svd_compact!, svd_full!, svd_trunc!, svd_compact, svd_full, svd_trunc,
8181
exp, exp!,
82-
eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eig_full!, eig_full, eig_trunc!,
83-
eig_trunc,
84-
eigh_vals!, eigh_vals, eig_vals!, eig_vals,
85-
isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond
82+
eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals,
83+
eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals,
84+
ishermitian, project_hermitian, project_hermitian!,
85+
isantihermitian, project_antihermitian, project_antihermitian!,
86+
isisometric, isunitary, project_isometric, project_isometric!,
87+
isposdef, isposdef!, sylvester, rank, cond
8688

8789
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
8890
repartition!
@@ -135,7 +137,7 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr,
135137
adjoint, adjoint!, transpose, transpose!,
136138
lu, pinv, sylvester,
137139
eigen, eigen!, svd, svd!,
138-
isposdef, isposdef!, ishermitian, rank, cond,
140+
isposdef, isposdef!, rank, cond,
139141
Diagonal, Hermitian
140142
using MatrixAlgebraKit
141143

@@ -176,8 +178,11 @@ struct SpaceMismatch{S <: Union{Nothing, AbstractString}} <: TensorException
176178
message::S
177179
end
178180
SpaceMismatch() = SpaceMismatch{Nothing}(nothing)
179-
Base.showerror(io::IO, ::SpaceMismatch{Nothing}) = print(io, "SpaceMismatch()")
180-
Base.showerror(io::IO, e::SpaceMismatch) = print(io, "SpaceMismatch(\"", e.message, "\")")
181+
function Base.showerror(io::IO, err::SpaceMismatch)
182+
print(io, "SpaceMismatch: ")
183+
isnothing(err.message) || print(io, err.message)
184+
return nothing
185+
end
181186

182187
# Exception type for all errors related to invalid tensor index specification.
183188
struct IndexError{S <: Union{Nothing, AbstractString}} <: TensorException

src/auxiliary/deprecate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ function tsvd(t::AbstractTensorMap; kwargs...)
186186
Base.depwarn("p is a deprecated kwarg, and should be specified through the truncation strategy", :tsvd)
187187
kwargs = _drop_p(; kwargs...)
188188
end
189-
return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : svd_compact(t; kwargs...)
189+
return haskey(kwargs, :trunc) ? svd_trunc(t; kwargs...) : (svd_compact(t; kwargs...)..., abs(zero(scalartype(t))))
190190
end
191191
function tsvd!(t::AbstractTensorMap; kwargs...)
192192
Base.depwarn("`tsvd!` is deprecated, use `svd_compact!`, `svd_full!` or `svd_trunc!` instead", :tsvd!)

src/factorizations/adjoint.jl

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,54 @@ _adjoint(alg::MAK.LAPACK_HouseholderQR) = MAK.LAPACK_HouseholderLQ(; alg.kwargs.
66
_adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs...)
77
_adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...)
88
_adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...)
9-
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svdalg))
9+
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg))
1010
_adjoint(alg::AbstractAlgorithm) = alg
1111

12-
# 1-arg functions
13-
function MAK.initialize_output(::typeof(left_null!), t::AdjointTensorMap, alg::AbstractAlgorithm)
14-
return adjoint(MAK.initialize_output(right_null!, adjoint(t), _adjoint(alg)))
15-
end
16-
function MAK.initialize_output(
17-
::typeof(right_null!), t::AdjointTensorMap,
18-
alg::AbstractAlgorithm
19-
)
20-
return adjoint(MAK.initialize_output(left_null!, adjoint(t), _adjoint(alg)))
12+
for f in
13+
[
14+
:svd_compact, :svd_full, :svd_vals,
15+
:qr_compact, :qr_full, :qr_null,
16+
:lq_compact, :lq_full, :lq_null,
17+
:eig_full, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
18+
:left_polar, :right_polar,
19+
:project_hermitian, :project_antihermitian, :project_isometric,
20+
]
21+
f! = Symbol(f, :!)
22+
# just return the algorithm for the parent type since we are mapping this with
23+
# `_adjoint` afterwards anyways.
24+
# TODO: properly handle these cases
25+
@eval MAK.default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T <: AdjointTensorMap} =
26+
MAK.default_algorithm($f!, TensorKit.parenttype(T); kwargs...)
2127
end
2228

23-
function MAK.left_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
24-
right_null!(adjoint(t), adjoint(N), _adjoint(alg))
25-
return N
26-
end
27-
function MAK.right_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm)
28-
left_null!(adjoint(t), adjoint(N), _adjoint(alg))
29-
return N
30-
end
29+
# 1-arg functions
30+
MAK.initialize_output(::typeof(qr_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) =
31+
adjoint(MAK.initialize_output(lq_null!, adjoint(t), _adjoint(alg)))
32+
MAK.initialize_output(::typeof(lq_null!), t::AdjointTensorMap, alg::AbstractAlgorithm) =
33+
adjoint(MAK.initialize_output(qr_null!, adjoint(t), _adjoint(alg)))
3134

32-
function MAK.is_left_isometry(t::AdjointTensorMap; kwargs...)
33-
return is_right_isometry(adjoint(t); kwargs...)
34-
end
35-
function MAK.is_right_isometry(t::AdjointTensorMap; kwargs...)
36-
return is_left_isometry(adjoint(t); kwargs...)
37-
end
35+
MAK.qr_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) =
36+
lq_null!(adjoint(t), adjoint(N), _adjoint(alg))
37+
MAK.lq_null!(t::AdjointTensorMap, N, alg::AbstractAlgorithm) =
38+
qr_null!(adjoint(t), adjoint(N), _adjoint(alg))
39+
40+
MAK.is_left_isometric(t::AdjointTensorMap; kwargs...) =
41+
MAK.is_right_isometric(adjoint(t); kwargs...)
42+
MAK.is_right_isometric(t::AdjointTensorMap; kwargs...) =
43+
MAK.is_left_isometric(adjoint(t); kwargs...)
3844

3945
# 2-arg functions
40-
for (left_f!, right_f!) in zip(
41-
(:qr_full!, :qr_compact!, :left_polar!, :left_orth!),
42-
(:lq_full!, :lq_compact!, :right_polar!, :right_orth!)
46+
for (left_f, right_f) in zip(
47+
(:qr_full, :qr_compact, :left_polar),
48+
(:lq_full, :lq_compact, :right_polar)
4349
)
44-
@eval function MAK.copy_input(::typeof($left_f!), t::AdjointTensorMap)
45-
return adjoint(MAK.copy_input($right_f!, adjoint(t)))
50+
left_f! = Symbol(left_f, :!)
51+
right_f! = Symbol(right_f, :!)
52+
@eval function MAK.copy_input(::typeof($left_f), t::AdjointTensorMap)
53+
return adjoint(MAK.copy_input($right_f, adjoint(t)))
4654
end
47-
@eval function MAK.copy_input(::typeof($right_f!), t::AdjointTensorMap)
48-
return adjoint(MAK.copy_input($left_f!, adjoint(t)))
55+
@eval function MAK.copy_input(::typeof($right_f), t::AdjointTensorMap)
56+
return adjoint(MAK.copy_input($left_f, adjoint(t)))
4957
end
5058

5159
@eval function MAK.initialize_output(
@@ -60,29 +68,31 @@ for (left_f!, right_f!) in zip(
6068
end
6169

6270
@eval function MAK.$left_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
63-
$right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
64-
return F
71+
F′ = $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
72+
return reverse(adjoint.(F′))
6573
end
6674
@eval function MAK.$right_f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
67-
$left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
68-
return F
75+
F′ = $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
76+
return reverse(adjoint.(F′))
6977
end
7078
end
7179

7280
# 3-arg functions
73-
for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
74-
@eval function MAK.copy_input(::typeof($f!), t::AdjointTensorMap)
75-
return adjoint(MAK.copy_input($f!, adjoint(t)))
81+
for f in (:svd_full, :svd_compact)
82+
f! = Symbol(f, :!)
83+
@eval function MAK.copy_input(::typeof($f), t::AdjointTensorMap)
84+
return adjoint(MAK.copy_input($f, adjoint(t)))
7685
end
7786

7887
@eval function MAK.initialize_output(
7988
::typeof($f!), t::AdjointTensorMap, alg::AbstractAlgorithm
8089
)
8190
return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg))))
8291
end
92+
8393
@eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
84-
$f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
85-
return F
94+
F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
95+
return reverse(adjoint.(F′))
8696
end
8797

8898
# disambiguate by prohibition
@@ -92,17 +102,9 @@ for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
92102
throw(MethodError($f!, (t, alg)))
93103
end
94104
end
105+
95106
# avoid amgiguity
96-
function MAK.initialize_output(
97-
::typeof(svd_trunc!), t::AdjointTensorMap, alg::TruncatedAlgorithm
98-
)
99-
return MAK.initialize_output(svd_compact!, t, alg.alg)
100-
end
101-
# to fix ambiguity
102-
function MAK.svd_trunc!(t::AdjointTensorMap, USVᴴ, alg::TruncatedAlgorithm)
103-
USVᴴ′ = svd_compact!(t, USVᴴ, alg.alg)
104-
return MAK.truncate(svd_trunc!, USVᴴ′, alg.trunc)
105-
end
106-
function MAK.svd_compact!(t::AdjointTensorMap, USVᴴ, alg::DiagonalAlgorithm)
107-
return MAK.svd_compact!(t, USVᴴ, alg.alg)
107+
function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm)
108+
F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
109+
return reverse(adjoint.(F′))
108110
end

src/factorizations/diagonal.jl

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -81,55 +81,20 @@ for f! in
8181
:eigh_trunc!, :right_orth!, :left_orth!,
8282
)
8383
@eval function MAK.$f!(d::DiagonalTensorMap, F, alg::DiagonalAlgorithm)
84-
MAK.check_input($f!, d, F, alg)
8584
$f!(_repack_diagonal(d), _repack_diagonal.(F), alg)
8685
return F
8786
end
8887
end
8988

90-
for f! in (:qr_full!, :qr_compact!)
91-
@eval function MAK.check_input(
92-
::typeof($f!), d::AbstractTensorMap, QR, ::DiagonalAlgorithm
93-
)
94-
Q, R = QR
95-
@assert d isa DiagonalTensorMap
96-
@assert Q isa DiagonalTensorMap && R isa DiagonalTensorMap
97-
@check_scalar Q d
98-
@check_scalar R d
99-
@check_space(Q, space(d))
100-
@check_space(R, space(d))
101-
102-
return nothing
103-
end
104-
end
105-
106-
for f! in (:lq_full!, :lq_compact!)
107-
@eval function MAK.check_input(
108-
::typeof($f!), d::AbstractTensorMap, LQ, ::DiagonalAlgorithm
109-
)
110-
L, Q = LQ
111-
@assert d isa DiagonalTensorMap
112-
@assert Q isa DiagonalTensorMap && L isa DiagonalTensorMap
113-
@check_scalar Q d
114-
@check_scalar L d
115-
@check_space(Q, space(d))
116-
@check_space(L, space(d))
117-
118-
return nothing
119-
end
120-
end
121-
12289
# disambiguate
12390
function MAK.svd_compact!(t::AbstractTensorMap, USVᴴ, alg::DiagonalAlgorithm)
12491
return svd_full!(t, USVᴴ, alg)
12592
end
12693

12794
# f_vals
12895
# ------
129-
13096
for f! in (:eig_vals!, :eigh_vals!, :svd_vals!)
13197
@eval function MAK.$f!(d::AbstractTensorMap, V, alg::DiagonalAlgorithm)
132-
MAK.check_input($f!, d, V, alg)
13398
$f!(_repack_diagonal(d), diagview(_repack_diagonal(V)), alg)
13499
return V
135100
end
@@ -140,64 +105,3 @@ for f! in (:eig_vals!, :eigh_vals!, :svd_vals!)
140105
return DiagonalTensorMap(data, d.domain)
141106
end
142107
end
143-
144-
function MAK.check_input(::typeof(eig_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm)
145-
domain(t) == codomain(t) ||
146-
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
147-
148-
D, V = DV
149-
150-
@assert D isa DiagonalTensorMap
151-
@assert V isa AbstractTensorMap
152-
153-
# scalartype checks
154-
@check_scalar D t
155-
@check_scalar V t
156-
157-
# space checks
158-
@check_space D space(t)
159-
@check_space V space(t)
160-
161-
return nothing
162-
end
163-
164-
function MAK.check_input(::typeof(eigh_full!), t::AbstractTensorMap, DV, ::DiagonalAlgorithm)
165-
domain(t) == codomain(t) ||
166-
throw(ArgumentError("Eigenvalue decomposition requires square input tensor"))
167-
168-
D, V = DV
169-
170-
@assert D isa DiagonalTensorMap
171-
@assert V isa AbstractTensorMap
172-
173-
# scalartype checks
174-
@check_scalar D t real
175-
@check_scalar V t
176-
177-
# space checks
178-
@check_space D space(t)
179-
@check_space V space(t)
180-
181-
return nothing
182-
end
183-
184-
function MAK.check_input(::typeof(eig_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
185-
@assert D isa DiagonalTensorMap
186-
@check_scalar D t
187-
@check_space D space(t)
188-
return nothing
189-
end
190-
191-
function MAK.check_input(::typeof(eigh_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
192-
@assert D isa DiagonalTensorMap
193-
@check_scalar D t real
194-
@check_space D space(t)
195-
return nothing
196-
end
197-
198-
function MAK.check_input(::typeof(svd_vals!), t::AbstractTensorMap, D, ::DiagonalAlgorithm)
199-
@assert D isa DiagonalTensorMap
200-
@check_scalar D t real
201-
@check_space D space(t)
202-
return nothing
203-
end

0 commit comments

Comments
 (0)