Skip to content

Commit 2c6af42

Browse files
committed
Start on CUDA extension
1 parent 155aa89 commit 2c6af42

File tree

9 files changed

+860
-63
lines changed

9 files changed

+860
-63
lines changed

.buildkite/pipeline.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ steps:
1515
queue: "juliagpu"
1616
cuda: "*"
1717
if: build.message !~ /\[skip tests\]/
18-
timeout_in_minutes: 30
18+
timeout_in_minutes: 60
1919
matrix:
2020
setup:
2121
julia:
2222
- "1.10"
23-
- "1.11"
23+
- "1.12"
2424

2525
- label: "Julia {{matrix.julia}} -- AMDGPU"
2626
plugins:
@@ -36,9 +36,9 @@ steps:
3636
rocm: "*"
3737
rocmgpu: "*"
3838
if: build.message !~ /\[skip tests\]/
39-
timeout_in_minutes: 30
39+
timeout_in_minutes: 60
4040
matrix:
4141
setup:
4242
julia:
4343
- "1.10"
44-
- "1.11"
44+
- "1.12"

Project.toml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,26 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1818
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1919

2020
[weakdeps]
21+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2122
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2223
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
24+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2325

2426
[extensions]
27+
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
2528
TensorKitChainRulesCoreExt = "ChainRulesCore"
2629
TensorKitFiniteDifferencesExt = "FiniteDifferences"
2730

2831
[compat]
32+
Adapt = "4"
2933
Aqua = "0.6, 0.7, 0.8"
3034
ArgParse = "1.2.0"
35+
CUDA = "5.9"
3136
ChainRulesCore = "1"
3237
ChainRulesTestUtils = "1"
3338
Combinatorics = "1"
3439
FiniteDifferences = "0.12"
40+
GPUArrays = "11.3.1"
3541
LRUCache = "1.0.2"
3642
LinearAlgebra = "1"
3743
MatrixAlgebraKit = "0.6.0"
@@ -48,21 +54,26 @@ TestExtras = "0.2,0.3"
4854
TupleTools = "1.1"
4955
VectorInterface = "0.4.8, 0.5"
5056
Zygote = "0.7"
57+
cuTENSOR = "2"
5158
julia = "1.10"
5259

5360
[extras]
54-
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
61+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
5562
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
63+
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
64+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5665
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5766
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5867
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5968
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
69+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
6070
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6171
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
6272
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
6373
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6474
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
6575
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
76+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
6677

