Skip to content

Commit bcbe1eb

Browse files
committed
Comments
1 parent d6aa119 commit bcbe1eb

File tree

6 files changed

+79
-145
lines changed

6 files changed

+79
-145
lines changed

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,4 @@ using Random
1818

1919
include("cutensormap.jl")
2020

21-
# TODO
22-
# add VectorInterface extensions for proper CUDA promotion
23-
function TensorKit.VectorInterface.promote_add(TA::Type{<:CUDA.StridedCuMatrix{Tx}}, TB::Type{<:CUDA.StridedCuMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}
24-
return Base.promote_op(add, Tx, Ty, Tα, Tβ)
25-
end
26-
2721
end

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 6 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
44
const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, CuTensorMap{T, S, N₁, N₂}}
55

66
function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
7-
return CuTensorMap{T, S, N₁, N₂}(CuArray(t.data), t.space)
7+
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
8+
end
9+
10+
function Base.collect(t::CuTensorMap{T}) where {T}
11+
return convert(TensorKit.TensorMapWithStorage{T, Vector{T}}, t)
812
end
913

1014
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
@@ -87,75 +91,11 @@ for randfun in (:curand, :curandn)
8791
end
8892
end
8993

90-
for randfun in (:rand, :randn, :randisometry)
91-
randfun! = Symbol(randfun, :!)
92-
@eval begin
93-
# converting `codomain` and `domain` into `HomSpace`
94-
function $randfun(
95-
::Type{A}, codomain::TensorSpace{S},
96-
domain::TensorSpace{S}
97-
) where {A <: CuArray, S <: IndexSpace}
98-
return $randfun(A, codomain domain)
99-
end
100-
function $randfun(
101-
::Type{T}, ::Type{A}, codomain::TensorSpace{S},
102-
domain::TensorSpace{S}
103-
) where {T, S <: IndexSpace, A <: CuArray{T}}
104-
return $randfun(T, A, codomain domain)
105-
end
106-
function $randfun(
107-
rng::Random.AbstractRNG, ::Type{T}, ::Type{A},
108-
codomain::TensorSpace{S},
109-
domain::TensorSpace{S}
110-
) where {T, S <: IndexSpace, A <: CuArray{T}}
111-
return $randfun(rng, T, A, codomain domain)
112-
end
113-
114-
# accepting single `TensorSpace`
115-
$randfun(::Type{A}, codomain::TensorSpace) where {A <: CuArray} = $randfun(A, codomain one(codomain))
116-
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace) where {T, A <: CuArray{T}}
117-
return $randfun(T, A, codomain one(codomain))
118-
end
119-
function $randfun(
120-
rng::Random.AbstractRNG, ::Type{T},
121-
::Type{A}, codomain::TensorSpace
122-
) where {T, A <: CuArray{T}}
123-
return $randfun(rng, T, A, codomain one(domain))
124-
end
125-
126-
# filling in default eltype
127-
$randfun(::Type{A}, V::TensorMapSpace) where {A <: CuArray} = $randfun(eltype(A), A, V)
128-
function $randfun(rng::Random.AbstractRNG, ::Type{A}, V::TensorMapSpace) where {A <: CuArray}
129-
return $randfun(rng, eltype(A), A, V)
130-
end
131-
132-
# filling in default rng
133-
function $randfun(::Type{T}, ::Type{A}, V::TensorMapSpace) where {T, A <: CuArray{T}}
134-
return $randfun(Random.default_rng(), T, A, V)
135-
end
136-
137-
# implementation
138-
function $randfun(
139-
rng::Random.AbstractRNG, ::Type{T},
140-
::Type{A}, V::TensorMapSpace
141-
) where {T, A <: CuArray{T}}
142-
t = CuTensorMap{T}(undef, V)
143-
$randfun!(rng, t)
144-
return t
145-
end
146-
end
147-
end
148-
149-
function Base.convert(::Type{CuTensorMap}, t::AbstractTensorMap{T, S}) where {T, S}
150-
d_data = CuArray(t.data)
151-
return TensorMap{T}(d_data, t.space)
152-
end
153-
15494
# Scalar implementation
15595
#-----------------------
15696
function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
15797
inds = findall(!iszero, t.data)
158-
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)
98+
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
15999
end
160100

