Skip to content

Commit ce5aae5

Browse files
committed
add similar_diagonal
1 parent 06d32b0 commit ce5aae5

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

src/tensors/abstracttensor.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,8 @@ The structure may be specified either as a single `HomSpace` argument or as `cod
491491
492492
By default, this will result in `TensorMap{T}(undef, V)` when custom objects do not
493493
specialize this method.
494+
495+
See also [`similar_diagonal`](@ref).
494496
""" Base.similar(::AbstractTensorMap, args...)
495497

496498
function Base.similar(
@@ -543,6 +545,70 @@ function Base.similar(
543545
return TensorMap{scalartype(TT)}(undef, cod, dom)
544546
end
545547

548+
# similar diagonal
549+
# ----------------
550+
# The implementation is again written for similar_diagonal(t, TorA, V::TensorSpace) -> DiagonalTensorMap
551+
# and all other methods are just filling in default arguments
552+
@doc """
553+
similar_diagonal(t::AbstractTensorMap, [AorT=storagetype(t)], [V=space(t)])
554+
similar_diagonal(t::AbstractTensorMap, [AorT=storagetype(t)], codomain, domain)
555+
556+
Creates an uninitialized mutable diagonal tensor with the given scalar or storagetype `AorT` and
557+
structure `V` or `codomain ← domain`, based on the source tensormap. The second and third
558+
arguments are both optional, defaulting to the given tensor's `storagetype` and `space`.
559+
The structure may be specified either as a single `HomSpace` argument or as `codomain` and
560+
`domain`.
561+
562+
By default, this will result in `DiagonalTensorMap{T}(undef, V)` when custom objects do not
563+
specialize this method. Furthermore, the method will throw if the provided space is not compatible
564+
with a diagonal structure.
565+
566+
See also [`Base.similar`](@ref).
567+
""" similar_diagonal(::AbstractTensorMap, args...)
568+
569+
# 4 arguments
570+
function similar_diagonal(t::AbstractTensorMap, ::Type{T}, cod::TensorSpace, dom::TensorSpace) where {T}
571+
length(cod) == length(dom) == 1 && cod == dom ||
572+
throw(ArgumentError("requested space is not square"))
573+
return similar_diagonal(t, T, cod)
574+
end
575+
576+
# 3 arguments
577+
function similar_diagonal(t::AbstractTensorMap, ::Type{T}, V::TensorMapSpace) where {T}
578+
numout(V) == numin(V) == 1 && domain(V) == codomain(V) ||
579+
throw(ArgumentError("requested space is not square"))
580+
return similar_diagonal(t, T, codomain(V))
581+
end
582+
function similar_diagonal(t::AbstractTensorMap, ::Type{T}, V::ProductSpace) where {T}
583+
length(V) == 1 || throw(ArgumentError())
584+
return similar_diagonal(t, T, only(V.spaces))
585+
end
586+
function similar_diagonal(t::AbstractTensorMap, ::Type{TorA}, V::ElementarySpace) where {TorA}
587+
if TorA <: Number
588+
T = TorA
589+
A = similarstoragetype(t, T)
590+
elseif TorA <: DenseVector
591+
A = TorA
592+
T = scalartype(A)
593+
else
594+
throw(ArgumentError("Type $TorA not supported for similar"))
595+
end
596+
597+
return DiagonalTensorMap{T, spacetype(V), A}(undef, V)
598+
end
599+
600+
# 2 arguments
601+
similar_diagonal(t::AbstractTensorMap, ::Type{T}) where {T} =
602+
similar_diagonal(t, T, space(t))
603+
similar_diagonal(t::AbstractTensorMap, P::TensorMapSpace) =
604+
similar_diagonal(t, similarstoragetype(t), P)
605+
similar_diagonal(t::AbstractTensorMap, P::TensorSpace) =
606+
similar_diagonal(t, similarstoragetype(t), P)
607+
608+
# 1 argument
609+
similar_diagonal(t::AbstractTensorMap) =
610+
similar_diagonal(t, similarstoragetype(t), space(t))
611+
546612
# Equality and approximality
547613
#----------------------------
548614
function Base.:(==)(t1::AbstractTensorMap, t2::AbstractTensorMap)

src/tensors/diagonal.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ function DiagonalTensorMap(t::AbstractTensorMap{T, S, 1, 1}) where {T, S}
7878
return d
7979
end
8080

81-
Base.similar(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
82-
function Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T <: Number}
83-
return DiagonalTensorMap(similar(d.data, T), d.domain)
84-
end
81+
Base.similar(d::DiagonalTensorMap) = similar_diagonal(d)
82+
Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T} = similar_diagonal(d, T)
83+
84+
similar_diagonal(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
85+
similar_diagonal(d::DiagonalTensorMap, ::Type{T}) where {T <: Number} =
86+
DiagonalTensorMap(similar(d.data, T), d.domain)
8587

8688
# TODO: more constructors needed?
8789

0 commit comments

Comments
 (0)