Skip to content

Commit 3341852

Browse files
committed
Try using the new constructors PR
1 parent d63caf7 commit 3341852

File tree

4 files changed

+5
-117
lines changed

4 files changed

+5
-117
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 2 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,74 +3,12 @@ const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
33

44
const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, CuTensorMap{T, S, N₁, N₂}}
55

6-
function TensorKit.tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type{<:StridedCuArray})
7-
if TorA <: CuArray
8-
return TensorMap{eltype(TorA), S, N₁, N₂, CuVector{eltype(TorA), CUDA.DeviceMemory}}
9-
else
10-
throw(ArgumentError("argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:CuVector{<:Number}`"))
11-
end
12-
end
6+
TensorKit._tensormap_storagetype(::Type{A}) where {T, A <: CuArray{T}} = CuVector{T, CUDA.DeviceMemory}
137

14-
function TensorKit.TensorMap{T, S, N₁, N₂, <:CuVector{T}}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
8+
function CuTensorMap{T, S, N₁, N₂}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
159
return CuTensorMap{T, S, N₁, N₂}(CuArray(t.data), t.space)
1610
end
1711

18-
function CuTensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
19-
return CuTensorMap{T, S, N₁, N₂}(undef, V)
20-
end
21-
22-
function CuTensorMap{T}(
23-
::UndefInitializer, codomain::TensorSpace{S},
24-
domain::TensorSpace{S}
25-
) where {T, S}
26-
return CuTensorMap{T}(undef, codomain domain)
27-
end
28-
function CuTensor{T}(::UndefInitializer, V::TensorSpace{S}) where {T, S}
29-
return CuTensorMap{T}(undef, V one(V))
30-
end
31-
# constructor starting from block data
32-
"""
33-
CuTensorMap(data::AbstractDict{<:Sector,<:CuMatrix}, codomain::ProductSpace{S,N₁},
34-
domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂}
35-
CuTensorMap(data, codomain ← domain)
36-
CuTensorMap(data, domain → codomain)
37-
38-
Construct a `CuTensorMap` by explicitly specifying its block data.
39-
40-
## Arguments
41-
- `data::AbstractDict{<:Sector,<:CuMatrix}`: dictionary containing the block data for
42-
each coupled sector `c` as a matrix of size `(blockdim(codomain, c), blockdim(domain, c))`.
43-
- `codomain::ProductSpace{S,N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type
44-
`S<:ElementarySpace`.
45-
- `domain::ProductSpace{S,N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type
46-
`S<:ElementarySpace`.
47-
48-
Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
49-
using the syntax `codomain ← domain` or `domain → codomain`.
50-
"""
51-
function CuTensorMap(
52-
data::AbstractDict{<:Sector, <:CuArray},
53-
V::TensorMapSpace{S, N₁, N₂}
54-
) where {S, N₁, N₂}
55-
T = eltype(valtype(data))
56-
t = CuTensorMap{T}(undef, V)
57-
for (c, b) in blocks(t)
58-
haskey(data, c) || throw(SectorMismatch("no data for block sector $c"))
59-
datac = data[c]
60-
size(datac) == size(b) ||
61-
throw(DimensionMismatch("wrong size of block for sector $c"))
62-
copy!(b, datac)
63-
end
64-
for (c, b) in data
65-
c blocksectors(t) || isempty(b) ||
66-
throw(SectorMismatch("data for block sector $c not expected"))
67-
end
68-
return t
69-
end
70-
function CuTensorMap(data::CuArray{T}, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
71-
return CuTensorMap{T, S, N₁, N₂}(vec(data), V)
72-
end
73-
7412
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
7513
@eval begin
7614
function CUDA.$fname(
@@ -215,10 +153,6 @@ end
215153
TensorKit.scalartype(A::StridedCuArray{T}) where {T} = T
216154
TensorKit.scalartype(::Type{<:CuTensorMap{T}}) where {T} = T
217155
TensorKit.scalartype(::Type{<:CuArray{T}}) where {T} = T
218-
TensorKit.densevectortype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: CuVector{T}} = A
219-
TensorKit.densevectortype(::Type{<:CuArray{T}}) where {T} = CuVector{T}
220-
TensorKit.matrixtype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: CuVector{T}} = CuMatrix{T}
221-
TensorKit.matrixtype(::Type{CuArray{T}}) where {T} = CuMatrix{T}
222156

