Skip to content

Commit 1c023a1

Browse files
committed
Use only device memory for now
1 parent d0e8821 commit 1c023a1

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using CUDA: @allowscalar
55
using cuTENSOR: cuTENSOR
66

77
using TensorKit
8+
import TensorKit.VectorInterface: scalartype as vi_scalartype
89
using TensorKit.Factorizations
910
using TensorKit.Factorizations: select_svd_algorithm, OFA, initialize_output, AbstractAlgorithm
1011
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap
@@ -19,7 +20,7 @@ TensorKit.Factorizations.select_svd_algorithm(::CuTensorMap, ::TensorKit.Factori
1920
TensorKit.Factorizations.select_svd_algorithm(::CuTensorMap, ::TensorKit.Factorizations.SDD) = throw(ArgumentError("DivideAndConquer unavailable on CUDA"))
2021
TensorKit.Factorizations.select_svd_algorithm(::CuTensorMap, alg::OFA) = throw(ArgumentError(lazy"Unknown algorithm $alg"))
2122

22-
const CuDiagonalTensorMap{T, S} = DiagonalTensorMap{T, S, CuVector{T}}
23+
const CuDiagonalTensorMap{T, S} = DiagonalTensorMap{T, S, CuVector{T, CUDA.DeviceMemory}}
2324

2425
"""
2526
CuDiagonalTensorMap{T}(undef, domain::S) where {T,S<:IndexSpace}
@@ -82,4 +83,8 @@ end
8283

8384
# TODO
8485
# add VectorInterface extensions for proper CUDA promotion
86+
function TensorKit.VectorInterface.promote_add(TA::Type{<:CUDA.StridedCuMatrix{Tx}}, TB::Type{<:CUDA.StridedCuMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}
87+
return Base.promote_op(add, Tx, Ty, Tα, Tβ)
88+
end
89+
8590
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
const CuTensorMap{T,S,N₁,N₂,A<:CuVector{T}} = TensorMap{T,S,N₁,N₂,A}
2-
const CuTensor{T, S, N, A<:CuVector{T}} = CuTensorMap{T, S, N, 0, A}
1+
const CuTensorMap{T,S,N₁,N₂} = TensorMap{T,S,N₁,N₂, CuVector{T,CUDA.DeviceMemory}}
2+
const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
33

44
function TensorKit.tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type{<:StridedCuArray})
55
if TorA <: CuArray
6-
return TensorMap{eltype(TorA),S,N₁,N₂,CuVector{eltype(TorA)}}
6+
return TensorMap{eltype(TorA),S,N₁,N₂,CuVector{eltype(TorA), CUDA.DeviceMemory}}
77
else
88
throw(ArgumentError("argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:CuVector{<:Number}`"))
99
end
1010
end
1111

1212
function CuTensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
13-
return CuTensorMap{T,S,N₁,N₂,CuVector{T}}(undef, V)
13+
return CuTensorMap{T,S,N₁,N₂}(undef, V)
1414
end
1515

1616
function CuTensorMap{T}(::UndefInitializer, codomain::TensorSpace{S},
@@ -164,14 +164,16 @@ function TensorKit.scalar(t::CuTensorMap)
164164
first(blocks(t))[2][1, 1] : throw(DimensionMismatch())
165165
end
166166

167-
TensorKit.scalartype(A::CuArray{T}) where {T} = T
167+
TensorKit.scalartype(A::StridedCuArray{T}) where {T} = T
168+
vi_scalartype(::Type{<:CuTensorMap{T}}) where {T} = T
169+
vi_scalartype(::Type{<:CuArray{T}}) where {T} = T
168170

169-
function TensorKit.similarstoragetype(TT::Type{<:CuTensorMap}, ::Type{T}) where {T}
170-
return CuVector{T}
171+
function TensorKit.similarstoragetype(TT::Type{<:CuTensorMap{TTT,S,N₁,N₂}}, ::Type{T}) where {TTT,T,S,N₁,N₂}
172+
return CuVector{T, CUDA.DeviceMemory}
171173
end
172174

173-
function Base.convert(TT::Type{CuTensorMap{T,S,N₁,N₂,A}},
174-
t::AbstractTensorMap{<:Any,S,N₁,N₂}) where {T,S,N₁,N₂,A<:CuVector{T}}
175+
function Base.convert(TT::Type{CuTensorMap{T,S,N₁,N₂}},
176+
t::AbstractTensorMap{<:Any,S,N₁,N₂}) where {T,S,N₁,N₂}
175177
if typeof(t) === TT
176178
return t
177179
else
@@ -180,7 +182,7 @@ function Base.convert(TT::Type{CuTensorMap{T,S,N₁,N₂,A}},
180182
end
181183
end
182184

183-
function Base.copy!(tdst::CuTensorMap{T, S, N₁, N₂, A}, tsrc::CuTensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
185+
function Base.copy!(tdst::CuTensorMap{T, S, N₁, N₂}, tsrc::CuTensorMap{T, S, N₁, N₂}) where {T, S, N₁, N₂}
184186
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
185187
for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc))
186188
copy!(bdst, bsrc)
@@ -195,3 +197,11 @@ function Base.copy!(tdst::CuTensorMap, tsrc::TensorKit.AdjointTensorMap)
195197
end
196198
return tdst
197199
end
200+
201+
function Base.promote_rule(::Type{<:TT₁},
202+
::Type{<:TT₂}) where {S,N₁,N₂, TTT₁, TTT₂,
203+
TT₁<:CuTensorMap{TTT₁,S,N₁,N₂},
204+
TT₂<:CuTensorMap{TTT₂,S,N₁,N₂}}
205+
T = TensorKit.VectorInterface.promote_add(TTT₁, TTT₂)
206+
return CuTensorMap{T,S,N₁,N₂}
207+
end

src/tensors/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ function ⊗(t1::AbstractTensorMap, t2::AbstractTensorMap)
545545
m1 = sreshape(t1[f1l, f1r], (d1, 1, d3, 1))
546546
m2 = sreshape(t2[f2l, f2r], (1, d2, 1, d4))
547547
m = sreshape(t[fl, fr], (d1, d2, d3, d4))
548-
@. m += coeff1 * conj(coeff2) * m1 * m2
548+
m .+= coeff1 .* conj.(coeff2) .* m1 .* m2
549549
end
550550
end
551551
end

0 commit comments

Comments
 (0)