@@ -75,6 +75,12 @@ function CuTensorMap(
7575 ) where {S}
7676 return CuTensorMap(data, codom ← dom)
7777end
78+ function CuTensorMap(data:: DenseVector{T} , V:: TensorMapSpace{S, N₁, N₂} ) where {T,S,N₁,N₂}
79+ return CuTensorMap{T, S, N₁, N₂}(data, V)
80+ end
81+ function CuTensorMap(data:: CuArray{T} , V:: TensorMapSpace{S, N₁, N₂} ) where {T,S,N₁,N₂}
82+ return CuTensorMap{T, S, N₁, N₂}(vec(data), V)
83+ end
7884
7985for (fname, felt) in ((:zeros, :zero), (:ones, :one))
8086 @eval begin
@@ -158,6 +164,55 @@ for randfun in (:curand, :curandn)
158164 end
159165end
160166
167+ for randfun in (:rand, :randn, :randisometry)
168+ randfun! = Symbol(randfun, :! )
169+ @eval begin
170+ # converting `codomain` and `domain` into `HomSpace`
171+ function $ randfun(:: Type{A} , codomain:: TensorSpace{S} ,
172+ domain:: TensorSpace{S} ) where {A <: CuArray , S<: IndexSpace }
173+ return $ randfun(A, codomain ← domain)
174+ end
175+ function $ randfun(:: Type{T} , :: Type{A} , codomain:: TensorSpace{S} ,
176+ domain:: TensorSpace{S} ) where {T,S<: IndexSpace , A<: CuArray{T} }
177+ return $ randfun(T, A, codomain ← domain)
178+ end
179+ function $ randfun(rng:: Random.AbstractRNG , :: Type{T} , :: Type{A} ,
180+ codomain:: TensorSpace{S} ,
181+ domain:: TensorSpace{S} ) where {T,S<: IndexSpace , A<: CuArray{T} }
182+ return $ randfun(rng, T, A, codomain ← domain)
183+ end
184+
185+ # accepting single `TensorSpace`
186+ $ randfun(:: Type{A} , codomain:: TensorSpace ) where {A<: CuArray } = $ randfun(A, codomain ← one(codomain))
187+ function $ randfun(:: Type{T} , :: Type{A} , codomain:: TensorSpace ) where {T, A<: CuArray{T} }
188+ return $ randfun(T, A, codomain ← one(codomain))
189+ end
190+ function $ randfun(rng:: Random.AbstractRNG , :: Type{T} ,
191+ :: Type{A} , codomain:: TensorSpace ) where {T, A<: CuArray{T} }
192+ return $ randfun(rng, T, A, codomain ← one(domain))
193+ end
194+
195+ # filling in default eltype
196+ $ randfun(:: Type{A} , V:: TensorMapSpace ) where {A<: CuArray } = $ randfun(eltype(A), A, V)
197+ function $ randfun(rng:: Random.AbstractRNG , :: Type{A} , V:: TensorMapSpace ) where {A<: CuArray }
198+ return $ randfun(rng, eltype(A), A, V)
199+ end
200+
201+ # filling in default rng
202+ function $ randfun(:: Type{T} , :: Type{A} , V:: TensorMapSpace ) where {T, A<: CuArray{T} }
203+ return $ randfun(Random. default_rng(), T, A, V)
204+ end
205+
206+ # implementation
207+ function $ randfun(rng:: Random.AbstractRNG , :: Type{T} ,
208+ :: Type{A} , V:: TensorMapSpace ) where {T, A<: CuArray{T} }
209+ t = CuTensorMap{T}(undef, V)
210+ $ randfun!(rng, t)
211+ return t
212+ end
213+ end
214+ end
215+
161216# converters
162217# ----------
163218function Base. convert(:: Type{CuTensorMap} , d:: Dict{Symbol, Any} )
@@ -250,3 +305,26 @@ function LinearAlgebra.isposdef(t::CuTensorMap)
250305 end
251306 return true
252307end
308+
309+
310+ # Conversion to CuArray:
311+ # ----------------------
312+ # probably not optimized for speed, only for checking purposes
313+ function Base. convert(:: Type{CuArray} , t:: AbstractTensorMap )
314+ I = sectortype(t)
315+ if I === Trivial
316+ convert(CuArray, t[])
317+ else
318+ cod = codomain(t)
319+ dom = domain(t)
320+ T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
321+ sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
322+ A = CUDA. zeros(T, dims(cod). .. , dims(dom). .. )
323+ for (f₁, f₂) in fusiontrees(t)
324+ F = convert(CuArray, (f₁, f₂))
325+ Aslice = StridedView(A)[axes(cod, f₁. uncoupled). .. , axes(dom, f₂. uncoupled). .. ]
326+ add!(Aslice, StridedView(TensorKit. _kron(convert(CuArray, t[f₁, f₂]), F)))
327+ end
328+ return A
329+ end
330+ end
0 commit comments