Skip to content

Commit 1f10edc

Browse files
committed
Tensors tests passing
1 parent e1dbf05 commit 1f10edc

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ function TensorKit.tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type{<
1111
end
1212
end
1313

14+
function TensorKit.TensorMap{T, S, N₁, N₂, <:CuVector{T}}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
15+
return CuTensorMap{T, S, N₁, N₂}(CuArray(t.data), t.space)
16+
end
17+
1418
function CuTensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
1519
return CuTensorMap{T, S, N₁, N₂}(undef, V)
1620
end

src/tensors/tensor.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,14 @@ function TensorMap(
361361
end
362362

363363
arraydata = reshape(data, arraysize)
364-
if data isa DenseVector
365-
t = TensorMap{T, S, N₁, N₂, typeof(data)}(undef, codom dom) # to ensure the correct storage type
366-
t = project_symmetric!(t, arraydata)
367-
if !isapprox(Array(arraydata), convert(Array, t); atol = tol)
368-
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
369-
end
370-
elseif data isa DenseArray # can be reshaped into a DenseVector
364+
if data isa DenseArray # can be reshaped into a DenseVector
371365
A = densevectortype(typeof(data))
372-
t = TensorMap{T, S, N₁, N₂, A}(undef, codom dom) # to ensure the correct storage type
373-
t = project_symmetric!(t, arraydata)
366+
t = TensorMap{T, S, N₁, N₂, Vector{T}}(undef, codom dom) # to ensure the correct storage type
367+
t = project_symmetric!(t, Array(arraydata))
374368
if !isapprox(Array(arraydata), convert(Array, t); atol = tol)
375369
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
376370
end
371+
t = TensorMap(A(t.data), t.space)
377372
else
378373
t = TensorMap{T}(undef, codom dom) # to ensure the correct storage type
379374
t = project_symmetric!(t, arraydata)

0 commit comments

Comments
 (0)