6778
[targets]
68-
test = ["ArgParse", "Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "SafeTestsets", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
79+
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module TensorKitCUDAExt
2+
3+
using CUDA, CUDA.CUBLAS, CUDA.CUSOLVER, LinearAlgebra
4+
using CUDA: @allowscalar
5+
using cuTENSOR: cuTENSOR
6+
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
7+
8+
using TensorKit
9+
using TensorKit.Factorizations
10+
using TensorKit.Strided
11+
using TensorKit.Factorizations: AbstractAlgorithm
12+
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13+
import TensorKit: randisometry, rand, randn
14+
15+
using TensorKit.MatrixAlgebraKit
16+
17+
using Random
18+
19+
include("cutensormap.jl")
20+
21+
end
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
const CuTensorMap{T, S, N₁, N₂} = TensorMap{T, S, N₁, N₂, CuVector{T, CUDA.DeviceMemory}}
2+
const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
3+
4+
const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, CuTensorMap{T, S, N₁, N₂}}
5+
6+
function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A}
7+
return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t))
8+
end
9+
10+
function Base.collect(t::CuTensorMap{T}) where {T}
11+
return convert(TensorKit.TensorMapWithStorage{T, Vector{T}}, t)
12+
end
13+
14+
# project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy
15+
function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}}
16+
h_t = TensorKit.TensorMapWithStorage{T, Vector{T}}(undef, V)
17+
h_t = TensorKit.project_symmetric!(h_t, Array(data))
18+
# verify result
19+
isapprox(Array(reshape(data, dims(h_t))), convert(Array, h_t); atol = tol) ||
20+
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
21+
return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V)
22+
end
23+
24+
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
25+
@eval begin
26+
function CUDA.$fname(
27+
codomain::TensorSpace{S},
28+
domain::TensorSpace{S} = one(codomain)
29+
) where {S <: IndexSpace}
30+
return CUDA.$fname(codomain domain)
31+
end
32+
function CUDA.$fname(
33+
::Type{T}, codomain::TensorSpace{S},
34+
domain::TensorSpace{S} = one(codomain)
35+
) where {T, S <: IndexSpace}
36+
return CUDA.$fname(T, codomain domain)
37+
end
38+
CUDA.$fname(V::TensorMapSpace) = CUDA.$fname(Float64, V)
39+
function CUDA.$fname(::Type{T}, V::TensorMapSpace) where {T}
40+
t = CuTensorMap{T}(undef, V)
41+
fill!(t, $felt(T))
42+
return t
43+
end
44+
end
45+
end
46+
47+
for randfun in (:curand, :curandn)
48+
randfun! = Symbol(randfun, :!)
49+
@eval begin
50+
# converting `codomain` and `domain` into `HomSpace`
51+
function $randfun(
52+
codomain::TensorSpace{S},
53+
domain::TensorSpace{S} = one(codomain),
54+
) where {S <: IndexSpace}
55+
return $randfun(codomain domain)
56+
end
57+
function $randfun(
58+
::Type{T}, codomain::TensorSpace{S},
59+
domain::TensorSpace{S} = one(codomain),
60+
) where {T, S <: IndexSpace}
61+
return $randfun(T, codomain domain)
62+
end
63+
function $randfun(
64+
rng::Random.AbstractRNG, ::Type{T},
65+
codomain::TensorSpace{S},
66+
domain::TensorSpace{S} = one(codomain),
67+
) where {T, S <: IndexSpace}
68+
return $randfun(rng, T, codomain domain)
69+
end
70+
71+
# filling in default eltype
72+
$randfun(V::TensorMapSpace) = $randfun(Float64, V)
73+
function $randfun(rng::Random.AbstractRNG, V::TensorMapSpace)
74+
return $randfun(rng, Float64, V)
75+
end
76+
77+
# filling in default rng
78+
function $randfun(::Type{T}, V::TensorMapSpace) where {T}
79+
return $randfun(Random.default_rng(), T, V)
80+
end
81+
82+
# implementation
83+
function $randfun(
84+
rng::Random.AbstractRNG, ::Type{T},
85+
V::TensorMapSpace
86+
) where {T}
87+
t = CuTensorMap{T}(undef, V)
88+
$randfun!(rng, t)
89+
return t
90+
end
91+
end
92+
end
93+
94+
# Scalar implementation
95+
#-----------------------
96+
function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
97+
inds = findall(!iszero, t.data)
98+
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
99+
end
100+
101+
function Base.convert(
102+
TT::Type{CuTensorMap{T, S, N₁, N₂}},
103+
t::AbstractTensorMap{<:Any, S, N₁, N₂}
104+
) where {T, S, N₁, N₂}
105+
if typeof(t) === TT
106+
return t
107+
else
108+
tnew = TT(undef, space(t))
109+
return copy!(tnew, t)
110+
end
111+
end
112+
113+
function LinearAlgebra.isposdef(t::CuTensorMap)
114+
domain(t) == codomain(t) ||
115+
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
116+
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
117+
for (c, b) in blocks(t)
118+
# do our own hermitian check
119+
isherm = TensorKit.MatrixAlgebraKit.ishermitian(b; atol = eps(real(eltype(b))), rtol = eps(real(eltype(b))))
120+
isherm || return false
121+
isposdef(Hermitian(b)) || return false
122+
end
123+
return true
124+
end
125+
126+
function Base.promote_rule(
127+
::Type{<:TT₁},
128+
::Type{<:TT₂}
129+
) where {
130+
S, N₁, N₂, TTT₁, TTT₂,
131+
TT₁ <: CuTensorMap{TTT₁, S, N₁, N₂},
132+
TT₂ <: CuTensorMap{TTT₂, S, N₁, N₂},
133+
}
134+
T = TensorKit.VectorInterface.promote_add(TTT₁, TTT₂)
135+
return CuTensorMap{T, S, N₁, N₂}
136+
end
137+
138+
# CuTensorMap exponentation:
139+
function TensorKit.exp!(t::CuTensorMap)
140+
domain(t) == codomain(t) ||
141+
error("Exponential of a tensor only exist when domain == codomain.")
142+
for (c, b) in blocks(t)
143+
copy!(b, parent(Base.exp(Hermitian(b))))
144+
end
145+
return t
146+
end
147+
148+
# functions that don't map ℝ to (a subset of) ℝ
149+
for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
150+
sf = string(f)
151+
@eval function Base.$f(t::CuTensorMap)
152+
domain(t) == codomain(t) ||
153+
throw(SpaceMismatch("`$($sf)` of a tensor only exist when domain == codomain"))
154+
T = complex(float(scalartype(t)))
155+
tf = similar(t, T)
156+
for (c, b) in blocks(t)
157+
copy!(block(tf, c), parent($f(Hermitian(b))))
158+
end
159+
return tf
160+
end
161+
end

