Skip to content

Commit 152ea71

Browse files
lkdvosJutho
andauthored
Improve DiagonalTensorMap constructors with similar_diagonal (#330)
* add `similar_diagonal` * update factorization * also update sectorvector * remove ambiguity * Update src/tensors/abstracttensor.jl Co-authored-by: Jutho <[email protected]> * fix import * purge some more constructors * Update src/tensors/abstracttensor.jl Co-authored-by: Jutho <[email protected]> * small oopsie --------- Co-authored-by: Jutho <[email protected]>
1 parent 06d32b0 commit 152ea71

File tree

7 files changed

+75
-16
lines changed

7 files changed

+75
-16
lines changed

src/factorizations/diagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function MAK.initialize_output(
7272
V_cod = fuse(codomain(t))
7373
V_dom = fuse(domain(t))
7474
U = similar(t, codomain(t) V_cod)
75-
S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod V_dom)
75+
S = similar_diagonal(t, real(scalartype(t)), V_cod)
7676
Vᴴ = similar(t, V_dom domain(t))
7777
return U, S, Vᴴ
7878
end

src/factorizations/factorizations.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ module Factorizations
66
export copy_oftype, factorisation_scalartype, one!, truncspace
77

88
using ..TensorKit
9-
using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector, blocktype, foreachblock, one!
9+
using ..TensorKit: AdjointTensorMap, SectorDict, SectorVector,
10+
blocktype, foreachblock, one!,
11+
similar_diagonal, similarstoragetype
1012

11-
using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal, svdvals, svdvals!, eigen, eigen!,
13+
using LinearAlgebra: LinearAlgebra, BlasFloat, Diagonal,
14+
svdvals, svdvals!, eigen, eigen!,
1215
isposdef, isposdef!, ishermitian
1316

1417
using TensorOperations: Index2Tuple

src/factorizations/matrixalgebrakit.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,45 +74,47 @@ end
7474
function MAK.initialize_output(::typeof(svd_compact!), t::AbstractTensorMap, ::AbstractAlgorithm)
7575
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
7676
U = similar(t, codomain(t) V_cod)
77-
S = DiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
77+
S = similar_diagonal(t, real(scalartype(t)), V_cod)
7878
Vᴴ = similar(t, V_dom domain(t))
7979
return U, S, Vᴴ
8080
end
8181

8282
function MAK.initialize_output(::typeof(svd_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
8383
V_cod = infimum(fuse(codomain(t)), fuse(domain(t)))
8484
T = real(scalartype(t))
85-
return SectorVector{T}(undef, V_cod)
85+
A = similarstoragetype(t, T)
86+
return SectorVector{T, sectortype(t), A}(undef, V_cod)
8687
end
8788

8889
# Eigenvalue decomposition
8990
# ------------------------
9091
function MAK.initialize_output(::typeof(eigh_full!), t::AbstractTensorMap, ::AbstractAlgorithm)
9192
V_D = fuse(domain(t))
92-
T = real(scalartype(t))
93-
D = DiagonalTensorMap{T}(undef, V_D)
93+
D = similar_diagonal(t, real(scalartype(t)), V_D)
9494
V = similar(t, codomain(t) V_D)
9595
return D, V
9696
end
9797

9898
function MAK.initialize_output(::typeof(eig_full!), t::AbstractTensorMap, ::AbstractAlgorithm)
9999
V_D = fuse(domain(t))
100100
Tc = complex(scalartype(t))
101-
D = DiagonalTensorMap{Tc}(undef, V_D)
101+
D = similar_diagonal(t, Tc, V_D)
102102
V = similar(t, Tc, codomain(t) V_D)
103103
return D, V
104104
end
105105

106106
function MAK.initialize_output(::typeof(eigh_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
107107
V_D = fuse(domain(t))
108108
T = real(scalartype(t))
109-
return SectorVector{T}(undef, V_D)
109+
A = similarstoragetype(t, T)
110+
return SectorVector{T, sectortype(t), A}(undef, V_D)
110111
end
111112

112113
function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::AbstractAlgorithm)
113114
V_D = fuse(domain(t))
114115
Tc = complex(scalartype(t))
115-
return SectorVector{Tc}(undef, V_D)
116+
A = similarstoragetype(t, Tc)
117+
return SectorVector{Tc, sectortype(t), A}(undef, V_D)
116118
end
117119

118120
# QR decomposition

src/factorizations/truncation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function MAK.truncate(
6767

6868
= similar(U, codomain(U) V_truncated)
6969
truncate_domain!(Ũ, U, ind)
70-
= DiagonalTensorMap{scalartype(S)}(undef, V_truncated)
70+
= similar_diagonal(S, V_truncated)
7171
truncate_diagonal!(S̃, S, ind)
7272
Ṽᴴ = similar(Vᴴ, V_truncated domain(Vᴴ))
7373
truncate_codomain!(Ṽᴴ, Vᴴ, ind)
@@ -132,7 +132,7 @@ for f! in (:eig_trunc!, :eigh_trunc!)
132132
ind = MAK.findtruncated(diagview(D), strategy)
133133
V_truncated = truncate_space(space(D, 1), ind)
134134

135-
= DiagonalTensorMap{scalartype(D)}(undef, V_truncated)
135+
= similar_diagonal(D, V_truncated)
136136
truncate_diagonal!(D̃, D, ind)
137137

138138
= similar(V, codomain(V) V_truncated)

src/tensors/abstracttensor.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,8 @@ The structure may be specified either as a single `HomSpace` argument or as `cod
491491
492492
By default, this will result in `TensorMap{T}(undef, V)` when custom objects do not
493493
specialize this method.
494+
495+
See also [`similar_diagonal`](@ref).
494496
""" Base.similar(::AbstractTensorMap, args...)
495497

496498
function Base.similar(
@@ -543,6 +545,51 @@ function Base.similar(
543545
return TensorMap{scalartype(TT)}(undef, cod, dom)
544546
end
545547

548+
# similar diagonal
549+
# ----------------
550+
# The implementation is again written for similar_diagonal(t, TorA, V::ElementarySpace) -> DiagonalTensorMap
551+
# and all other methods are just filling in default arguments
552+
@doc """
553+
similar_diagonal(t::AbstractTensorMap, [AorT=storagetype(t)], [V::ElementarySpace])
554+
555+
Creates an uninitialized mutable diagonal tensor with the given scalar or storagetype `AorT` and
556+
structure `V ← V`, based on the source tensormap. The second argument is optional and defaults
557+
to the given tensor's `storagetype`, while the third argument can only be omitted for square
558+
input tensors of space `V ← V`, to conform with the diagonal structure.
559+
560+
By default, this will result in `DiagonalTensorMap{T}(undef, V)` when custom objects do not
561+
specialize this method. Furthermore, the method will throw if the provided space is not compatible
562+
with a diagonal structure.
563+
564+
See also [`Base.similar`](@ref).
565+
""" similar_diagonal(::AbstractTensorMap, args...)
566+
567+
# 3 arguments
568+
function similar_diagonal(t::AbstractTensorMap, ::Type{TorA}, V::ElementarySpace) where {TorA}
569+
if TorA <: Number
570+
T = TorA
571+
A = similarstoragetype(t, T)
572+
elseif TorA <: DenseVector
573+
A = TorA
574+
T = scalartype(A)
575+
else
576+
throw(ArgumentError("Type $TorA not supported for similar"))
577+
end
578+
579+
return DiagonalTensorMap{T, spacetype(V), A}(undef, V)
580+
end
581+
582+
similar_diagonal(t::AbstractTensorMap) = similar_diagonal(t, similarstoragetype(t), _diagspace(t))
583+
similar_diagonal(t::AbstractTensorMap, V::ElementarySpace) = similar_diagonal(t, similarstoragetype(t), V)
584+
similar_diagonal(t::AbstractTensorMap, T::Type) = similar_diagonal(t, T, _diagspace(t))
585+
586+
function _diagspace(t)
587+
cod, dom = codomain(t), domain(t)
588+
length(cod) == 1 && cod == dom ||
589+
throw(ArgumentError("space does not support a DiagonalTensorMap"))
590+
return only(cod)
591+
end
592+
546593
# Equality and approximality
547594
#----------------------------
548595
function Base.:(==)(t1::AbstractTensorMap, t2::AbstractTensorMap)

src/tensors/diagonal.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ function DiagonalTensorMap(t::AbstractTensorMap{T, S, 1, 1}) where {T, S}
7878
return d
7979
end
8080

81-
Base.similar(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
82-
function Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T <: Number}
83-
return DiagonalTensorMap(similar(d.data, T), d.domain)
84-
end
81+
Base.similar(d::DiagonalTensorMap) = similar_diagonal(d)
82+
Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T} = similar_diagonal(d, T)
83+
84+
similar_diagonal(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
85+
similar_diagonal(d::DiagonalTensorMap, ::Type{T}) where {T <: Number} =
86+
DiagonalTensorMap(similar(d.data, T), d.domain)
8587

8688
# TODO: more constructors needed?
8789

src/tensors/sectorvector.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ function SectorVector{T}(::UndefInitializer, V::ElementarySpace) where {T}
1616
structure = diagonalblockstructure(V V)
1717
return SectorVector(data, structure)
1818
end
19+
function SectorVector{T, I, A}(::UndefInitializer, V::ElementarySpace) where {T, I, A <: AbstractVector{T}}
20+
data = A(undef, reduceddim(V))
21+
structure = diagonalblockstructure(V V)
22+
return SectorVector{T, I, A}(data, structure)
23+
end
1924

2025
Base.parent(v::SectorVector) = v.data
2126

0 commit comments

Comments
 (0)