@@ -3,74 +3,12 @@ const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
33
44const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, CuTensorMap{T, S, N₁, N₂}}
55
6- function TensorKit. tensormaptype (S:: Type{<:IndexSpace} , N₁, N₂, TorA:: Type{<:StridedCuArray} )
7- if TorA <: CuArray
8- return TensorMap{eltype (TorA), S, N₁, N₂, CuVector{eltype (TorA), CUDA. DeviceMemory}}
9- else
10- throw (ArgumentError (" argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:CuVector{<:Number}`" ))
11- end
12- end
6+ TensorKit. _tensormap_storagetype (:: Type{A} ) where {T, A <: CuArray{T} } = CuVector{T, CUDA. DeviceMemory}
137
14- function TensorKit . TensorMap {T, S, N₁, N₂, <:CuVector{T} } (t:: TensorMap{T, S, N₁, N₂, A} ) where {T, S, N₁, N₂, A}
8+ function CuTensorMap {T, S, N₁, N₂} (t:: TensorMap{T, S, N₁, N₂, A} ) where {T, S, N₁, N₂, A}
159 return CuTensorMap {T, S, N₁, N₂} (CuArray (t. data), t. space)
1610end
1711
18- function CuTensorMap {T} (:: UndefInitializer , V:: TensorMapSpace{S, N₁, N₂} ) where {T, S, N₁, N₂}
19- return CuTensorMap {T, S, N₁, N₂} (undef, V)
20- end
21-
22- function CuTensorMap {T} (
23- :: UndefInitializer , codomain:: TensorSpace{S} ,
24- domain:: TensorSpace{S}
25- ) where {T, S}
26- return CuTensorMap {T} (undef, codomain ← domain)
27- end
28- function CuTensor {T} (:: UndefInitializer , V:: TensorSpace{S} ) where {T, S}
29- return CuTensorMap {T} (undef, V ← one (V))
30- end
31- # constructor starting from block data
32- """
33- CuTensorMap(data::AbstractDict{<:Sector,<:CuMatrix}, codomain::ProductSpace{S,N₁},
34- domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂}
35- CuTensorMap(data, codomain ← domain)
36- CuTensorMap(data, domain → codomain)
37-
38- Construct a `CuTensorMap` by explicitly specifying its block data.
39-
40- ## Arguments
41- - `data::AbstractDict{<:Sector,<:CuMatrix}`: dictionary containing the block data for
42- each coupled sector `c` as a matrix of size `(blockdim(codomain, c), blockdim(domain, c))`.
43- - `codomain::ProductSpace{S,N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type
44- `S<:ElementarySpace`.
45- - `domain::ProductSpace{S,N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type
46- `S<:ElementarySpace`.
47-
48- Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
49- using the syntax `codomain ← domain` or `domain → codomain`.
50- """
51- function CuTensorMap (
52- data:: AbstractDict{<:Sector, <:CuArray} ,
53- V:: TensorMapSpace{S, N₁, N₂}
54- ) where {S, N₁, N₂}
55- T = eltype (valtype (data))
56- t = CuTensorMap {T} (undef, V)
57- for (c, b) in blocks (t)
58- haskey (data, c) || throw (SectorMismatch (" no data for block sector $c " ))
59- datac = data[c]
60- size (datac) == size (b) ||
61- throw (DimensionMismatch (" wrong size of block for sector $c " ))
62- copy! (b, datac)
63- end
64- for (c, b) in data
65- c ∈ blocksectors (t) || isempty (b) ||
66- throw (SectorMismatch (" data for block sector $c not expected" ))
67- end
68- return t
69- end
70- function CuTensorMap (data:: CuArray{T} , V:: TensorMapSpace{S, N₁, N₂} ) where {T, S, N₁, N₂}
71- return CuTensorMap {T, S, N₁, N₂} (vec (data), V)
72- end
73-
7412for (fname, felt) in ((:zeros , :zero ), (:ones , :one ))
7513 @eval begin
7614 function CUDA. $fname (
215153TensorKit. scalartype (A:: StridedCuArray{T} ) where {T} = T
216154TensorKit. scalartype (:: Type{<:CuTensorMap{T}} ) where {T} = T
217155TensorKit. scalartype (:: Type{<:CuArray{T}} ) where {T} = T
218- TensorKit. densevectortype (:: Type{<:TensorMap{T, S, N₁, N₂, A}} ) where {T, S, N₁, N₂, A <: CuVector{T} } = A
219- TensorKit. densevectortype (:: Type{<:CuArray{T}} ) where {T} = CuVector{T}
220- TensorKit. matrixtype (:: Type{<:TensorMap{T, S, N₁, N₂, A}} ) where {T, S, N₁, N₂, A <: CuVector{T} } = CuMatrix{T}
221- TensorKit. matrixtype (:: Type{CuArray{T}} ) where {T} = CuMatrix{T}
222156
223157function TensorKit. similarstoragetype (TT:: Type{<:CuTensorMap{TTT, S, N₁, N₂}} , :: Type{T} ) where {TTT, T, S, N₁, N₂}
224158 return CuVector{T, CUDA. DeviceMemory}
@@ -261,28 +195,6 @@ function Base.promote_rule(
261195 return CuTensorMap{T, S, N₁, N₂}
262196end
263197
264- # Conversion to CuArray:
265- # ----------------------
266- # probably not optimized for speed, only for checking purposes
267- function Base. convert (:: Type{CuArray} , t:: AbstractTensorMap )
268- I = sectortype (t)
269- if I === Trivial
270- CUDA. @allowscalar convert (CuArray, t[])
271- else
272- cod = codomain (t)
273- dom = domain (t)
274- T = sectorscalartype (I) <: Complex ? complex (scalartype (t)) :
275- sectorscalartype (I) <: Integer ? scalartype (t) : float (scalartype (t))
276- A = CUDA. zeros (T, dims (cod)... , dims (dom)... )
277- for (f₁, f₂) in fusiontrees (t)
278- F = convert (CuArray, (f₁, f₂))
279- Aslice = StridedView (A)[axes (cod, f₁. uncoupled)... , axes (dom, f₂. uncoupled)... ]
280- CUDA. @allowscalar add! (Aslice, StridedView (TensorKit. _kron (convert (CuArray, t[f₁, f₂]), F)))
281- end
282- return A
283- end
284- end
285-
286198# CuTensorMap exponentation:
287199function TensorKit. exp! (t:: CuTensorMap )
288200 domain (t) == codomain (t) ||
0 commit comments