Skip to content

Commit 1493bac

Browse files
committed
rework tensor contructors
1 parent 06d32b0 commit 1493bac

File tree

1 file changed

+83
-69
lines changed

1 file changed

+83
-69
lines changed

src/tensors/tensor.jl

Lines changed: 83 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,22 @@ dim(t::TensorMap) = length(t.data)
7171

7272
# General TensorMap constructors
7373
#--------------------------------
74-
# undef constructors
74+
# hook for mapping input types to storage types -- to be implemented in extensions
75+
_tensormap_storagetype(::Type{A}) where {A <: AbstractArray} = _tensormap_storagetype(eltype(A))
76+
_tensormap_storagetype(::Type{T}) where {T <: Number} = Vector{T}
77+
78+
# utility type alias that makes constructors also work for type aliases that specify
79+
# different storage types. (i.e. CuTensorMap = _TensorMap{T, CuVector{T}, ...})
80+
# TODO: do we need a name for this and do we want to consider this as public?
81+
const _TensorMap{T, A <: DenseVector{T}, S, N₁, N₂} = TensorMap{T, S, N₁, N₂, A}
82+
const _Tensor{T, A <: DenseVector{T}, S, N} = Tensor{T, S, N, A}
83+
84+
# undef constructors:
85+
# - dispatch start with TensorMap{T}
86+
# - select A and map to _TensorMap{T, A}
87+
# - select S, N1, N2 and map to TensorMap{T,S,N1,N2,A}
7588
"""
76-
TensorMap{T}(undef, codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂})
77-
where {T,S,N₁,N₂}
89+
TensorMap{T}(undef, codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂}) where {T,S,N₁,N₂}
7890
TensorMap{T}(undef, codomain ← domain)
7991
TensorMap{T}(undef, domain → codomain)
8092
# expert mode: select storage type `A`
@@ -83,33 +95,35 @@ dim(t::TensorMap) = length(t.data)
8395
8496
Construct a `TensorMap` with uninitialized data.
8597
"""
86-
function TensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
87-
return TensorMap{T, S, N₁, N₂, Vector{T}}(undef, V)
88-
end
89-
function TensorMap{T}(
90-
::UndefInitializer, codomain::TensorSpace{S}, domain::TensorSpace{S}
91-
) where {T, S}
92-
return TensorMap{T}(undef, codomain domain)
93-
end
94-
function Tensor{T}(::UndefInitializer, V::TensorSpace{S}) where {T, S}
95-
return TensorMap{T}(undef, V one(V))
96-
end
98+
TensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T} =
99+
_TensorMap{T, _tensormap_storagetype(T)}(undef, V)
100+
TensorMap{T}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T} =
101+
TensorMap{T}(undef, codomain domain)
102+
Tensor{T}(::UndefInitializer, V::TensorSpace) where {T} = TensorMap{T}(undef, V one(V))
103+
104+
# specifying storagetype, fill in other parameters
105+
_TensorMap{T, A}(::UndefInitializer, V::TensorMapSpace) where {T, A} =
106+
TensorMap{T, spacetype(V), numout(V), numin(V), A}(undef, V)
107+
_TensorMap{T, A}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
108+
_TensorMap{T, A}(undef, codomain domain)
109+
_Tensor{T, A}(::UndefInitializer, V::TensorSpace) where {T, A} = _TensorMap{T, A}(undef, V one(V))
97110

