Skip to content

Commit a2a3e9a

Browse files
committed
CUDA tensors updates
1 parent ec11c6b commit a2a3e9a

File tree

8 files changed

+415
-165
lines changed

8 files changed

+415
-165
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2828
GPUArrays = {rev = "master", url = "https://github.com/JuliaGPU/GPUArrays.jl"}
2929
MatrixAlgebraKit = {rev = "ksh/tk", url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl"}
3030
AMDGPU = {rev = "master", url = "https://github.com/JuliaGPU/AMDGPU.jl"}
31-
cuTENSOR = {subdir = "lib/cutensor", url = "https://github.com/JuliaGPU/CUDA.jl"}
31+
cuTENSOR = {subdir = "lib/cutensor", url = "https://github.com/JuliaGPU/CUDA.jl", rev="master"}
3232

3333
[extensions]
3434
TensorKitAMDGPUExt = "AMDGPU"

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curan
88
using TensorKit
99
import TensorKit.VectorInterface: scalartype as vi_scalartype
1010
using TensorKit.Factorizations
11+
using TensorKit.Strided
1112
using TensorKit.Factorizations: AbstractAlgorithm
1213
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap
14+
import TensorKit: randisometry
1315

1416
using TensorKit.MatrixAlgebraKit
1517

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ function CuTensorMap(
7575
) where {S}
7676
return CuTensorMap(data, codom dom)
7777
end
78+
function CuTensorMap(data::DenseVector{T}, V::TensorMapSpace{S, N₁, N₂}) where {T,S,N₁,N₂}
79+
return CuTensorMap{T, S, N₁, N₂}(data, V)
80+
end
81+
function CuTensorMap(data::CuArray{T}, V::TensorMapSpace{S, N₁, N₂}) where {T,S,N₁,N₂}
82+
return CuTensorMap{T, S, N₁, N₂}(vec(data), V)
83+
end
7884

7985
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
8086
@eval begin
@@ -158,6 +164,55 @@ for randfun in (:curand, :curandn)
158164
end
159165
end
160166

167+
for randfun in (:rand, :randn, :randisometry)
168+
randfun! = Symbol(randfun, :!)
169+
@eval begin
170+
# converting `codomain` and `domain` into `HomSpace`
171+
function $randfun(::Type{A}, codomain::TensorSpace{S},
172+
domain::TensorSpace{S}) where {A <: CuArray, S<:IndexSpace}
173+
return $randfun(A, codomain domain)
174+
end
175+
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace{S},
176+
domain::TensorSpace{S}) where {T,S<:IndexSpace, A<:CuArray{T}}
177+
return $randfun(T, A, codomain domain)
178+
end
179+
function $randfun(rng::Random.AbstractRNG, ::Type{T}, ::Type{A},
180+
codomain::TensorSpace{S},
181+
domain::TensorSpace{S}) where {T,S<:IndexSpace, A<:CuArray{T}}
182+
return $randfun(rng, T, A, codomain domain)
183+
end
184+
185+
# accepting single `TensorSpace`
186+
$randfun(::Type{A}, codomain::TensorSpace) where {A<:CuArray} = $randfun(A, codomain one(codomain))
187+
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace) where {T, A<:CuArray{T}}
188+
return $randfun(T, A, codomain one(codomain))
189+
end
190+
function $randfun(rng::Random.AbstractRNG, ::Type{T},
191+
::Type{A}, codomain::TensorSpace) where {T, A<:CuArray{T}}
192+
return $randfun(rng, T, A, codomain one(domain))
193+
end
194+
195+
# filling in default eltype
196+
$randfun(::Type{A}, V::TensorMapSpace) where {A<:CuArray} = $randfun(eltype(A), A, V)
197+
function $randfun(rng::Random.AbstractRNG, ::Type{A}, V::TensorMapSpace) where {A<:CuArray}
198+
return $randfun(rng, eltype(A), A, V)
199+
end
200+
201+
# filling in default rng
202+
function $randfun(::Type{T}, ::Type{A}, V::TensorMapSpace) where {T, A<:CuArray{T}}
203+
return $randfun(Random.default_rng(), T, A, V)
204+
end
205+
206+
# implementation
207+
function $randfun(rng::Random.AbstractRNG, ::Type{T},
208+
::Type{A}, V::TensorMapSpace) where {T, A<:CuArray{T}}
209+
t = CuTensorMap{T}(undef, V)
210+
$randfun!(rng, t)
211+
return t
212+
end
213+
end
214+
end
215+
161216
# converters
162217
# ----------
163218
function Base.convert(::Type{CuTensorMap}, d::Dict{Symbol, Any})
@@ -250,3 +305,26 @@ function LinearAlgebra.isposdef(t::CuTensorMap)
250305
end
251306
return true
252307
end
308+
309+
310+
# Conversion to CuArray:
311+
#----------------------
312+
# probably not optimized for speed, only for checking purposes
313+
function Base.convert(::Type{CuArray}, t::AbstractTensorMap)
314+
I = sectortype(t)
315+
if I === Trivial
316+
convert(CuArray, t[])
317+
else
318+
cod = codomain(t)
319+
dom = domain(t)
320+
T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
321+
sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
322+
A = CUDA.zeros(T, dims(cod)..., dims(dom)...)
323+
for (f₁, f₂) in fusiontrees(t)
324+
F = convert(CuArray, (f₁, f₂))
325+
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
326+
add!(Aslice, StridedView(TensorKit._kron(convert(CuArray, t[f₁, f₂]), F)))
327+
end
328+
return A
329+
end
330+
end