src/tensors/diagonal.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,10 @@ function DiagonalTensorMap(t::AbstractTensorMap{T, S, 1, 1}) where {T, S}
7878
return d
7979
end
8080

81-
Base.similar(d::DiagonalTensorMap) = similar_diagonal(d)
82-
Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T} = similar_diagonal(d, T)
83-
84-
similar_diagonal(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
85-
similar_diagonal(d::DiagonalTensorMap, ::Type{T}) where {T <: Number} =
86-
DiagonalTensorMap(similar(d.data, T), d.domain)
81+
Base.similar(d::DiagonalTensorMap) = DiagonalTensorMap(similar(d.data), d.domain)
82+
function Base.similar(d::DiagonalTensorMap, ::Type{T}) where {T <: Number}
83+
return DiagonalTensorMap(similar(d.data, T), d.domain)
84+
end
8785

8886
# TODO: more constructors needed?
8987

src/tensors/linalg.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,11 @@ function _norm(blockiter, p::Real, init::Real)
270270
return mapreduce(max, blockiter; init = init) do (c, b)
271271
return isempty(b) ? init : oftype(init, LinearAlgebra.normInf(b))
272272
end
273-
elseif p == 2
274-
= mapreduce(+, blockiter; init = init) do (c, b)
275-
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.norm2(b)^2)
273+
elseif p > 0 # finite positive p
274+
np = sum(blockiter; init) do (c, b)
275+
return oftype(init, dim(c) * norm(b, p)^p)
276276
end
277-
return sqrt(n²)
278-
elseif p == 1
279-
return mapreduce(+, blockiter; init = init) do (c, b)
280-
return isempty(b) ? init : oftype(init, dim(c) * sum(abs, b))
281-
end
282-
elseif p > 0
283-
nᵖ = mapreduce(+, blockiter; init = init) do (c, b)
284-
return isempty(b) ? init : oftype(init, dim(c) * LinearAlgebra.normp(b, p)^p)
285-
end
286-
return (nᵖ)^inv(oftype(nᵖ, p))
277+
return np^(inv(oftype(np, p)))
287278
else
288279
msg = "Norm with non-positive p is not defined for `AbstractTensorMap`"
289280
throw(ArgumentError(msg))
@@ -299,7 +290,7 @@ function LinearAlgebra.rank(
299290
r = 0 * dim(first(allunits(sectortype(t))))
300291
dim(t) == 0 && return r
301292
S = LinearAlgebra.svdvals(t)
302-
tol = max(atol, rtol * maximum(first, values(S)))
293+
tol = max(atol, rtol * maximum(parent(S)))
303294
for (c, b) in pairs(S)
304295
if !isempty(b)
305296
r += dim(c) * count(>(tol), b)
@@ -317,8 +308,8 @@ function LinearAlgebra.cond(t::AbstractTensorMap, p::Real = 2)
317308
return zero(real(float(scalartype(t))))
318309
end
319310
S = LinearAlgebra.svdvals(t)
320-
maxS = maximum(first, values(S))
321-
minS = minimum(last, values(S))
311+
maxS = maximum(parent(S))
312+
minS = minimum(parent(S))
322313
return iszero(maxS) ? oftype(maxS, Inf) : (maxS / minS)
323314
else
324315
throw(ArgumentError("cond currently only defined for p=2"))

0 commit comments

Comments
 (0)