161101
function Base.convert(

src/tensors/diagonal.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,10 @@ function DiagonalTensorMap(t::AbstractTensorMap{T, S, 1, 1}) where {T, S}
7878
return d
7979
end
8080

81-
Base.similar(d::DiagonalTensorMap) = similar_diagonal(d)
82-
Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T} = similar_diagonal(d, T)
83-
84-
similar_diagonal(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
85-
similar_diagonal(d::DiagonalTensorMap, ::Type{T}) where {T <: Number} =
86-
DiagonalTensorMap(similar(d.data, T), d.domain)
81+
Base.similar(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
82+
function Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T <: Number}
83+
return DiagonalTensorMap(similar(d.data, T), d.domain)
84+
end
8785

8886
# TODO: more constructors needed?
8987

@@ -273,7 +271,7 @@ function LinearAlgebra.mul!(
273271
dC::DiagonalTensorMap, dA::DiagonalTensorMap, dB::DiagonalTensorMap, α::Number, β::Number
274272
)
275273
dC.domain == dA.domain == dB.domain || throw(SpaceMismatch())
276-
@. dC.data =* dA.data * dB.data) + β * dC.data
274+
mul!(Diagonal(dC.data), Diagonal(dA.data), Diagonal(dB.data), α, β)
277275
return dC
278276
end
279277

src/tensors/linalg.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,11 @@ function _norm(blockiter, p::Real, init::Real)
270270
return mapreduce(max, blockiter; init = init) do (c, b)
271271
return isempty(b) ? init : oftype(init, LinearAlgebra.normInf(b))
272272
end
273-
elseif p == 2
274-
= mapreduce(+, blockiter; init = init) do (c, b)
275-
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.norm(b, 2)^2)
273+
elseif p > 0 # finite positive p
274+
np = sum(blockiter; init) do (c, b)
275+
return oftype(init, dim(c) * norm(b, p)^p)
276276
end
277-
return sqrt(n²)
278-
elseif p == 1
279-
return mapreduce(+, blockiter; init = init) do (c, b)
280-
return isempty(b) ? init : oftype(init, dim(c) * sum(abs, b))
281-
end
282-
elseif p > 0
283-
nᵖ = mapreduce(+, blockiter; init = init) do (c, b)
284-
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.norm(b, p)^p)
285-
end
286-
return (nᵖ)^inv(oftype(nᵖ, p))
277+
return np^(inv(oftype(np, p)))
287278
else
288279
msg = "Norm with non-positive p is not defined for `AbstractTensorMap`"
289280
throw(ArgumentError(msg))
@@ -317,8 +308,8 @@ function LinearAlgebra.cond(t::AbstractTensorMap, p::Real = 2)
317308
return zero(real(float(scalartype(t))))
318309
end
319310
S = LinearAlgebra.svdvals(t)
320-
maxS = maximum(S.data)
321-
minS = minimum(S.data)
311+
maxS = maximum(parent(S))
312+
minS = minimum(parent(S))
322313
return iszero(maxS) ? oftype(maxS, Inf) : (maxS / minS)
323314
else
324315
throw(ArgumentError("cond currently only defined for p=2"))

src/tensors/tensor.jl

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,14 @@ end
316316

317317
for randf in (:rand, :randn, :randexp, :randisometry)
318318
_docstr = """
319-
$randf([rng=default_rng()], [T=Float64], codomain::ProductSpace{S,N₁},
319+
$randf([rng=default_rng()], [TorA=Float64], codomain::ProductSpace{S,N₁},
320320
domain::ProductSpace{S,N₂}) where {S,N₁,N₂,T} -> t
321-
$randf([rng=default_rng()], [T=Float64], codomain ← domain) -> t
321+
$randf([rng=default_rng()], [TorA=Float64], codomain ← domain) -> t
322322
323323
Generate a tensor `t` with entries generated by `$randf`.
324+
The type `TorA` can be used to control the element type and
325+
data type generated. For example, if `TorA` is a `SparseVector{ComplexF32}`,
326+
then the final output `TensorMap` will have that as its storage type.
324327
325328
See also [`Random.$(randf)!`](@ref).
326329
"""
@@ -349,25 +352,25 @@ for randf in (:rand, :randn, :randexp, :randisometry)
349352
return $randfun(codomain domain)
350353
end
351354
function $randfun(
352-
::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S}
353-
) where {T, S <: IndexSpace}
354-
return $randfun(T, codomain domain)
355+
::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S}
356+
) where {TorA, S <: IndexSpace}
357+
return $randfun(TorA, codomain domain)
355358
end
356359
function $randfun(
357-
rng::Random.AbstractRNG, ::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S}
358-
) where {T, S <: IndexSpace}
359-
return $randfun(rng, T, codomain domain)
360+
rng::Random.AbstractRNG, ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S}
361+
) where {TorA, S <: IndexSpace}
362+
return $randfun(rng, TorA, codomain domain)
360363
end
361364

