Skip to content

Commit dba83c7

Browse files
committed
More specializations
1 parent 2ee3e2f commit dba83c7

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,25 @@ TensorKit.scalartype(A::CuArray{T}) where {T} = T
165165
function TensorKit.similarstoragetype(TT::Type{<:CuTensorMap}, ::Type{T}) where {T}
166166
return CuVector{T}
167167
end
168+
169+
function Base.convert(::Type{CuTensorMap}, t::AbstractTensorMap)
170+
return copy!(CuTensorMap{scalartype(t)}(undef, space(t)), t)
171+
end
172+
173+
function Base.convert(TT::Type{CuTensorMap{T,S,N₁,N₂,A}},
174+
t::AbstractTensorMap{<:Any,S,N₁,N₂}) where {T,S,N₁,N₂,A<:CuVector{T}}
175+
if typeof(t) === TT
176+
return t
177+
else
178+
tnew = TT(undef, space(t))
179+
return copy!(tnew, t)
180+
end
181+
end
182+
183+
function Base.copy!(tdst::CuTensorMap{T, S, N₁, N₂, A}, tsrc::CuTensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
184+
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
185+
for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc))
186+
copy!(bdst, bsrc)
187+
end
188+
return tdst
189+
end

0 commit comments

Comments
 (0)