Skip to content

Commit 00f5d71

Browse files
committed
convert(TensorMap, t) retains storagetype
1 parent b5a3ab5 commit 00f5d71

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/tensors/tensor.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ end
532532
#---------------------------
533533
Base.convert(::Type{TensorMap}, t::TensorMap) = t
534534
function Base.convert(::Type{TensorMap}, t::AbstractTensorMap)
535-
return copy!(TensorMap{scalartype(t)}(undef, space(t)), t)
535+
A = storagetype(t)
536+
return copy!(TensorMapWithStorage{scalartype(A), A}(undef, space(t)), t)
536537
end
537538

538539
function Base.convert(::Type{TensorMapWithStorage{T, A}}, t::TensorMap) where {T, A}

test/cuda/tensors.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,14 @@ for V in spacelist
258258
end
259259
end
260260
end
261-
@timedtestset "Tensor conversion" begin # TODO adjoint conversion methods don't work yet
261+
@timedtestset "Tensor conversion" begin
262262
W = V1 V2
263263
t = @constinferred CUDA.randn(W W)
264-
#@test typeof(convert(TensorMap, t')) == typeof(t) # TODO Adjoint not supported yet
264+
@test typeof(convert(TensorMap, t')) == typeof(t)
265265
tc = complex(t)
266266
@test convert(typeof(tc), t) == tc
267267
@test typeof(convert(typeof(tc), t)) == typeof(tc)
268-
# @test typeof(convert(typeof(tc), t')) == typeof(tc) # TODO Adjoint not supported yet
268+
@test typeof(convert(typeof(tc), t')) == typeof(tc)
269269
@test Base.promote_typeof(t, tc) == typeof(tc)
270270
@test Base.promote_typeof(tc, t) == typeof(tc + t)
271271
end

0 commit comments

Comments
 (0)