Skip to content

Commit 08d000a

Browse files
committed
Incremental
1 parent 08bc705 commit 08d000a

File tree

5 files changed

+210
-103
lines changed

5 files changed

+210
-103
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module TensorKitCUDAExt
22

3-
using CUDA, CUDA.CUBLAS, LinearAlgebra
3+
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
44
using CUDA: @allowscalar
55
using cuTENSOR: cuTENSOR
66
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
@@ -18,6 +18,70 @@ using Random
1818

1919
include("cutensormap.jl")
2020

21+
# for ambiguity
22+
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, 0}) where {I}
23+
return convert(A, TensorKit.fusiontensor(unit(I), unit(I), unit(I)))[1, 1, :]
24+
end
25+
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, 1}) where {I}
26+
c = f.coupled
27+
if f.isdual[1]
28+
sqrtdc = TensorKit.sqrtdim(c)
29+
Zcbartranspose = sqrtdc * convert(A, TensorKit.fusiontensor(dual(c), c, unit(c)))[:, :, 1, 1]
30+
X = conj!(Zcbartranspose) # we want Zcbar^†
31+
else
32+
X = convert(A, TensorKit.fusiontensor(c, unit(c), c))[:, 1, :, 1, 1]
33+
end
34+
return X
35+
end
36+
# needed because the Int eltype isn't supported by CuTENSOR
37+
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, 2}) where {I}
38+
a, b = f.uncoupled
39+
isduala, isdualb = f.isdual
40+
c = f.coupled
41+
μ = (TensorKit.FusionStyle(I) isa TensorKit.GenericFusion) ? f.vertices[1] : 1
42+
C = convert(A, TensorKit.fusiontensor(a, b, c))[:, :, :, μ]
43+
X = C
44+
fX = reinterpret(Float64, X)
45+
if isduala
46+
Za = convert(A, TensorKit.FusionTree((a,), a, (isduala,), ()))
47+
# reinterpret all these as Float64 since cuTENSOR does not support Int64
48+
fZa = reinterpret(Float64, Za)
49+
@tensor fX[a′, b, c] := fZa[a′, a] * fX[a, b, c]
50+
end
51+
if isdualb
52+
Zb = convert(A, TensorKit.FusionTree((b,), b, (isdualb,), ()))
53+
fZb = reinterpret(Float64, Zb)
54+
@tensor fX[a, b′, c] := fZb[b′, b] * fX[a, b, c]
55+
end
56+
return X
57+
end
58+
59+
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, N}) where {I, N}
60+
tailout = (f.innerlines[1], TensorKit.TupleTools.tail2(f.uncoupled)...)
61+
isdualout = (false, TensorKit.TupleTools.tail2(f.isdual)...)
62+
ftail = TensorKit.FusionTree(tailout, f.coupled, isdualout, Base.tail(f.innerlines), Base.tail(f.vertices))
63+
Ctail = convert(A, ftail)
64+
f₁ = TensorKit.FusionTree(
65+
(f.uncoupled[1], f.uncoupled[2]), f.innerlines[1],
66+
(f.isdual[1], f.isdual[2]), (), (f.vertices[1],)
67+
)
68+
C1 = convert(A, f₁)
69+
dtail = size(Ctail)
70+
d1 = size(C1)
71+
X = similar(C1, (d1[1], d1[2], Base.tail(dtail)...))
72+
trivialtuple = ntuple(identity, Val(N))
73+
# reinterpret all these as Float64 since cuTENSOR does not support Int64
74+
fX = reinterpret(Float64, X)
75+
fC1 = reinterpret(Float64, C1)
76+
fCtail = reinterpret(Float64, Ctail)
77+
TensorKit.TensorOperations.tensorcontract!(
78+
fX,
79+
fC1, ((1, 2), (3,)), false,
80+
fCtail, ((1,), Base.tail(trivialtuple)), false,
81+
((trivialtuple..., N + 1), ())
82+
)
83+
return X
84+
end
2185
# TODO
2286
# add VectorInterface extensions for proper CUDA promotion
2387
function TensorKit.VectorInterface.promote_add(TA::Type{<:CUDA.StridedCuMatrix{Tx}}, TB::Type{<:CUDA.StridedCuMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ function TensorKit.tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type{<
1111
end
1212
end
1313

14-
TensorKit.matrixtype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: CuVector{T}} = CuMatrix{T}
15-
1614
function CuTensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
1715
return CuTensorMap{T, S, N₁, N₂}(undef, V)
1816
end
@@ -213,6 +211,10 @@ end
213211
TensorKit.scalartype(A::StridedCuArray{T}) where {T} = T
214212
TensorKit.scalartype(::Type{<:CuTensorMap{T}}) where {T} = T
215213
TensorKit.scalartype(::Type{<:CuArray{T}}) where {T} = T
214+
TensorKit.densevectortype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: CuVector{T}} = A
215+
TensorKit.densevectortype(::Type{<:CuArray{T}}) where {T} = CuVector{T}
216+
TensorKit.matrixtype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: CuVector{T}} = CuMatrix{T}
217+
TensorKit.matrixtype(::Type{CuArray{T}}) where {T} = CuMatrix{T}
216218

