@@ -18,70 +18,6 @@ using Random
1818
1919include (" cutensormap.jl" )
2020
21- # for ambiguity
22- function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, 0} ) where {I}
23- return convert (A, TensorKit. fusiontensor (unit (I), unit (I), unit (I)))[1 , 1 , :]
24- end
25- function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, 1} ) where {I}
26- c = f. coupled
27- if f. isdual[1 ]
28- sqrtdc = TensorKit. sqrtdim (c)
29- Zcbartranspose = sqrtdc * convert (A, TensorKit. fusiontensor (dual (c), c, unit (c)))[:, :, 1 , 1 ]
30- X = conj! (Zcbartranspose) # we want Zcbar^†
31- else
32- X = convert (A, TensorKit. fusiontensor (c, unit (c), c))[:, 1 , :, 1 , 1 ]
33- end
34- return X
35- end
36- # needed because the Int eltype isn't supported by CuTENSOR
37- function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, 2} ) where {I}
38- a, b = f. uncoupled
39- isduala, isdualb = f. isdual
40- c = f. coupled
41- μ = (TensorKit. FusionStyle (I) isa TensorKit. GenericFusion) ? f. vertices[1 ] : 1
42- C = convert (A, TensorKit. fusiontensor (a, b, c))[:, :, :, μ]
43- X = C
44- fX = reinterpret (Float64, X)
45- if isduala
46- Za = convert (A, TensorKit. FusionTree ((a,), a, (isduala,), ()))
47- # reinterpret all these as Float64 since cuTENSOR does not support Int64
48- fZa = reinterpret (Float64, Za)
49- @tensor fX[a′, b, c] := fZa[a′, a] * fX[a, b, c]
50- end
51- if isdualb
52- Zb = convert (A, TensorKit. FusionTree ((b,), b, (isdualb,), ()))
53- fZb = reinterpret (Float64, Zb)
54- @tensor fX[a, b′, c] := fZb[b′, b] * fX[a, b, c]
55- end
56- return X
57- end
58-
59- function Base. convert (A:: Type{CuArray} , f:: TensorKit.FusionTree{I, N} ) where {I, N}
60- tailout = (f. innerlines[1 ], TensorKit. TupleTools. tail2 (f. uncoupled)... )
61- isdualout = (false , TensorKit. TupleTools. tail2 (f. isdual)... )
62- ftail = TensorKit. FusionTree (tailout, f. coupled, isdualout, Base. tail (f. innerlines), Base. tail (f. vertices))
63- Ctail = convert (A, ftail)
64- f₁ = TensorKit. FusionTree (
65- (f. uncoupled[1 ], f. uncoupled[2 ]), f. innerlines[1 ],
66- (f. isdual[1 ], f. isdual[2 ]), (), (f. vertices[1 ],)
67- )
68- C1 = convert (A, f₁)
69- dtail = size (Ctail)
70- d1 = size (C1)
71- X = similar (C1, (d1[1 ], d1[2 ], Base. tail (dtail)... ))
72- trivialtuple = ntuple (identity, Val (N))
73- # reinterpret all these as Float64 since cuTENSOR does not support Int64
74- fX = reinterpret (Float64, X)
75- fC1 = reinterpret (Float64, C1)
76- fCtail = reinterpret (Float64, Ctail)
77- TensorKit. TensorOperations. tensorcontract! (
78- fX,
79- fC1, ((1 , 2 ), (3 ,)), false ,
80- fCtail, ((1 ,), Base. tail (trivialtuple)), false ,
81- ((trivialtuple... , N + 1 ), ())
82- )
83- return X
84- end
8521# TODO
8622# add VectorInterface extensions for proper CUDA promotion
8723function TensorKit. VectorInterface. promote_add (TA:: Type{<:CUDA.StridedCuMatrix{Tx}} , TB:: Type{<:CUDA.StridedCuMatrix{Ty}} , α:: T α = TensorKit. VectorInterface. One (), β:: T β = TensorKit. VectorInterface. One ()) where {Tx, Ty, Tα, Tβ}
0 commit comments