Skip to content

Commit d60855e

Browse files
lkdvosJutho
andauthored
rework tensor contructors to allow storagetype specification (#327)
* rework tensor contructors try handle PtrArrays fix docstrings * reorganize constructors fix length check fix reshape Fix typos [skip ci] * rename `TensorMapWithStorage` * more strict with provided data * extra comments in docstrings * unify tensortype usage also update `tensoralloc` * uniformize into `similarstoragetype` fix ambiguity more careful with storagetypes more careful with tensoroperations even more careful the carefulest! * Apply suggestions from code review Co-authored-by: Jutho <[email protected]> * remove unused code * add note about 1 vs 2-arg versions * update similar_diagonal to same logic --------- Co-authored-by: Jutho <[email protected]>
1 parent 152ea71 commit d60855e

File tree

4 files changed

+331
-256
lines changed

4 files changed

+331
-256
lines changed

docs/src/lib/tensors.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor
3434

3535
A `TensorMap` with undefined data can be constructed by specifying its domain and codomain:
3636
```@docs
37-
TensorMap{T}(::UndefInitializer, V::TensorMapSpace{S,N₁,N₂}) where {T,S,N₁,N₂}
37+
TensorMap{T}(::UndefInitializer, V::TensorMapSpace)
3838
```
3939

4040
The resulting object can then be filled with data using the `setindex!` method as discussed
@@ -45,8 +45,8 @@ in an `@tensor output[...] = ...` expression.
4545
Alternatively, a `TensorMap` can be constructed by specifying its data, codmain and domain
4646
in one of the following ways:
4747
```@docs
48-
TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
49-
TensorMap(data::AbstractArray, V::TensorMapSpace{S,N₁,N₂}; tol) where {S<:IndexSpace,N₁,N₂}
48+
TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, V::TensorMapSpace)
49+
TensorMap(data::AbstractArray, V::TensorMapSpace; tol)
5050
```
5151

5252
Finally, we also support the following `Array`-like constructors

src/tensors/abstracttensor.jl

Lines changed: 131 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,63 @@ end
4545
Return the type of vector that stores the data of a tensor.
4646
""" storagetype
4747

48-
similarstoragetype(TT::Type{<:AbstractTensorMap}) = similarstoragetype(TT, scalartype(TT))
48+
# storage type determination and promotion - hooks for specializing
49+
# the default implementation tries to leverarge inference and `similar`
50+
@doc """
51+
similarstoragetype(t, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
52+
similarstoragetype(TT, [T = scalartype(TT)]) -> Type{<:DenseVector{T}}
53+
similarstoragetype(A, [T = scalartype(A)]) -> Type{<:DenseVector{T}}
54+
similarstoragetype(D, [T = scalartype(D)]) -> Type{<:DenseVector{T}}
55+
56+
similarstoragetype(T::Type{<:Number}) -> Vector{T}
57+
58+
For a given tensor `t`, tensor type `TT <: AbstractTensorMap`, array type `A <: AbstractArray`,
59+
or sector dictionary type `D <: AbstractDict{<:Sector, <:AbstractMatrix}`, compute an appropriate
60+
storage type for tensors. Optionally, a different scalar type `T` can be supplied as well.
61+
62+
This function determines the type of newly allocated `TensorMap`s throughout TensorKit.jl.
63+
It does so by leveraging type inference and calls to `Base.similar` for automatically determining
64+
appropriate storage types. Additionally this registers the default storage type when only a type
65+
`T <: Number` is provided, which is `Vector{T}`.
66+
67+
!!! note
68+
There is a slight semantic difference in the single and two-argument version. The former is
69+
used in constructor-like calls, and therefore will return the exact same type for a `DenseVector`
70+
input. The latter is used in `similar`-like calls, and therefore will return the type of calling
71+
`similar` on the given `DenseVector`, which need not coincide with the original type.
72+
""" similarstoragetype
73+
74+
# implement in type domain
75+
similarstoragetype(t) = similarstoragetype(typeof(t))
76+
similarstoragetype(t, ::Type{T}) where {T <: Number} = similarstoragetype(typeof(t), T)
77+
78+
# avoid infinite recursion
79+
similarstoragetype(X::Type) =
80+
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X`"))
81+
similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
82+
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X` and/or scalar type `$T`"))
83+
84+
# implement on tensors
85+
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
86+
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
87+
similarstoragetype(storagetype(TT), T)
88+
89+
# implement on arrays
90+
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
91+
Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} =
92+
Core.Compiler.return_type(similar, Tuple{A, Int})
93+
Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} =
94+
Core.Compiler.return_type(similar, Tuple{A, Type{T}, Int})
95+
96+
# implement on sectordicts
97+
similarstoragetype(::Type{D}) where {D <: AbstractDict{<:Sector, <:AbstractMatrix}} =
98+
similarstoragetype(valtype(D))
99+
similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:AbstractMatrix}, T <: Number} =
100+
similarstoragetype(valtype(D), T)
101+
102+
# default storage type for numbers
103+
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}
49104

50-
function similarstoragetype(TT::Type{<:AbstractTensorMap}, ::Type{T}) where {T}
51-
return Core.Compiler.return_type(similar, Tuple{storagetype(TT), Type{T}})
52-
end
53105

54106
# tensor characteristics: space and index information
55107
#-----------------------------------------------------
@@ -175,7 +227,6 @@ end
175227
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
176228
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
177229
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
178-
similarstoragetype(t::AbstractTensorMap, T = scalartype(t)) = similarstoragetype(typeof(t), T)
179230

180231
numout(t::AbstractTensorMap) = numout(typeof(t))
181232
numin(t::AbstractTensorMap) = numin(typeof(t))
@@ -496,61 +547,46 @@ See also [`similar_diagonal`](@ref).
496547
""" Base.similar(::AbstractTensorMap, args...)
497548

498549
function Base.similar(
499-
t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S}
500-
) where {T, S}
550+
t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace, domain::TensorSpace
551+
) where {T}
501552
return similar(t, T, codomain domain)
502553
end
554+
503555
# 3 arguments
504-
function Base.similar(
505-
t::AbstractTensorMap, codomain::TensorSpace{S}, domain::TensorSpace{S}
506-
) where {S}
507-
return similar(t, similarstoragetype(t), codomain domain)
508-
end
509-
function Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T}
510-
return similar(t, T, codomain one(codomain))
511-
end
556+
Base.similar(t::AbstractTensorMap, codomain::TensorSpace, domain::TensorSpace) =
557+
similar(t, similarstoragetype(t, scalartype(t)), codomain domain)
558+
Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T} =
559+
similar(t, T, codomain one(codomain))
560+
512561
# 2 arguments
513-
function Base.similar(t::AbstractTensorMap, codomain::TensorSpace)
514-
return similar(t, similarstoragetype(t), codomain one(codomain))
515-
end
516-
Base.similar(t::AbstractTensorMap, P::TensorMapSpace) = similar(t, storagetype(t), P)
562+
Base.similar(t::AbstractTensorMap, codomain::TensorSpace) =
563+
similar(t, codomain one(codomain))
564+
Base.similar(t::AbstractTensorMap, V::TensorMapSpace) = similar(t, scalartype(t), V)
517565
Base.similar(t::AbstractTensorMap, ::Type{T}) where {T} = similar(t, T, space(t))
518566
# 1 argument
519-
Base.similar(t::AbstractTensorMap) = similar(t, similarstoragetype(t), space(t))
567+
Base.similar(t::AbstractTensorMap) = similar(t, scalartype(t), space(t))
520568

