1- const CuTensorMap{T,S,N₁,N₂,A <: CuVector{T} } = TensorMap{T,S,N₁,N₂,A }
2- const CuTensor{T, S, N, A <: CuVector{T} } = CuTensorMap{T, S, N, 0 , A }
1+ const CuTensorMap{T,S,N₁,N₂} = TensorMap{T,S,N₁,N₂, CuVector{T,CUDA . DeviceMemory} }
2+ const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0 }
33
44function TensorKit. tensormaptype (S:: Type{<:IndexSpace} , N₁, N₂, TorA:: Type{<:StridedCuArray} )
55 if TorA <: CuArray
6- return TensorMap{eltype (TorA),S,N₁,N₂,CuVector{eltype (TorA)}}
6+ return TensorMap{eltype (TorA),S,N₁,N₂,CuVector{eltype (TorA), CUDA . DeviceMemory }}
77 else
88 throw (ArgumentError (" argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:CuVector{<:Number}`" ))
99 end
1010end
1111
1212function CuTensorMap {T} (:: UndefInitializer , V:: TensorMapSpace{S, N₁, N₂} ) where {T, S, N₁, N₂}
13- return CuTensorMap {T,S,N₁,N₂,CuVector{T} } (undef, V)
13+ return CuTensorMap {T,S,N₁,N₂} (undef, V)
1414end
1515
1616function CuTensorMap {T} (:: UndefInitializer , codomain:: TensorSpace{S} ,
@@ -164,14 +164,16 @@ function TensorKit.scalar(t::CuTensorMap)
164164 first (blocks (t))[2 ][1 , 1 ] : throw (DimensionMismatch ())
165165end
166166
167- TensorKit. scalartype (A:: CuArray{T} ) where {T} = T
167+ TensorKit. scalartype (A:: StridedCuArray{T} ) where {T} = T
168+ vi_scalartype (:: Type{<:CuTensorMap{T}} ) where {T} = T
169+ vi_scalartype (:: Type{<:CuArray{T}} ) where {T} = T
168170
169- function TensorKit. similarstoragetype (TT:: Type{<:CuTensorMap} , :: Type{T} ) where {T }
170- return CuVector{T}
171+ function TensorKit. similarstoragetype (TT:: Type{<:CuTensorMap{TTT,S,N₁,N₂}} , :: Type{T} ) where {TTT,T,S,N₁,N₂ }
172+ return CuVector{T, CUDA . DeviceMemory }
171173end
172174
173- function Base. convert (TT:: Type{CuTensorMap{T,S,N₁,N₂,A }} ,
174- t:: AbstractTensorMap{<:Any,S,N₁,N₂} ) where {T,S,N₁,N₂,A <: CuVector{T} }
175+ function Base. convert (TT:: Type{CuTensorMap{T,S,N₁,N₂}} ,
176+ t:: AbstractTensorMap{<:Any,S,N₁,N₂} ) where {T,S,N₁,N₂}
175177 if typeof (t) === TT
176178 return t
177179 else
@@ -180,7 +182,7 @@ function Base.convert(TT::Type{CuTensorMap{T,S,N₁,N₂,A}},
180182 end
181183end
182184
183- function Base. copy! (tdst:: CuTensorMap{T, S, N₁, N₂, A } , tsrc:: CuTensorMap{T, S, N₁, N₂, A } ) where {T, S, N₁, N₂, A }
185+ function Base. copy! (tdst:: CuTensorMap{T, S, N₁, N₂} , tsrc:: CuTensorMap{T, S, N₁, N₂} ) where {T, S, N₁, N₂}
184186 space (tdst) == space (tsrc) || throw (SpaceMismatch (" $(space (tdst)) ≠ $(space (tsrc)) " ))
185187 for ((c, bdst), (_, bsrc)) in zip (blocks (tdst), blocks (tsrc))
186188 copy! (bdst, bsrc)
@@ -195,3 +197,11 @@ function Base.copy!(tdst::CuTensorMap, tsrc::TensorKit.AdjointTensorMap)
195197 end
196198 return tdst
197199end
200+
201+ function Base. promote_rule (:: Type{<:TT₁} ,
202+ :: Type{<:TT₂} ) where {S,N₁,N₂, TTT₁, TTT₂,
203+ TT₁<: CuTensorMap{TTT₁,S,N₁,N₂} ,
204+ TT₂<: CuTensorMap{TTT₂,S,N₁,N₂} }
205+ T = TensorKit. VectorInterface. promote_add (TTT₁, TTT₂)
206+ return CuTensorMap{T,S,N₁,N₂}
207+ end
0 commit comments