Skip to content

Commit ac01f19

Browse files
committed
the carefulest!
1 parent 9740cfb commit ac01f19

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/tensors/abstracttensor.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: N
8282

8383
# implement on arrays
8484
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
85+
Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} =
86+
Core.Compiler.return_type(similar, Tuple{A, Int})
8587
Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} =
8688
Core.Compiler.return_type(similar, Tuple{A, Type{T}, Int})
8789

@@ -544,17 +546,17 @@ end
544546

545547
# 3 arguments
546548
Base.similar(t::AbstractTensorMap, codomain::TensorSpace, domain::TensorSpace) =
547-
similar(t, similarstoragetype(t), codomain domain)
549+
similar(t, similarstoragetype(t, scalartype(t)), codomain domain)
548550
Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T} =
549551
similar(t, T, codomain one(codomain))
550552

551553
# 2 arguments
552554
Base.similar(t::AbstractTensorMap, codomain::TensorSpace) =
553-
similar(t, similarstoragetype(t), codomain one(codomain))
554-
Base.similar(t::AbstractTensorMap, V::TensorMapSpace) = similar(t, similarstoragetype(t), V)
555+
similar(t, codomain one(codomain))
556+
Base.similar(t::AbstractTensorMap, V::TensorMapSpace) = similar(t, scalartype(t), V)
555557
Base.similar(t::AbstractTensorMap, ::Type{T}) where {T} = similar(t, T, space(t))
556558
# 1 argument
557-
Base.similar(t::AbstractTensorMap) = similar(t, similarstoragetype(t), space(t))
559+
Base.similar(t::AbstractTensorMap) = similar(t, scalartype(t), space(t))
558560

559561
# generic implementation for AbstractTensorMap -> returns `TensorMap`
560562
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, V::TensorMapSpace) where {TorA}
@@ -565,7 +567,7 @@ end
565567

566568
# implementation in type-domain
567569
function Base.similar(::Type{TT}, V::TensorMapSpace) where {TT <: AbstractTensorMap}
568-
TT′ = tensormaptype(spacetype(V), numout(V), numin(V), similarstoragetype(TT))
570+
TT′ = tensormaptype(spacetype(V), numout(V), numin(V), similarstoragetype(TT, scalartype(TT)))
569571
return TT′(undef, V)
570572
end
571573
Base.similar(::Type{TT}, cod::TensorSpace, dom::TensorSpace) where {TT <: AbstractTensorMap} =

0 commit comments

Comments
 (0)