521569
# generic implementation for AbstractTensorMap -> returns `TensorMap`
522-
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, P::TensorMapSpace{S}) where {TorA, S}
523-
if TorA <: Number
524-
T = TorA
525-
A = similarstoragetype(t, T)
526-
elseif TorA <: DenseVector
527-
A = TorA
528-
T = scalartype(A)
529-
else
530-
throw(ArgumentError("Type $TorA not supported for similar"))
531-
end
532-
533-
N₁ = length(codomain(P))
534-
N₂ = length(domain(P))
535-
return TensorMap{T, S, N₁, N₂, A}(undef, P)
570+
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, V::TensorMapSpace) where {TorA}
571+
A = TorA <: Number ? similarstoragetype(t, TorA) : TorA
572+
TT = tensormaptype(spacetype(V), numout(V), numin(V), A)
573+
return TT(undef, V)
536574
end
537575

538576
# implementation in type-domain
539-
function Base.similar(::Type{TT}, P::TensorMapSpace) where {TT <: AbstractTensorMap}
540-
return TensorMap{scalartype(TT)}(undef, P)
541-
end
542-
function Base.similar(
543-
::Type{TT}, cod::TensorSpace{S}, dom::TensorSpace{S}
544-
) where {TT <: AbstractTensorMap, S}
545-
return TensorMap{scalartype(TT)}(undef, cod, dom)
577+
function Base.similar(::Type{TT}, V::TensorMapSpace) where {TT <: AbstractTensorMap}
578+
TT′ = tensormaptype(spacetype(V), numout(V), numin(V), similarstoragetype(TT, scalartype(TT)))
579+
return TT′(undef, V)
546580
end
581+
Base.similar(::Type{TT}, cod::TensorSpace, dom::TensorSpace) where {TT <: AbstractTensorMap} =
582+
similar(TT, cod dom)
547583

