Skip to content

Commit 2ee3e2f

Browse files
committed
Further updates for GPU tests
1 parent e8b3e76 commit 2ee3e2f

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
module TensorKitCUDAExt
22

3-
using LinearAlgebra
4-
5-
using CUDA
3+
using CUDA, CUDA.CUBLAS, LinearAlgebra
64
using CUDA: @allowscalar
7-
using CUDA.CUBLAS # for LinearAlgebra tie-ins
85
using cuTENSOR: cuTENSOR
96

107
using TensorKit
@@ -81,4 +78,8 @@ function TensorKit.Factorizations.initialize_output(::typeof(eig_vals!), t::CuTe
8178
Tc = complex(scalartype(t))
8279
return D = CuDiagonalTensorMap{Tc}(undef, V_D)
8380
end
81+
82+
83+
# TODO
84+
# add VectorInterface extensions for proper CUDA promotion
8485
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ function TensorKit.scalar(t::CuTensorMap)
160160
first(blocks(t))[2][1, 1] : throw(DimensionMismatch())
161161
end
162162

163+
TensorKit.scalartype(A::CuArray{T}) where {T} = T
164+
163165
function TensorKit.similarstoragetype(TT::Type{<:CuTensorMap}, ::Type{T}) where {T}
164166
return CuVector{T}
165167
end

test/cuda.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ for V in spacelist
445445
d = LinearAlgebra.eigvals(t; sortby=nothing)
446446
d′ = LinearAlgebra.diag(D)
447447
for (c, b) in d
448-
@test b d′[c]
448+
@test b d′[c]
449449
end
450450

451451
# Somehow moving these test before the previous one gives rise to errors

0 commit comments

Comments
 (0)