Skip to content

Commit ce23b52

Browse files
committed
unify tensortype usage
1 parent beeefe4 commit ce23b52

File tree

3 files changed

+41
-51
lines changed

3 files changed

+41
-51
lines changed

src/auxiliary/auxiliary.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,6 @@ end
8686
else
8787
_allequal(f, xs) = allequal(f, xs)
8888
end
89+
90+
Base.@assume_effects :foldable parenttype(::Type{T}) where {T} =
91+
Core.Compiler.return_type(parent, Tuple{T})

src/tensors/abstracttensor.jl

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -494,54 +494,39 @@ specialize this method.
494494
""" Base.similar(::AbstractTensorMap, args...)
495495

496496
function Base.similar(
497-
t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S}
498-
) where {T, S}
497+
t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace, domain::TensorSpace
498+
) where {T}
499499
return similar(t, T, codomain domain)
500500
end
501+
501502
# 3 arguments
502-
function Base.similar(
503-
t::AbstractTensorMap, codomain::TensorSpace{S}, domain::TensorSpace{S}
504-
) where {S}
505-
return similar(t, similarstoragetype(t), codomain domain)
506-
end
507-
function Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T}
508-
return similar(t, T, codomain one(codomain))
509-
end
503+
Base.similar(t::AbstractTensorMap, codomain::TensorSpace, domain::TensorSpace) =
504+
similar(t, similarstoragetype(t), codomain domain)
505+
Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T} =
506+
similar(t, T, codomain one(codomain))
507+
510508
# 2 arguments
511-
function Base.similar(t::AbstractTensorMap, codomain::TensorSpace)
512-
return similar(t, similarstoragetype(t), codomain one(codomain))
513-
end
514-
Base.similar(t::AbstractTensorMap, P::TensorMapSpace) = similar(t, storagetype(t), P)
509+
Base.similar(t::AbstractTensorMap, codomain::TensorSpace) =
510+
similar(t, similarstoragetype(t), codomain one(codomain))
511+
Base.similar(t::AbstractTensorMap, V::TensorMapSpace) = similar(t, similarstoragetype(t), V)
515512
Base.similar(t::AbstractTensorMap, ::Type{T}) where {T} = similar(t, T, space(t))
516513
# 1 argument
517514
Base.similar(t::AbstractTensorMap) = similar(t, similarstoragetype(t), space(t))
518515

519516
# generic implementation for AbstractTensorMap -> returns `TensorMap`
520-
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, P::TensorMapSpace{S}) where {TorA, S}
521-
if TorA <: Number
522-
T = TorA
523-
A = similarstoragetype(t, T)
524-
elseif TorA <: DenseVector
525-
A = TorA
526-
T = scalartype(A)
527-
else
528-
throw(ArgumentError("Type $TorA not supported for similar"))
529-
end
530-
531-
N₁ = length(codomain(P))
532-
N₂ = length(domain(P))
533-
return TensorMap{T, S, N₁, N₂, A}(undef, P)
517+
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, V::TensorMapSpace) where {TorA}
518+
A = TorA <: Number ? similarstoragetype(t, TorA) : TorA
519+
TT = tensormaptype(spacetype(V), numout(V), numin(V), A)
520+
return TT(undef, V)
534521
end
535522

536523
# implementation in type-domain
537-
function Base.similar(::Type{TT}, P::TensorMapSpace) where {TT <: AbstractTensorMap}
538-
return TensorMap{scalartype(TT)}(undef, P)
539-
end
540-
function Base.similar(
541-
::Type{TT}, cod::TensorSpace{S}, dom::TensorSpace{S}
542-
) where {TT <: AbstractTensorMap, S}
543-
return TensorMap{scalartype(TT)}(undef, cod, dom)
524+
function Base.similar(::Type{TT}, V::TensorMapSpace) where {TT <: AbstractTensorMap}
525+
TT′ = tensormaptype(spacetype(V), numout(V), numin(V), similarstoragetype(TT))
526+
return TT′(undef, V)
544527
end
528+
Base.similar(::Type{TT}, cod::TensorSpace, dom::TensorSpace) where {TT <: AbstractTensorMap} =
529+
similar(TT, cod dom)
545530

546531
# Equality and approximality
547532
#----------------------------

src/tensors/tensor.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac
3131
I = sectortype(S)
3232
T <: Real && !(sectorscalartype(I) <: Real) &&
3333
@warn("Tensors with real data might be incompatible with sector type $I", maxlog = 1)
34+
d = fusionblockstructure(space).totaldim
35+
length(data) == d || throw(DimensionMismatch("invalid length of data"))
3436
return new{T, S, N₁, N₂, A}(data, space)
3537
end
3638
end
@@ -47,19 +49,20 @@ i.e. a tensor map with only a non-trivial output space.
4749
const Tensor{T, S, N, A} = TensorMap{T, S, N, 0, A}
4850

4951
function tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type)
50-
if TorA <: Number
51-
return TensorMap{TorA, S, N₁, N₂, Vector{TorA}}
52-
elseif TorA <: DenseVector
53-
return TensorMap{scalartype(TorA), S, N₁, N₂, TorA}
54-
else
55-
throw(ArgumentError("argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:DenseVector{<:Number}`"))
56-
end
52+
A = _tensormap_storagetype(TorA)
53+
A <: DenseVector || throw(ArgumentError("Cannot determine a valid storage type from argument $TorA"))
54+
return TensorMap{scalartype(A), S, N₁, N₂, A}
5755
end
5856

5957
# hook for mapping input types to storage types -- to be implemented in extensions
60-
_tensormap_storagetype(::Type{A}) where {A <: AbstractArray} = _tensormap_storagetype(scalartype(A))
6158
_tensormap_storagetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
59+
_tensormap_storagetype(::Type{A}) where {A <: Array} = _tensormap_storagetype(scalartype(A))
6260
_tensormap_storagetype(::Type{T}) where {T <: Number} = Vector{T}
61+
function _tensormap_storagetype(::Type{A}) where {A <: AbstractArray}
62+
PA = parenttype(A)
63+
PA === A && throw(MethodError(_tensormap_storagetype, A)) # avoid infinite recursion
64+
return _tensormap_storagetype(PA)
65+
end
6366

6467
# Basic methods for characterising a tensor:
6568
#--------------------------------------------
@@ -95,7 +98,7 @@ const TensorWithStorage{T, A <: DenseVector{T}, S, N} = Tensor{T, S, N, A}
9598
Construct a `TensorMap` with uninitialized data with elements of type `T`.
9699
"""
97100
TensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T} =
98-
TensorMapWithStorage{T, _tensormap_storagetype(T)}(undef, V)
101+
tensormaptype(spacetype(V), numout(V), numin(V), T)(undef, V)
99102
TensorMap{T}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T} =
100103
TensorMap{T}(undef, codomain domain)
101104
Tensor{T}(::UndefInitializer, V::TensorSpace) where {T} = TensorMap{T}(undef, V one(V))
@@ -108,7 +111,7 @@ Tensor{T}(::UndefInitializer, V::TensorSpace) where {T} = TensorMap{T}(undef, V
108111
Construct a `TensorMap` with uninitialized data stored as `A <: DenseVector{T}`.
109112
"""
110113
TensorMapWithStorage{T, A}(::UndefInitializer, V::TensorMapSpace) where {T, A} =
111-
TensorMap{T, spacetype(V), numout(V), numin(V), A}(undef, V)
114+
tensormaptype(spacetype(V), numout(V), numin(V), A)(undef, V)
112115
TensorMapWithStorage{T, A}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
113116
TensorMapWithStorage{T, A}(undef, codomain domain)
114117
TensorWithStorage{T, A}(::UndefInitializer, V::TensorSpace) where {T, A} = TensorMapWithStorage{T, A}(undef, V one(V))
@@ -128,7 +131,7 @@ Construct a `TensorMap` from the given raw data.
128131
This constructor takes ownership of the provided vector, and will not make an independent copy.
129132
"""
130133
TensorMap{T}(data::DenseVector{T}, V::TensorMapSpace) where {T} =
131-
TensorMapWithStorage{T, typeof(data)}(data, V)
134+
tensormaptype(spacetype(V), numout(V), numin(V), typeof(data))(data, V)
132135
TensorMap{T}(data::DenseVector{T}, codomain::TensorSpace, domain::TensorSpace) where {T} =
133136
TensorMap{T}(data, codomain domain)
134137

@@ -141,8 +144,7 @@ Construct a `TensorMap` from the given raw data.
141144
This constructor takes ownership of the provided vector, and will not make an independent copy.
142145
"""
143146
function TensorMapWithStorage{T, A}(data::A, V::TensorMapSpace) where {T, A}
144-
length(data) == dim(V) || throw(DimensionMismatch("invalid length of data"))
145-
return TensorMap{T, spacetype(V), numout(V), numin(V), A}(data, V)
147+
return tensormaptype(spacetype(V), numout(V), numin(V), typeof(data))(data, V)
146148
end
147149
TensorMapWithStorage{T, A}(data::A, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
148150
TensorMapWithStorage{T, A}(data, codomain domain)
@@ -213,11 +215,11 @@ function TensorMapWithStorage{T, A}(
213215
) where {T, A}
214216
# refer to specific raw data constructors if input is a vector of the correct length
215217
ndims(data) == 1 && length(data) == dim(V) &&
216-
return TensorMap{T, spacetype(V), numout(V), numin(V), A}(data, V)
218+
return tensormaptype(spacetype(V), numout(V), numin(V), A)(data, V)
217219

218220
# special case trivial: refer to same method, but now with vector argument
219221
sectortype(V) === Trivial &&
220-
return TensorMap{T, spacetype(V), numout(V), numin(V), A}(reshape(data, length(data)), V)
222+
return tensormaptype(spacetype(V), numout(V), numin(V), A)(reshape(data, length(data)), V)
221223

222224
# do projection
223225
t = TensorMapWithStorage{T, A}(undef, V)
@@ -230,7 +232,7 @@ function TensorMapWithStorage{T, A}(
230232
return t
231233
end
232234
TensorMapWithStorage{T, A}(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) where {T, A} =
233-
TensorMapWithStorage(data, codom dom; kwargs...)
235+
TensorMapWithStorage{T, A}(data, codom dom; kwargs...)
234236
TensorWithStorage{T, A}(data::AbstractArray, codom::TensorSpace; kwargs...) where {T, A} =
235237
TensorMapWithStorage{T, A}(data, codom one(codom); kwargs...)
236238

0 commit comments

Comments
 (0)