217219
function TensorKit.similarstoragetype(TT::Type{<:CuTensorMap{TTT, S, N₁, N₂}}, ::Type{T}) where {TTT, T, S, N₁, N₂}
218220
return CuVector{T, CUDA.DeviceMemory}
@@ -261,7 +263,7 @@ end
261263
function Base.convert(::Type{CuArray}, t::AbstractTensorMap)
262264
I = sectortype(t)
263265
if I === Trivial
264-
convert(CuArray, t[])
266+
CUDA.@allowscalar convert(CuArray, t[])
265267
else
266268
cod = codomain(t)
267269
dom = domain(t)
@@ -271,8 +273,33 @@ function Base.convert(::Type{CuArray}, t::AbstractTensorMap)
271273
for (f₁, f₂) in fusiontrees(t)
272274
F = convert(CuArray, (f₁, f₂))
273275
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
274-
add!(Aslice, StridedView(TensorKit._kron(convert(CuArray, t[f₁, f₂]), F)))
276+
CUDA.@allowscalar add!(Aslice, StridedView(TensorKit._kron(convert(CuArray, t[f₁, f₂]), F)))
275277
end
276278
return A
277279
end
278280
end
281+
282+
# CuTensorMap exponentation:
283+
function TensorKit.exp!(t::CuTensorMap)
284+
domain(t) == codomain(t) ||
285+
error("Exponential of a tensor only exist when domain == codomain.")
286+
for (c, b) in blocks(t)
287+
copy!(b, parent(Base.exp(Hermitian(b))))
288+
end
289+
return t
290+
end
291+
292+
# functions that don't map ℝ to (a subset of) ℝ
293+
for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
294+
sf = string(f)
295+
@eval function Base.$f(t::CuTensorMap)
296+
domain(t) == codomain(t) ||
297+
throw(SpaceMismatch("`$($sf)` of a tensor only exist when domain == codomain"))
298+
T = complex(float(scalartype(t)))
299+
tf = similar(t, T)
300+
for (c, b) in blocks(t)
301+
copy!(block(tf, c), parent($f(Hermitian(b))))
302+
end
303+
return tf
304+
end
305+
end

src/tensors/abstracttensor.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ Return the type of vector that stores the data of a tensor.
4747

4848
@doc """
4949
matrixtype(t::AbstractTensorMap) -> Type{A<:AbstractVector}
50-
matrixtype(T::Type{<:AbstractTensorMap}) -> Type{A<:AbstractVector}
50+
matrixtrype(T::Type{<:AbstractTensorMap}) -> Type{A<:AbstractVector}
5151
52-
Return the type of **matrix** that stores the data of a tensor.
52+
Return the type of matrix that stores the data of a tensor, for conversion
53+
to/from dictionaries.
5354
""" matrixtype
5455

5556
similarstoragetype(TT::Type{<:AbstractTensorMap}) = similarstoragetype(TT, scalartype(TT))
@@ -181,8 +182,8 @@ end
181182
#------------------------------------------------------------
182183
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
183184
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
184-
matrixtype(t::AbstractTensorMap) = matrixtype(typeof(t))
185185
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
186+
matrixtype(t::AbstractTensorMap) = matrixtype(typeof(t))
186187
similarstoragetype(t::AbstractTensorMap, T = scalartype(t)) = similarstoragetype(typeof(t), T)
187188

188189
numout(t::AbstractTensorMap) = numout(typeof(t))
@@ -633,7 +634,8 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
633634
for (f₁, f₂) in fusiontrees(t)
634635
F = convert(Array, (f₁, f₂))
635636
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
636-
add!(Aslice, StridedView(_kron(convert(Array, t[f₁, f₂]), F)))
637+
tf₁f₂ = convert(Array, t[f₁, f₂])
638+
add!(Aslice, StridedView(_kron(tf₁f₂, F)))
637639
end
638640
return A
639641
end

src/tensors/tensor.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,17 @@ Return the type of the storage `A` of the tensor map.
6767
"""
6868
storagetype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: DenseVector{T}} = A
6969
"""
70-
matrixtype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:Vector}
70+
densevectortype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:Vector}
7171
7272
Return the type of the storage `A` of the tensor map.
7373
"""
74+
densevectortype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}} = A
75+
densevectortype(::Type{<:Array{T}}) where {T} = Vector{T}
76+
77+
"""
78+
matrixtype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:Vector}
79+
Return the matrix analogue type of the storage `A` of the tensor map.
80+
"""
7481
matrixtype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}} = Matrix{T}
7582

7683
dim(t::TensorMap) = length(t.data)

0 commit comments

Comments
 (0)