548584
# similar diagonal
549585
# ----------------
550586
# The implementation is again written for similar_diagonal(t, TorA, V::ElementarySpace) -> DiagonalTensorMap
551587
# and all other methods are just filling in default arguments
552588
@doc """
553-
similar_diagonal(t::AbstractTensorMap, [AorT=storagetype(t)], [V::ElementarySpace])
589+
similar_diagonal(t::AbstractTensorMap, [AorT=scalartype(t)], [V::ElementarySpace])
554590
555591
Creates an uninitialized mutable diagonal tensor with the given scalar or storagetype `AorT` and
556592
structure `V ← V`, based on the source tensormap. The second argument is optional and defaults
@@ -566,21 +602,12 @@ See also [`Base.similar`](@ref).
566602

567603
# 3 arguments
568604
function similar_diagonal(t::AbstractTensorMap, ::Type{TorA}, V::ElementarySpace) where {TorA}
569-
if TorA <: Number
570-
T = TorA
571-
A = similarstoragetype(t, T)
572-
elseif TorA <: DenseVector
573-
A = TorA
574-
T = scalartype(A)
575-
else
576-
throw(ArgumentError("Type $TorA not supported for similar"))
577-
end
578-
579-
return DiagonalTensorMap{T, spacetype(V), A}(undef, V)
605+
A = similarstoragetype(TorA <: Number ? similarstoragetype(t, TorA) : TorA)
606+
return DiagonalTensorMap{scalartype(A), spacetype(V), A}(undef, V)
580607
end
581608

582-
similar_diagonal(t::AbstractTensorMap) = similar_diagonal(t, similarstoragetype(t), _diagspace(t))
583-
similar_diagonal(t::AbstractTensorMap, V::ElementarySpace) = similar_diagonal(t, similarstoragetype(t), V)
609+
similar_diagonal(t::AbstractTensorMap) = similar_diagonal(t, scalartype(t), _diagspace(t))
610+
similar_diagonal(t::AbstractTensorMap, V::ElementarySpace) = similar_diagonal(t, scalartype(t), V)
584611
similar_diagonal(t::AbstractTensorMap, T::Type) = similar_diagonal(t, T, _diagspace(t))
585612

586613
function _diagspace(t)
@@ -656,8 +683,8 @@ function Base.imag(t::AbstractTensorMap)
656683
end
657684
end
658685

659-
# Conversion to Array:
660-
#----------------------
686+
# Conversion to/from Array:
687+
#--------------------------
661688
# probably not optimized for speed, only for checking purposes
662689
function Base.convert(::Type{Array}, t::AbstractTensorMap)
663690
I = sectortype(t)
@@ -678,9 +705,55 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
678705
end
679706
end
680707

708+
"""
709+
project_symmetric!(t::AbstractTensorMap, data::AbstractArray) -> t
710+
711+
Project the data from a dense array `data` into the tensor map `t`. This function discards
712+
any data that does not fit the symmetry structure of `t`.
713+
"""
714+
function project_symmetric!(t::AbstractTensorMap, data::AbstractArray)
715+
# dimension check
716+
codom, dom = codomain(t), domain(t)
717+
arraysize = dims(t)
718+
matsize = (dim(codom), dim(dom))
719+
(size(data) == arraysize || size(data) == matsize) ||
720+
throw(DimensionMismatch("input data has incompatible size for the given tensor"))
721+
data = reshape(collect(data), arraysize)
722+
723+
I = sectortype(t)
724+
if I === Trivial && t isa TensorMap
725+
copy!(t.data, reshape(data, length(t.data)))
726+
return t
727+
end
728+
729+
for ((f₁, f₂), subblock) in subblocks(t)
730+
F = convert(Array, (f₁, f₂))
731+
dataslice = sview(
732+
data, axes(codomain(t), f₁.uncoupled)..., axes(domain(t), f₂.uncoupled)...
733+
)
734+
if FusionStyle(I) === UniqueFusion()
735+
Fscalar = only(F) # contains a single element
736+
scale!(subblock, dataslice, conj(Fscalar))
737+
else
738+
szbF = _interleave(size(F), size(subblock))
739+
indset1 = ntuple(identity, numind(t))
740+
indset2 = 2 .* indset1
741+
indset3 = indset2 .- 1
742+
TensorOperations.tensorcontract!(
743+
subblock,
744+
F, ((), indset1), true,
745+
sreshape(dataslice, szbF), (indset3, indset2), false,
746+
(indset1, ()),
747+
inv(dim(f₁.coupled)), false
748+
)
749+
end
750+
end
751+
752+
return t
753+
end
754+
681755
# Show and friends
682756
# ----------------
683-
684757
function Base.dims2string(V::HomSpace)
685758
str_cod = numout(V) == 0 ? "()" : join(dim.(codomain(V)), '×')
686759
str_dom = numin(V) == 0 ? "()" : join(dim.(domain(V)), '×')

0 commit comments

Comments
 (0)