11module TensorKitCUDAExt
22
3- using CUDA, CUDA. CUBLAS, LinearAlgebra
3+ using CUDA, CUDA. CUBLAS, CUDA . CUSOLVER, LinearAlgebra
44using CUDA: @allowscalar
55using cuTENSOR: cuTENSOR
66import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
@@ -18,6 +18,70 @@ using Random
1818
1919include (" cutensormap.jl" )
2020
21+ # for ambiguity
22+ function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, 0} ) where {I}
23+ return convert (A, TensorKit. fusiontensor (unit (I), unit (I), unit (I)))[1 , 1 , :]
24+ end
25+ function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, 1} ) where {I}
26+ c = f. coupled
27+ if f. isdual[1 ]
28+ sqrtdc = TensorKit. sqrtdim (c)
29+ Zcbartranspose = sqrtdc * convert (A, TensorKit. fusiontensor (dual (c), c, unit (c)))[:, :, 1 , 1 ]
30+ X = conj! (Zcbartranspose) # we want Zcbar^†
31+ else
32+ X = convert (A, TensorKit. fusiontensor (c, unit (c), c))[:, 1 , :, 1 , 1 ]
33+ end
34+ return X
35+ end
36+ # needed because the Int eltype isn't supported by CuTENSOR
37+ function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, 2} ) where {I}
38+ a, b = f. uncoupled
39+ isduala, isdualb = f. isdual
40+ c = f. coupled
41+ μ = (TensorKit. FusionStyle (I) isa TensorKit. GenericFusion) ? f. vertices[1 ] : 1
42+ C = convert (A, TensorKit. fusiontensor (a, b, c))[:, :, :, μ]
43+ X = C
44+ fX = reinterpret (Float64, X)
45+ if isduala
46+ Za = convert (A, TensorKit. FusionTree ((a,), a, (isduala,), ()))
47+ # reinterpret all these as Float64 since cuTENSOR does not support Int64
48+ fZa = reinterpret (Float64, Za)
49+ @tensor fX[a′, b, c] := fZa[a′, a] * fX[a, b, c]
50+ end
51+ if isdualb
52+ Zb = convert (A, TensorKit. FusionTree ((b,), b, (isdualb,), ()))
53+ fZb = reinterpret (Float64, Zb)
54+ @tensor fX[a, b′, c] := fZb[b′, b] * fX[a, b, c]
55+ end
56+ return X
57+ end
58+
59+ function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, N} ) where {I, N}
60+ tailout = (f. innerlines[1 ], TensorKit. TupleTools. tail2 (f. uncoupled)... )
61+ isdualout = (false , TensorKit. TupleTools. tail2 (f. isdual)... )
62+ ftail = TensorKit. FusionTree (tailout, f. coupled, isdualout, Base. tail (f. innerlines), Base. tail (f. vertices))
63+ Ctail = convert (A, ftail)
64+ f₁ = TensorKit. FusionTree (
65+ (f. uncoupled[1 ], f. uncoupled[2 ]), f. innerlines[1 ],
66+ (f. isdual[1 ], f. isdual[2 ]), (), (f. vertices[1 ],)
67+ )
68+ C1 = convert (A, f₁)
69+ dtail = size (Ctail)
70+ d1 = size (C1)
71+ X = similar (C1, (d1[1 ], d1[2 ], Base. tail (dtail)... ))
72+ trivialtuple = ntuple (identity, Val (N))
73+ # reinterpret all these as Float64 since cuTENSOR does not support Int64
74+ fX = reinterpret (Float64, X)
75+ fC1 = reinterpret (Float64, C1)
76+ fCtail = reinterpret (Float64, Ctail)
77+ TensorKit. TensorOperations. tensorcontract! (
78+ fX,
79+ fC1, ((1 , 2 ), (3 ,)), false ,
80+ fCtail, ((1 ,), Base. tail (trivialtuple)), false ,
81+ ((trivialtuple... , N + 1 ), ())
82+ )
83+ return X
84+ end
2185# TODO
2286# add VectorInterface extensions for proper CUDA promotion
2387function TensorKit. VectorInterface. promote_add (TA:: Type{<:CUDA.StridedCuMatrix{Tx}} , TB:: Type{<:CUDA.StridedCuMatrix{Ty}} , α:: T α = TensorKit. VectorInterface. One (), β:: T β = TensorKit. VectorInterface. One ()) where {Tx, Ty, Tα, Tβ}
0 commit comments