Skip to content

Commit 32a93bf

Browse files
committed
uniformize into similarstoragetype
1 parent 2e9e075 commit 32a93bf

File tree

2 files changed

+49
-22
lines changed

2 files changed

+49
-22
lines changed

src/tensors/abstracttensor.jl

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,51 @@ end
4545
Return the type of vector that stores the data of a tensor.
4646
""" storagetype
4747

48-
similarstoragetype(TT::Type{<:AbstractTensorMap}) = similarstoragetype(TT, scalartype(TT))
48+
# storage type determination and promotion - hooks for specializing
49+
# the default implementation tries to leverarge inference and `similar`
50+
@doc """
51+
similarstoragetype(t, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
52+
similarstoragetype(TT, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
53+
similarstoragetype(A, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
54+
similarstoragetype(D, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
4955
50-
function similarstoragetype(TT::Type{<:AbstractTensorMap}, ::Type{T}) where {T}
51-
return Core.Compiler.return_type(similar, Tuple{storagetype(TT), Type{T}})
52-
end
56+
similarstoragetype(T::Type{<:Number}) -> Vector{T}
57+
58+
For a given tensor `t`, tensor type `TT <: AbstractTensorMap`, array type `A <: AbstractArray`,
59+
or sector dictionary type `D <: AbstractDict{<:Sector, <:AbstractMatrix}`, compute an appropriate
60+
storage type for tensors. Optionally, a different scalar type `T` can be supplied as well.
61+
62+
This function determines the type of newly allocated `TensorMap`s throughout TensorKit.jl.
63+
It does so by leveraging type inference and calls to `Base.similar` for automatically determining
64+
appropriate storage types. Additionally this registers the default storage type when only a type
65+
`T <: Number` is provided, which is `Vector{T}`.
66+
""" similarstoragetype
67+
68+
# implement in type domain and fill in default value
69+
@inline similarstoragetype(t, ::Type{T} = scalartype(t)) where {T <: Number} =
70+
similarstoragetype(typeof(t), T)
71+
72+
# implement on tensors
73+
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
74+
similarstoragetype(storagetype(TT), T)
75+
76+
# implement on arrays
77+
similarstoragetype(::Type{A}, ::Type{T}) where {T <: Number, A <: DenseVector{T}} = A
78+
Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} =
79+
Core.Compiler.return_type(similar, Tuple{A, Type{T}, Int})
80+
81+
# implement on sectordicts - intercept scalartype defaults!
82+
similarstoragetype(d::D, ::Type{T} = scalartype(valtype(d))) where {D <: AbstractDict{<:Sector}, T <: Number} =
83+
similarstoragetype(typeof(d), T)
84+
similarstoragetype(::Type{D}, ::Type{T} = scalartype(valtype(D))) where {D <: AbstractDict{<:Sector}, T <: Number} =
85+
similarstoragetype(valtype(D), T)
86+
87+
# default storage type for numbers
88+
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}
89+
90+
# avoid infinite recursion
91+
similarstoragetype(X::Type, Y::Type) =
92+
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X` and/or scalar type `$Y`"))
5393

5494
# tensor characteristics: space and index information
5595
#-----------------------------------------------------
@@ -175,7 +215,6 @@ end
175215
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
176216
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
177217
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
178-
similarstoragetype(t::AbstractTensorMap, T = scalartype(t)) = similarstoragetype(typeof(t), T)
179218

180219
numout(t::AbstractTensorMap) = numout(typeof(t))
181220
numin(t::AbstractTensorMap) = numin(typeof(t))

src/tensors/tensor.jl

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,11 @@ i.e. a tensor map with only a non-trivial output space.
4848
"""
4949
const Tensor{T, S, N, A} = TensorMap{T, S, N, 0, A}
5050

51-
function tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type)
52-
A = _tensormap_storagetype(TorA)
53-
A <: DenseVector || throw(ArgumentError("Cannot determine a valid storage type from argument $TorA"))
51+
function tensormaptype(::Type{S}, N₁, N₂, ::Type{TorA}) where {S <: IndexSpace, TorA}
52+
A = similarstoragetype(TorA)
5453
return TensorMap{scalartype(A), S, N₁, N₂, A}
5554
end
5655

57-
# hook for mapping input types to storage types -- to be implemented in extensions
58-
_tensormap_storagetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
59-
_tensormap_storagetype(::Type{A}) where {A <: Array} = _tensormap_storagetype(scalartype(A))
60-
_tensormap_storagetype(::Type{T}) where {T <: Number} = Vector{T}
61-
function _tensormap_storagetype(::Type{A}) where {A <: AbstractArray}
62-
PA = parenttype(A)
63-
PA === A && throw(MethodError(_tensormap_storagetype, A)) # avoid infinite recursion
64-
return _tensormap_storagetype(PA)
65-
end
66-
6756
# Basic methods for characterising a tensor:
6857
#--------------------------------------------
6958
space(t::TensorMap) = t.space
@@ -201,9 +190,8 @@ cases.
201190
the specified symmetry structure, up to a tolerance `tol`.
202191
"""
203192
function TensorMap(data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data))))))
204-
T = eltype(data)
205-
A = _tensormap_storagetype(typeof(data))
206-
return TensorMapWithStorage{T, A}(data, V; tol)
193+
A = similarstoragetype(data)
194+
return TensorMapWithStorage{scalartype(A), A}(data, V; tol)
207195
end
208196
TensorMap(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) =
209197
TensorMap(data, codom dom; kwargs...)
@@ -259,7 +247,7 @@ Construct a `TensorMap` by explicitly specifying its block data.
259247
- `domain::ProductSpace{S, N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type `S <: ElementarySpace`.
260248
"""
261249
function TensorMap(data::_BlockData, V::TensorMapSpace)
262-
A = _tensormap_storagetype(valtype(data))
250+
A = similarstoragetype(data)
263251
return TensorMapWithStorage{scalartype(A), A}(data, V)
264252
end
265253
TensorMap(data::_BlockData, codom::TensorSpace, dom::TensorSpace) =

0 commit comments

Comments
 (0)