98111
# constructor starting from vector = independent data (N₁ + N₂ = 1 is special cased below)
99112
# documentation is captured by the case where `data` is a general array
100-
# here, we force the `T` argument to distinguish it from the more general constructor below
101-
function TensorMap{T}(
102-
data::A, V::TensorMapSpace{S, N₁, N₂}
103-
) where {T, S, N₁, N₂, A <: DenseVector{T}}
104-
return TensorMap{T, S, N₁, N₂, A}(data, V)
105-
end
106-
function TensorMap{T}(
107-
data::DenseVector{T}, codomain::TensorSpace{S}, domain::TensorSpace{S}
108-
) where {T, S}
109-
return TensorMap(data, codomain domain)
110-
end
113+
# here, we force the `T` and/or `A` argument to distinguish it from the more general constructor below
114+
TensorMap{T}(data::DenseVector{T}, V::TensorMapSpace) where {T} =
115+
_TensorMap{T, typeof(data)}(data, V)
116+
TensorMap{T}(data::DenseVector{T}, codomain::TensorSpace, domain::TensorSpace) where {T} =
117+
TensorMap{T}(data, codomain domain)
118+
119+
_TensorMap{T, A}(data::DenseVector{T}, V::TensorMapSpace) where {T, A} =
120+
TensorMap{T, spacetype(V), numout(V), numin(V), A}(data, V)
121+
_TensorMap{T, A}(data::DenseVector{T}, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
122+
_TensorMap{T, A}(data, codomain domain)
111123

112124
# constructor starting from block data
125+
const _BlockData{I <: Sector, A <: AbstractMatrix} = AbstractDict{I, A}
126+
113127
"""
114128
TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, codomain::ProductSpace{S,N₁},
115129
domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂}
@@ -129,29 +143,33 @@ Construct a `TensorMap` by explicitly specifying its block data.
129143
Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
130144
using the syntax `codomain ← domain` or `domain → codomain`.
131145
"""
132-
function TensorMap(
133-
data::AbstractDict{<:Sector, <:AbstractMatrix}, V::TensorMapSpace{S, N₁, N₂}
134-
) where {S, N₁, N₂}
135-
T = eltype(valtype(data))
136-
t = TensorMap{T}(undef, V)
146+
function TensorMap(data::_BlockData, V::TensorMapSpace)
147+
A = _tensormap_storagetype(valtype(data))
148+
return _TensorMap{scalartype(A), A}(data, V)
149+
end
150+
TensorMap(data::_BlockData, codom::TensorSpace, dom::TensorSpace) =
151+
TensorMap(data, codom dom)
152+
153+
function _TensorMap{T, A}(data::_BlockData, V::TensorMapSpace) where {T, A}
154+
t = _TensorMap{T, A}(undef, V)
155+
156+
# check that there aren't too many blocks
157+
for (c, b) in data
158+
c blocksectors(t) || isempty(b) || throw(SectorMismatch("data for block sector $c not expected"))
159+
end
160+
161+
# fill in the blocks -- rely on conversion in copy
137162
for (c, b) in blocks(t)
138163
haskey(data, c) || throw(SectorMismatch("no data for block sector $c"))
139164
datac = data[c]
140-
size(datac) == size(b) ||
141-
throw(DimensionMismatch("wrong size of block for sector $c"))
165+
size(datac) == size(b) || throw(DimensionMismatch("wrong size of block for sector $c"))
142166
copy!(b, datac)
143167
end
144-
for (c, b) in data
145-
c blocksectors(t) || isempty(b) ||
146-
throw(SectorMismatch("data for block sector $c not expected"))
147-
end
168+
148169
return t
149170
end
150-
function TensorMap(
151-
data::AbstractDict{<:Sector, <:AbstractMatrix}, codom::TensorSpace{S}, dom::TensorSpace{S}
152-
) where {S}
153-
return TensorMap(data, codom dom)
154-
end
171+
_TensorMap{T, A}(data::_BlockData, codom::TensorSpace, dom::TensorSpace) where {T, A} =
172+
_TensorMap{T, A}(data, codom dom)
155173

156174
@doc """
157175
zeros([T=Float64,], codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂}) where {S,N₁,N₂,T}
@@ -317,49 +335,45 @@ cases.
317335
to a plain array is possible, and only in the case where the `data` actually respects
318336
the specified symmetry structure, up to a tolerance `tol`.
319337
"""
320-
function TensorMap(
321-
data::AbstractArray, V::TensorMapSpace{S, N₁, N₂};
322-
tol = sqrt(eps(real(float(eltype(data)))))
323-
) where {S <: IndexSpace, N₁, N₂}
338+
function TensorMap(data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data))))))
324339
T = eltype(data)
325-
if ndims(data) == 1 && length(data) == dim(V)
326-
if data isa DenseVector # refer to specific data-capturing constructor
327-
return TensorMap{T}(data, V)
328-
else
329-
return TensorMap{T}(collect(data), V)
330-
end
331-
end
340+
A = _tensormap_storagetype(typeof(data))
341+
return _TensorMap{T, A}(data, V; tol)
342+
end
343+
function _TensorMap{T, A}(
344+
data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))
345+
) where {T, A}
346+
# refer to specific data-capturing constructors if input is a vector of the correct length
347+
ndims(data) == 1 && length(data) == dim(V) && return _TensorMap{T, A}(data, V)
332348

333349
# dimension check
334-
codom = codomain(V)
335-
dom = domain(V)
350+
codom, dom = codomain(V), domain(V)
336351
arraysize = dims(V)
337352
matsize = (dim(codom), dim(dom))
353+
(size(data) == arraysize || size(data) == matsize) || throw(DimensionMismatch())
338354

339-
if !(size(data) == arraysize || size(data) == matsize)
340-
throw(DimensionMismatch())
341-
end
342-
343-
if sectortype(S) === Trivial # refer to same method, but now with vector argument
344-
return TensorMap(reshape(data, length(data)), V)
355+
if sectortype(V) === Trivial # refer to same method, but now with vector argument
356+
return _TensorMap{T, A}(reshape(data, length(data)), V)
345357
end
346358

347-
t = TensorMap{T}(undef, codom, dom)
359+
t = _TensorMap{T, A}(undef, V)
348360
arraydata = reshape(collect(data), arraysize)
349361
t = project_symmetric!(t, arraydata)
350362
if !isapprox(arraydata, convert(Array, t); atol = tol)
351363
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
352364
end
353365
return t
354366
end
355-
function TensorMap(
356-
data::AbstractArray, codom::TensorSpace{S}, dom::TensorSpace{S}; kwargs...
357-
) where {S}
358-
return TensorMap(data, codom dom; kwargs...)
359-
end
360-
function Tensor(data::AbstractArray, codom::TensorSpace, ; kwargs...)
361-
return TensorMap(data, codom one(codom); kwargs...)
362-
end
367+
368+
TensorMap(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) =
369+
TensorMap(data, codom dom; kwargs...)
370+
_TensorMap{T, A}(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) where {T, A} =
371+
_TensorMap(data, codom dom; kwargs...)
372+
373+
Tensor(data::AbstractArray, codom::TensorSpace; kwargs...) =
374+
TensorMap(data, codom one(codom); kwargs...)
375+
_Tensor{T, A}(data::AbstractArray, codom::TensorSpace; kwargs...) where {T, A} =
376+
_TensorMap{T, A}(data, codom one(codom); kwargs...)
363377

364378
"""
365379
project_symmetric!(t::TensorMap, data::DenseArray) -> TensorMap

0 commit comments

Comments
 (0)