Skip to content

Commit d6165f0

Browse files
committed
Remove unnecessary fusiontree conversions
1 parent 3341852 commit d6165f0

File tree

1 file changed

+0
-64
lines changed

1 file changed

+0
-64
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,70 +18,6 @@ using Random
1818

1919
include("cutensormap.jl")
2020

21-
# for ambiguity
22-
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, 0}) where {I}
23-
return convert(A, TensorKit.fusiontensor(unit(I), unit(I), unit(I)))[1, 1, :]
24-
end
25-
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, 1}) where {I}
26-
c = f.coupled
27-
if f.isdual[1]
28-
sqrtdc = TensorKit.sqrtdim(c)
29-
Zcbartranspose = sqrtdc * convert(A, TensorKit.fusiontensor(dual(c), c, unit(c)))[:, :, 1, 1]
30-
X = conj!(Zcbartranspose) # we want Zcbar^†
31-
else
32-
X = convert(A, TensorKit.fusiontensor(c, unit(c), c))[:, 1, :, 1, 1]
33-
end
34-
return X
35-
end
36-
# needed because the Int eltype isn't supported by CuTENSOR
37-
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, 2}) where {I}
38-
a, b = f.uncoupled
39-
isduala, isdualb = f.isdual
40-
c = f.coupled
41-
μ = (TensorKit.FusionStyle(I) isa TensorKit.GenericFusion) ? f.vertices[1] : 1
42-
C = convert(A, TensorKit.fusiontensor(a, b, c))[:, :, :, μ]
43-
X = C
44-
fX = reinterpret(Float64, X)
45-
if isduala
46-
Za = convert(A, TensorKit.FusionTree((a,), a, (isduala,), ()))
47-
# reinterpret all these as Float64 since cuTENSOR does not support Int64
48-
fZa = reinterpret(Float64, Za)
49-
@tensor fX[a′, b, c] := fZa[a′, a] * fX[a, b, c]
50-
end
51-
if isdualb
52-
Zb = convert(A, TensorKit.FusionTree((b,), b, (isdualb,), ()))
53-
fZb = reinterpret(Float64, Zb)
54-
@tensor fX[a, b′, c] := fZb[b′, b] * fX[a, b, c]
55-
end
56-
return X
57-
end
58-
59-
function Base.convert(A::Type{CuArray}, f::TensorKit.FusionTree{I, N}) where {I, N}
60-
tailout = (f.innerlines[1], TensorKit.TupleTools.tail2(f.uncoupled)...)
61-
isdualout = (false, TensorKit.TupleTools.tail2(f.isdual)...)
62-
ftail = TensorKit.FusionTree(tailout, f.coupled, isdualout, Base.tail(f.innerlines), Base.tail(f.vertices))
63-
Ctail = convert(A, ftail)
64-
f₁ = TensorKit.FusionTree(
65-
(f.uncoupled[1], f.uncoupled[2]), f.innerlines[1],
66-
(f.isdual[1], f.isdual[2]), (), (f.vertices[1],)
67-
)
68-
C1 = convert(A, f₁)
69-
dtail = size(Ctail)
70-
d1 = size(C1)
71-
X = similar(C1, (d1[1], d1[2], Base.tail(dtail)...))
72-
trivialtuple = ntuple(identity, Val(N))
73-
# reinterpret all these as Float64 since cuTENSOR does not support Int64
74-
fX = reinterpret(Float64, X)
75-
fC1 = reinterpret(Float64, C1)
76-
fCtail = reinterpret(Float64, Ctail)
77-
TensorKit.TensorOperations.tensorcontract!(
78-
fX,
79-
fC1, ((1, 2), (3,)), false,
80-
fCtail, ((1,), Base.tail(trivialtuple)), false,
81-
((trivialtuple..., N + 1), ())
82-
)
83-
return X
84-
end
8521
# TODO
8622
# add VectorInterface extensions for proper CUDA promotion
8723
function TensorKit.VectorInterface.promote_add(TA::Type{<:CUDA.StridedCuMatrix{Tx}}, TB::Type{<:CUDA.StridedCuMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}

0 commit comments

Comments
 (0)