Skip to content

Commit 9740cfb

Browse files
committed
even more careful
1 parent 30f4988 commit 9740cfb

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

src/tensors/abstracttensor.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,33 +65,35 @@ appropriate storage types. Additionally this registers the default storage type
6565
`T <: Number` is provided, which is `Vector{T}`.
6666
""" similarstoragetype
6767

68-
# implement in type domain and fill in default value
69-
@inline similarstoragetype(t, ::Type{T} = scalartype(t)) where {T <: Number} =
70-
similarstoragetype(typeof(t), T)
68+
# implement in type domain
69+
similarstoragetype(t) = similarstoragetype(typeof(t))
70+
similarstoragetype(t, ::Type{T}) where {T <: Number} = similarstoragetype(typeof(t), T)
71+
72+
# avoid infinite recursion
73+
similarstoragetype(X::Type) =
74+
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X`"))
75+
similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
76+
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X` and/or scalar type `$T`"))
7177

7278
# implement on tensors
79+
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
7380
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
7481
similarstoragetype(storagetype(TT), T)
7582

7683
# implement on arrays
84+
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
7785
Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} =
7886
Core.Compiler.return_type(similar, Tuple{A, Type{T}, Int})
7987

80-
# implement on sectordicts - intercept scalartype defaults!
81-
similarstoragetype(d::D, ::Type{T} = scalartype(valtype(d))) where {D <: AbstractDict{<:Sector}, T <: Number} =
82-
similarstoragetype(typeof(d), T)
83-
similarstoragetype(::Type{D}, ::Type{T} = scalartype(valtype(D))) where {D <: AbstractDict{<:Sector}, T <: Number} =
88+
# implement on sectordicts
89+
similarstoragetype(::Type{D}) where {D <: AbstractDict{<:Sector, <:AbstractMatrix}} =
90+
similarstoragetype(valtype(D))
91+
similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:AbstractMatrix}, T <: Number} =
8492
similarstoragetype(valtype(D), T)
8593

8694
# default storage type for numbers
8795
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}
8896

89-
# avoid infinite recursion
90-
similarstoragetype(X::Type, Y::Type) =
91-
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X` and/or scalar type `$Y`"))
92-
# ambiguity
93-
similarstoragetype(X::Type, ::Type{Y}) where {Y <: Number} =
94-
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X` and/or scalar type `$Y`"))
9597

9698
# tensor characteristics: space and index information
9799
#-----------------------------------------------------

src/tensors/tensoroperations.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@ function TO.tensoralloc(
1111
A = storagetype(TT)
1212
dim = fusionblockstructure(structure).totaldim
1313
data = TO.tensoralloc(A, dim, istemp, allocator)
14-
15-
# the following doesn't work since data isn't actually restricted to <: A
16-
# so hardcode type instead
17-
# TT′ = tensormaptype(spacetype(structure), numout(structure), numin(structure), typeof(data))
18-
TT′ = TensorMap{scalartype(data), spacetype(structure), numout(structure), numin(structure), typeof(data)}
19-
14+
TT′ = tensormaptype(spacetype(structure), numout(structure), numin(structure), typeof(data))
2015
return TT′(data, structure)
2116
end
2217

0 commit comments

Comments
 (0)