src/auxiliary/random.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
"""
2-
randisometry([::Type{T}=Float64], dims::Dims{2}) -> Array{T,2}
3-
randhaar([::Type{T}=Float64], dims::Dims{2}) -> Array{T,2}
2+
randisometry([::Type{T}=Float64], [::Type{A}=Matrix{T}], dims::Dims{2}) -> A
3+
randhaar([::Type{T}=Float64], [::Type{A}=Matrix{T}], dims::Dims{2}) -> A
44
55
Create a random isometry of size `dims`, uniformly distributed according to the Haar measure.
66
77
See also [`randuniform`](@ref) and [`randnormal`](@ref).
88
"""
9-
randisometry(dims::Base.Dims{2}) = randisometry(Float64, dims)
9+
randisometry(dims::Base.Dims{2}) = randisometry(Float64, Matrix{Float64}, dims)
1010
function randisometry(::Type{T}, dims::Base.Dims{2}) where {T <: Number}
1111
return randisometry(Random.default_rng(), T, dims)
1212
end
13-
function randisometry(rng::Random.AbstractRNG, ::Type{T}, dims::Base.Dims{2}) where {T <: Number}
14-
return randisometry!(rng, Matrix{T}(undef, dims))
13+
function randisometry(::Type{T}, ::Type{A}, dims::Base.Dims{2}) where {T <: Number, A<:AbstractArray{T}}
14+
return randisometry(Random.default_rng(), T, A, dims)
15+
end
16+
function randisometry(rng::Random.AbstractRNG, ::Type{T}, ::Type{A}, dims::Base.Dims{2}) where {T <: Number, A<:AbstractArray{T}}
17+
return randisometry!(rng, A(undef, dims))
1518
end
1619

1720
randisometry!(A::AbstractMatrix) = randisometry!(Random.default_rng(), A)

src/tensors/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ function _norm(blockiter, p::Real, init::Real)
281281
end
282282
elseif p > 0
283283
nᵖ = mapreduce(+, blockiter; init = init) do (c, b)
284-
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.normp(b, p)^p)
284+
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.norm(b, p)^p)
285285
end
286286
return (nᵖ)^inv(oftype(nᵖ, p))
287287
else
@@ -431,7 +431,7 @@ function exp!(t::TensorMap)
431431
domain(t) == codomain(t) ||
432432
error("Exponential of a tensor only exist when domain == codomain.")
433433
for (c, b) in blocks(t)
434-
copy!(b, LinearAlgebra.exp!(b))
434+
copy!(b, exp!(b))
435435
end
436436
return t
437437
end

src/tensors/tensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ function TensorMap(
357357
) where {S}
358358
return TensorMap(data, codom dom; kwargs...)
359359
end
360-
function Tensor(data::AbstractArray, codom::TensorSpace, ; kwargs...)
360+
function Tensor(data::AbstractArray, codom::TensorSpace; kwargs...)
361361
return TensorMap(data, codom one(codom); kwargs...)
362362
end
363363

0 commit comments

Comments
 (0)