223157
function TensorKit.similarstoragetype(TT::Type{<:CuTensorMap{TTT, S, N₁, N₂}}, ::Type{T}) where {TTT, T, S, N₁, N₂}
224158
return CuVector{T, CUDA.DeviceMemory}
@@ -261,28 +195,6 @@ function Base.promote_rule(
261195
return CuTensorMap{T, S, N₁, N₂}
262196
end
263197

264-
# Conversion to CuArray:
265-
#----------------------
266-
# probably not optimized for speed, only for checking purposes
267-
function Base.convert(::Type{CuArray}, t::AbstractTensorMap)
268-
I = sectortype(t)
269-
if I === Trivial
270-
CUDA.@allowscalar convert(CuArray, t[])
271-
else
272-
cod = codomain(t)
273-
dom = domain(t)
274-
T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
275-
sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
276-
A = CUDA.zeros(T, dims(cod)..., dims(dom)...)
277-
for (f₁, f₂) in fusiontrees(t)
278-
F = convert(CuArray, (f₁, f₂))
279-
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
280-
CUDA.@allowscalar add!(Aslice, StridedView(TensorKit._kron(convert(CuArray, t[f₁, f₂]), F)))
281-
end
282-
return A
283-
end
284-
end
285-
286198
# CuTensorMap exponentation:
287199
function TensorKit.exp!(t::CuTensorMap)
288200
domain(t) == codomain(t) ||

src/tensors/abstracttensor.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,6 @@ end
4545
Return the type of vector that stores the data of a tensor.
4646
""" storagetype
4747

48-
@doc """
49-
matrixtype(t::AbstractTensorMap) -> Type{A<:AbstractVector}
50-
matrixtrype(T::Type{<:AbstractTensorMap}) -> Type{A<:AbstractVector}
51-
52-
Return the type of matrix that stores the data of a tensor, for conversion
53-
to/from dictionaries.
54-
""" matrixtype
55-
5648
similarstoragetype(TT::Type{<:AbstractTensorMap}) = similarstoragetype(TT, scalartype(TT))
5749

5850
function similarstoragetype(TT::Type{<:AbstractTensorMap}, ::Type{T}) where {T}
@@ -183,7 +175,6 @@ end
183175
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
184176
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
185177
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
186-
matrixtype(t::AbstractTensorMap) = matrixtype(typeof(t))
187178
similarstoragetype(t::AbstractTensorMap, T = scalartype(t)) = similarstoragetype(typeof(t), T)
188179

189180
numout(t::AbstractTensorMap) = numout(typeof(t))
@@ -634,8 +625,7 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap)
634625
for (f₁, f₂) in fusiontrees(t)
635626
F = convert(Array, (f₁, f₂))
636627
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
637-
tf₁f₂ = convert(Array, t[f₁, f₂])
638-
add!(Aslice, StridedView(_kron(tf₁f₂, F)))
628+
add!(Aslice, StridedView(_kron(convert(Array, t[f₁, f₂]), F)))
639629
end
640630
return A
641631
end

src/tensors/tensor.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,6 @@ space(t::TensorMap) = t.space
6666
Return the type of the storage `A` of the tensor map.
6767
"""
6868
storagetype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: DenseVector{T}} = A
69-
"""
70-
densevectortype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:Vector}
71-
72-
Return the type of the storage `A` of the tensor map.
73-
"""
74-
densevectortype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}} = A
75-
densevectortype(::Type{<:Array{T}}) where {T} = Vector{T}
76-
77-
"""
78-
matrixtype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:Vector}
79-
Return the matrix analogue type of the storage `A` of the tensor map.
80-
"""
81-
matrixtype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}} = Matrix{T}
8269

8370
dim(t::TensorMap) = length(t.data)
8471

@@ -439,7 +426,7 @@ function Base.convert(::Type{Dict}, t::AbstractTensorMap)
439426
d[:domain] = repr(domain(t))
440427
data = Dict{String, Any}()
441428
for (c, b) in blocks(t)
442-
data[repr(c)] = matrixtype(t)(b)
429+
data[repr(c)] = b
443430
end
444431
d[:data] = data
445432
return d

test/cuda/tensors.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ using TensorKit, Combinatorics
44
ad = adapt(Array)
55
const CUDAExt = Base.get_extension(TensorKit, :TensorKitCUDAExt)
66
@assert !isnothing(CUDAExt)
7-
const CuTensorMap = getglobal(CUDAExt, :CuTensorMap)
87

98
@isdefined(TestSetup) || include("../setup.jl")
109
using .TestSetup
@@ -29,8 +28,8 @@ spacelist = try
2928
(Vtr, VU₁, VSU₂, Vfℤ₂)
3029
end
3130
catch
31+
(Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁) #, VSU₃)
3232
#(Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂) #, VSU₃)
33-
(Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂) #, VSU₃)
3433
end
3534

3635
for V in spacelist

0 commit comments

Comments
 (0)