@@ -71,10 +71,22 @@ dim(t::TensorMap) = length(t.data)
7171
7272# General TensorMap constructors
7373# --------------------------------
74- # undef constructors
74+ # hook for mapping input types to storage types -- to be implemented in extensions
75+ _tensormap_storagetype (:: Type{A} ) where {A <: AbstractArray } = _tensormap_storagetype (eltype (A))
76+ _tensormap_storagetype (:: Type{T} ) where {T <: Number } = Vector{T}
77+
78+ # utility type alias that makes constructors also work for type aliases that specify
79+ # different storage types. (i.e. CuTensorMap = _TensorMap{T, CuVector{T}, ...})
80+ # TODO : do we need a name for this and do we want to consider this as public?
81+ const _TensorMap{T, A <: DenseVector{T} , S, N₁, N₂} = TensorMap{T, S, N₁, N₂, A}
82+ const _Tensor{T, A <: DenseVector{T} , S, N} = Tensor{T, S, N, A}
83+
84+ # undef constructors:
85+ # - dispatch start with TensorMap{T}
86+ # - select A and map to _TensorMap{T, A}
87+ # - select S, N1, N2 and map to TensorMap{T,S,N1,N2,A}
7588"""
76- TensorMap{T}(undef, codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂})
77- where {T,S,N₁,N₂}
89+ TensorMap{T}(undef, codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂}) where {T,S,N₁,N₂}
7890 TensorMap{T}(undef, codomain ← domain)
7991 TensorMap{T}(undef, domain → codomain)
8092 # expert mode: select storage type `A`
@@ -83,33 +95,35 @@ dim(t::TensorMap) = length(t.data)
8395
8496Construct a `TensorMap` with uninitialized data.
8597"""
86- function TensorMap {T} (:: UndefInitializer , V:: TensorMapSpace{S, N₁, N₂} ) where {T, S, N₁, N₂}
87- return TensorMap {T, S, N₁, N₂, Vector{T}} (undef, V)
88- end
89- function TensorMap {T} (
90- :: UndefInitializer , codomain:: TensorSpace{S} , domain:: TensorSpace{S}
91- ) where {T, S}
92- return TensorMap {T} (undef, codomain ← domain)
93- end
94- function Tensor {T} (:: UndefInitializer , V:: TensorSpace{S} ) where {T, S}
95- return TensorMap {T} (undef, V ← one (V))
96- end
98+ TensorMap {T} (:: UndefInitializer , V:: TensorMapSpace ) where {T} =
99+ _TensorMap {T, _tensormap_storagetype(T)} (undef, V)
100+ TensorMap {T} (:: UndefInitializer , codomain:: TensorSpace , domain:: TensorSpace ) where {T} =
101+ TensorMap {T} (undef, codomain ← domain)
102+ Tensor {T} (:: UndefInitializer , V:: TensorSpace ) where {T} = TensorMap {T} (undef, V ← one (V))
103+
104+ # specifying storagetype, fill in other parameters
105+ _TensorMap {T, A} (:: UndefInitializer , V:: TensorMapSpace ) where {T, A} =
106+ TensorMap {T, spacetype(V), numout(V), numin(V), A} (undef, V)
107+ _TensorMap {T, A} (:: UndefInitializer , codomain:: TensorSpace , domain:: TensorSpace ) where {T, A} =
108+ _TensorMap {T, A} (undef, codomain ← domain)
109+ _Tensor {T, A} (:: UndefInitializer , V:: TensorSpace ) where {T, A} = _TensorMap {T, A} (undef, V ← one (V))
97110
98111# constructor starting from vector = independent data (N₁ + N₂ = 1 is special cased below)
99112# documentation is captured by the case where `data` is a general array
100- # here, we force the `T` argument to distinguish it from the more general constructor below
101- function TensorMap {T} (
102- data:: A , V:: TensorMapSpace{S, N₁, N₂}
103- ) where {T, S, N₁, N₂, A <: DenseVector{T} }
104- return TensorMap {T, S, N₁, N₂, A} (data, V)
105- end
106- function TensorMap {T} (
107- data:: DenseVector{T} , codomain:: TensorSpace{S} , domain:: TensorSpace{S}
108- ) where {T, S}
109- return TensorMap (data, codomain ← domain)
110- end
113+ # here, we force the `T` and/or `A` argument to distinguish it from the more general constructor below
114+ TensorMap {T} (data:: DenseVector{T} , V:: TensorMapSpace ) where {T} =
115+ _TensorMap {T, typeof(data)} (data, V)
116+ TensorMap {T} (data:: DenseVector{T} , codomain:: TensorSpace , domain:: TensorSpace ) where {T} =
117+ TensorMap {T} (data, codomain ← domain)
118+
119+ _TensorMap {T, A} (data:: DenseVector{T} , V:: TensorMapSpace ) where {T, A} =
120+ TensorMap {T, spacetype(V), numout(V), numin(V), A} (data, V)
121+ _TensorMap {T, A} (data:: DenseVector{T} , codomain:: TensorSpace , domain:: TensorSpace ) where {T, A} =
122+ _TensorMap {T, A} (data, codomain ← domain)
111123
112124# constructor starting from block data
125+ const _BlockData{I <: Sector , A <: AbstractMatrix } = AbstractDict{I, A}
126+
113127"""
114128 TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, codomain::ProductSpace{S,N₁},
115129 domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂}
@@ -129,29 +143,33 @@ Construct a `TensorMap` by explicitly specifying its block data.
129143Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
130144using the syntax `codomain ← domain` or `domain → codomain`.
131145"""
132- function TensorMap (
133- data:: AbstractDict{<:Sector, <:AbstractMatrix} , V:: TensorMapSpace{S, N₁, N₂}
134- ) where {S, N₁, N₂}
135- T = eltype (valtype (data))
136- t = TensorMap {T} (undef, V)
146+ function TensorMap (data:: _BlockData , V:: TensorMapSpace )
147+ A = _tensormap_storagetype (valtype (data))
148+ return _TensorMap {scalartype(A), A} (data, V)
149+ end
150+ TensorMap (data:: _BlockData , codom:: TensorSpace , dom:: TensorSpace ) =
151+ TensorMap (data, codom ← dom)
152+
153+ function _TensorMap {T, A} (data:: _BlockData , V:: TensorMapSpace ) where {T, A}
154+ t = _TensorMap {T, A} (undef, V)
155+
156+ # check that there aren't too many blocks
157+ for (c, b) in data
158+ c ∈ blocksectors (t) || isempty (b) || throw (SectorMismatch (" data for block sector $c not expected" ))
159+ end
160+
161+ # fill in the blocks -- rely on conversion in copy
137162 for (c, b) in blocks (t)
138163 haskey (data, c) || throw (SectorMismatch (" no data for block sector $c " ))
139164 datac = data[c]
140- size (datac) == size (b) ||
141- throw (DimensionMismatch (" wrong size of block for sector $c " ))
165+ size (datac) == size (b) || throw (DimensionMismatch (" wrong size of block for sector $c " ))
142166 copy! (b, datac)
143167 end
144- for (c, b) in data
145- c ∈ blocksectors (t) || isempty (b) ||
146- throw (SectorMismatch (" data for block sector $c not expected" ))
147- end
168+
148169 return t
149170end
150- function TensorMap (
151- data:: AbstractDict{<:Sector, <:AbstractMatrix} , codom:: TensorSpace{S} , dom:: TensorSpace{S}
152- ) where {S}
153- return TensorMap (data, codom ← dom)
154- end
171+ _TensorMap {T, A} (data:: _BlockData , codom:: TensorSpace , dom:: TensorSpace ) where {T, A} =
172+ _TensorMap {T, A} (data, codom ← dom)
155173
156174@doc """
157175 zeros([T=Float64,], codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂}) where {S,N₁,N₂,T}
@@ -317,49 +335,45 @@ cases.
317335 to a plain array is possible, and only in the case where the `data` actually respects
318336 the specified symmetry structure, up to a tolerance `tol`.
319337"""
320- function TensorMap (
321- data:: AbstractArray , V:: TensorMapSpace{S, N₁, N₂} ;
322- tol = sqrt (eps (real (float (eltype (data)))))
323- ) where {S <: IndexSpace , N₁, N₂}
338+ function TensorMap (data:: AbstractArray , V:: TensorMapSpace ; tol = sqrt (eps (real (float (eltype (data))))))
324339 T = eltype (data)
325- if ndims (data) == 1 && length (data) == dim (V)
326- if data isa DenseVector # refer to specific data-capturing constructor
327- return TensorMap {T} (data, V)
328- else
329- return TensorMap {T} (collect (data), V)
330- end
331- end
340+ A = _tensormap_storagetype (typeof (data))
341+ return _TensorMap {T, A} (data, V; tol)
342+ end
343+ function _TensorMap {T, A} (
344+ data:: AbstractArray , V:: TensorMapSpace ; tol = sqrt (eps (real (float (eltype (data)))))
345+ ) where {T, A}
346+ # refer to specific data-capturing constructors if input is a vector of the correct length
347+ ndims (data) == 1 && length (data) == dim (V) && return _TensorMap {T, A} (data, V)
332348
333349 # dimension check
334- codom = codomain (V)
335- dom = domain (V)
350+ codom, dom = codomain (V), domain (V)
336351 arraysize = dims (V)
337352 matsize = (dim (codom), dim (dom))
353+ (size (data) == arraysize || size (data) == matsize) || throw (DimensionMismatch ())
338354
339- if ! (size (data) == arraysize || size (data) == matsize)
340- throw (DimensionMismatch ())
341- end
342-
343- if sectortype (S) === Trivial # refer to same method, but now with vector argument
344- return TensorMap (reshape (data, length (data)), V)
355+ if sectortype (V) === Trivial # refer to same method, but now with vector argument
356+ return _TensorMap {T, A} (reshape (data, length (data)), V)
345357 end
346358
347- t = TensorMap {T } (undef, codom, dom )
359+ t = _TensorMap {T, A } (undef, V )
348360 arraydata = reshape (collect (data), arraysize)
349361 t = project_symmetric! (t, arraydata)
350362 if ! isapprox (arraydata, convert (Array, t); atol = tol)
351363 throw (ArgumentError (" Data has non-zero elements at incompatible positions" ))
352364 end
353365 return t
354366end
355- function TensorMap (
356- data:: AbstractArray , codom:: TensorSpace{S} , dom:: TensorSpace{S} ; kwargs...
357- ) where {S}
358- return TensorMap (data, codom ← dom; kwargs... )
359- end
360- function Tensor (data:: AbstractArray , codom:: TensorSpace , ; kwargs... )
361- return TensorMap (data, codom ← one (codom); kwargs... )
362- end
367+
368+ TensorMap (data:: AbstractArray , codom:: TensorSpace , dom:: TensorSpace ; kwargs... ) =
369+ TensorMap (data, codom ← dom; kwargs... )
370+ _TensorMap {T, A} (data:: AbstractArray , codom:: TensorSpace , dom:: TensorSpace ; kwargs... ) where {T, A} =
371+ _TensorMap (data, codom ← dom; kwargs... )
372+
373+ Tensor (data:: AbstractArray , codom:: TensorSpace ; kwargs... ) =
374+ TensorMap (data, codom ← one (codom); kwargs... )
375+ _Tensor {T, A} (data:: AbstractArray , codom:: TensorSpace ; kwargs... ) where {T, A} =
376+ _TensorMap {T, A} (data, codom ← one (codom); kwargs... )
363377
364378"""
365379 project_symmetric!(t::TensorMap, data::DenseArray) -> TensorMap
0 commit comments