362365
# accepting single `TensorSpace`
363366
$randfun(codomain::TensorSpace) = $randfun(codomain one(codomain))
364-
function $randfun(::Type{T}, codomain::TensorSpace) where {T}
365-
return $randfun(T, codomain one(codomain))
367+
function $randfun(::Type{TorA}, codomain::TensorSpace) where {TorA}
368+
return $randfun(TorA, codomain one(codomain))
366369
end
367370
function $randfun(
368-
rng::Random.AbstractRNG, ::Type{T}, codomain::TensorSpace
369-
) where {T}
370-
return $randfun(rng, T, codomain one(domain))
371+
rng::Random.AbstractRNG, ::Type{TorA}, codomain::TensorSpace
372+
) where {TorA}
373+
return $randfun(rng, TorA, codomain one(domain))
371374
end
372375

373376
# filling in default eltype
@@ -377,16 +380,16 @@ for randf in (:rand, :randn, :randexp, :randisometry)
377380
end
378381

379382
# filling in default rng
380-
function $randfun(::Type{T}, V::TensorMapSpace) where {T}
381-
return $randfun(Random.default_rng(), T, V)
383+
function $randfun(::Type{TorA}, V::TensorMapSpace) where {TorA}
384+
return $randfun(Random.default_rng(), TorA, V)
382385
end
383386
$randfun!(t::AbstractTensorMap) = $randfun!(Random.default_rng(), t)
384387

385388
# implementation
386389
function $randfun(
387-
rng::Random.AbstractRNG, ::Type{T}, V::TensorMapSpace
388-
) where {T}
389-
t = TensorMap{T}(undef, V)
390+
rng::Random.AbstractRNG, ::Type{TorA}, V::TensorMapSpace
391+
) where {TorA}
392+
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
390393
$randfun!(rng, t)
391394
return t
392395
end
@@ -406,18 +409,22 @@ Base.copy(t::TensorMap) = typeof(t)(copy(t.data), t.space)
406409

407410
# Conversion between TensorMap and Dict, for read and write purpose
408411
#------------------------------------------------------------------
412+
# We want to store the block data using simple data types,
413+
# rather tha reshaped views or some other wrapped array type.
414+
# Since this method is meant for storing data on disk, we can
415+
# freely collect data to the CPU
409416
function Base.convert(::Type{Dict}, t::AbstractTensorMap)
410417
d = Dict{Symbol, Any}()
411418
d[:codomain] = repr(codomain(t))
412419
d[:domain] = repr(domain(t))
413420
data = Dict{String, Any}()
414421
for (c, b) in blocks(t)
415-
data[repr(c)] = b
422+
data[repr(c)] = Array(b)
416423
end
417424
d[:data] = data
418425
return d
419426
end
420-
function Base.convert(::Type{<:TensorMap}, d::Dict{Symbol, Any})
427+
function Base.convert(::Type{TensorMap}, d::Dict{Symbol, Any})
421428
try
422429
codomain = eval(Meta.parse(d[:codomain]))
423430
domain = eval(Meta.parse(d[:domain]))
@@ -522,6 +529,11 @@ function Base.convert(::Type{TensorMap}, t::AbstractTensorMap)
522529
return copy!(TensorMap{scalartype(t)}(undef, space(t)), t)
523530
end
524531

532+
function Base.convert(::Type{TensorMapWithStorage{T, A}}, t::TensorMap) where {T, A}
533+
d_data = convert(A, t.data)
534+
return TensorMapWithStorage{T, A}(d_data, space(t))
535+
end
536+
525537
function Base.convert(
526538
TT::Type{TensorMap{T, S, N₁, N₂, A}}, t::AbstractTensorMap{<:Any, S, N₁, N₂}
527539
) where {T, S, N₁, N₂, A}
@@ -536,7 +548,7 @@ end
536548
function Base.promote_rule(
537549
::Type{<:TT₁}, ::Type{<:TT₂}
538550
) where {S, N₁, N₂, TT₁ <: TensorMap{<:Any, S, N₁, N₂}, TT₂ <: TensorMap{<:Any, S, N₁, N₂}}
539-
A = VectorInterface.promote_add(storagetype(TT₁), storagetype(TT₂))
540-
T = scalartype(A)
541-
return TensorMap{T, S, N₁, N₂, A}
551+
T = VectorInterface.promote_add(scalartype(TT₁), scalartype(TT₂))
552+
A = promote_storagetype(similarstoragetype(TT₁, T), similarstoragetype(TT₂, T))
553+
return tensormaptype(S, N₁, N₂, A)
542554
end

0 commit comments

Comments
 (0)