Skip to content

Commit 08bc705

Browse files
committed
Start on CUDA extension
1 parent 4e763ba commit 08bc705

File tree

11 files changed

+962
-24
lines changed

11 files changed

+962
-24
lines changed

.buildkite/pipeline.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ steps:
1515
queue: "juliagpu"
1616
cuda: "*"
1717
if: build.message !~ /\[skip tests\]/
18-
timeout_in_minutes: 30
18+
timeout_in_minutes: 60
1919
matrix:
2020
setup:
2121
julia:
2222
- "1.10"
23-
- "1.11"
23+
- "1.12"
2424

2525
- label: "Julia {{matrix.julia}} -- AMDGPU"
2626
plugins:
@@ -36,9 +36,9 @@ steps:
3636
rocm: "*"
3737
rocmgpu: "*"
3838
if: build.message !~ /\[skip tests\]/
39-
timeout_in_minutes: 30
39+
timeout_in_minutes: 60
4040
matrix:
4141
setup:
4242
julia:
4343
- "1.10"
44-
- "1.11"
44+
- "1.12"

Project.toml

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,30 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1818
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1919

2020
[weakdeps]
21+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2122
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2223
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
24+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
25+
26+
[sources]
27+
GPUArrays = {rev = "master", url = "https://github.com/JuliaGPU/GPUArrays.jl"}
28+
MatrixAlgebraKit = {rev = "main", url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl"}
2329

2430
[extensions]
31+
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
2532
TensorKitChainRulesCoreExt = "ChainRulesCore"
2633
TensorKitFiniteDifferencesExt = "FiniteDifferences"
2734

2835
[compat]
36+
Adapt = "4"
2937
Aqua = "0.6, 0.7, 0.8"
3038
ArgParse = "1.2.0"
39+
CUDA = "5.9"
3140
ChainRulesCore = "1"
3241
ChainRulesTestUtils = "1"
3342
Combinatorics = "1"
3443
FiniteDifferences = "0.12"
44+
GPUArrays = "11.3.1"
3545
LRUCache = "1.0.2"
3646
LinearAlgebra = "1"
3747
MatrixAlgebraKit = "0.6.0"
@@ -48,21 +58,26 @@ TestExtras = "0.2,0.3"
4858
TupleTools = "1.1"
4959
VectorInterface = "0.4.8, 0.5"
5060
Zygote = "0.7"
61+
cuTENSOR = "2"
5162
julia = "1.10"
5263

5364
[extras]
54-
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
65+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
5566
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
67+
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
68+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5669
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5770
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5871
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5972
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
73+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
6074
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6175
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
6276
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
6377
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6478
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
6579
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
80+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
6681

6782
[targets]
68-
test = ["ArgParse", "Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "SafeTestsets", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
83+
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module TensorKitCUDAExt
2+
3+
using CUDA, CUDA.CUBLAS, LinearAlgebra
4+
using CUDA: @allowscalar
5+
using cuTENSOR: cuTENSOR
6+
import CUDA: rand as curand, rand! as curand!, randn as curandn, randn! as curandn!
7+
8+
using TensorKit
9+
using TensorKit.Factorizations
10+
using TensorKit.Strided
11+
using TensorKit.Factorizations: AbstractAlgorithm
12+
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype
13+
import TensorKit: randisometry
14+
15+
using TensorKit.MatrixAlgebraKit
16+
17+
using Random
18+
19+
include("cutensormap.jl")
20+
21+
# TODO
22+
# add VectorInterface extensions for proper CUDA promotion
23+
function TensorKit.VectorInterface.promote_add(TA::Type{<:CUDA.StridedCuMatrix{Tx}}, TB::Type{<:CUDA.StridedCuMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}
24+
return Base.promote_op(add, Tx, Ty, Tα, Tβ)
25+
end
26+
27+
end
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
const CuTensorMap{T, S, N₁, N₂} = TensorMap{T, S, N₁, N₂, CuVector{T, CUDA.DeviceMemory}}
2+
const CuTensor{T, S, N} = CuTensorMap{T, S, N, 0}
3+
4+
const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, CuTensorMap{T, S, N₁, N₂}}
5+
6+
function TensorKit.tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type{<:StridedCuArray})
7+
if TorA <: CuArray
8+
return TensorMap{eltype(TorA), S, N₁, N₂, CuVector{eltype(TorA), CUDA.DeviceMemory}}
9+
else
10+
throw(ArgumentError("argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:CuVector{<:Number}`"))
11+
end
12+
end
13+
14+
TensorKit.matrixtype(::Type{<:TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: CuVector{T}} = CuMatrix{T}
15+
16+
function CuTensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
17+
return CuTensorMap{T, S, N₁, N₂}(undef, V)
18+
end
19+
20+
function CuTensorMap{T}(
21+
::UndefInitializer, codomain::TensorSpace{S},
22+
domain::TensorSpace{S}
23+
) where {T, S}
24+
return CuTensorMap{T}(undef, codomain domain)
25+
end
26+
function CuTensor{T}(::UndefInitializer, V::TensorSpace{S}) where {T, S}
27+
return CuTensorMap{T}(undef, V one(V))
28+
end
29+
# constructor starting from block data
30+
"""
31+
CuTensorMap(data::AbstractDict{<:Sector,<:CuMatrix}, codomain::ProductSpace{S,N₁},
32+
domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂}
33+
CuTensorMap(data, codomain ← domain)
34+
CuTensorMap(data, domain → codomain)
35+
36+
Construct a `CuTensorMap` by explicitly specifying its block data.
37+
38+
## Arguments
39+
- `data::AbstractDict{<:Sector,<:CuMatrix}`: dictionary containing the block data for
40+
each coupled sector `c` as a matrix of size `(blockdim(codomain, c), blockdim(domain, c))`.
41+
- `codomain::ProductSpace{S,N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type
42+
`S<:ElementarySpace`.
43+
- `domain::ProductSpace{S,N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type
44+
`S<:ElementarySpace`.
45+
46+
Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
47+
using the syntax `codomain ← domain` or `domain → codomain`.
48+
"""
49+
function CuTensorMap(
50+
data::AbstractDict{<:Sector, <:CuArray},
51+
V::TensorMapSpace{S, N₁, N₂}
52+
) where {S, N₁, N₂}
53+
T = eltype(valtype(data))
54+
t = CuTensorMap{T}(undef, V)
55+
for (c, b) in blocks(t)
56+
haskey(data, c) || throw(SectorMismatch("no data for block sector $c"))
57+
datac = data[c]
58+
size(datac) == size(b) ||
59+
throw(DimensionMismatch("wrong size of block for sector $c"))
60+
copy!(b, datac)
61+
end
62+
for (c, b) in data
63+
c blocksectors(t) || isempty(b) ||
64+
throw(SectorMismatch("data for block sector $c not expected"))
65+
end
66+
return t
67+
end
68+
function CuTensorMap(data::CuArray{T}, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
69+
return CuTensorMap{T, S, N₁, N₂}(vec(data), V)
70+
end
71+
72+
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
73+
@eval begin
74+
function CUDA.$fname(
75+
codomain::TensorSpace{S},
76+
domain::TensorSpace{S} = one(codomain)
77+
) where {S <: IndexSpace}
78+
return CUDA.$fname(codomain domain)
79+
end
80+
function CUDA.$fname(
81+
::Type{T}, codomain::TensorSpace{S},
82+
domain::TensorSpace{S} = one(codomain)
83+
) where {T, S <: IndexSpace}
84+
return CUDA.$fname(T, codomain domain)
85+
end
86+
CUDA.$fname(V::TensorMapSpace) = CUDA.$fname(Float64, V)
87+
function CUDA.$fname(::Type{T}, V::TensorMapSpace) where {T}
88+
t = CuTensorMap{T}(undef, V)
89+
fill!(t, $felt(T))
90+
return t
91+
end
92+
end
93+
end
94+
95+
for randfun in (:curand, :curandn)
96+
randfun! = Symbol(randfun, :!)
97+
@eval begin
98+
# converting `codomain` and `domain` into `HomSpace`
99+
function $randfun(
100+
codomain::TensorSpace{S},
101+
domain::TensorSpace{S} = one(codomain),
102+
) where {S <: IndexSpace}
103+
return $randfun(codomain domain)
104+
end
105+
function $randfun(
106+
::Type{T}, codomain::TensorSpace{S},
107+
domain::TensorSpace{S} = one(codomain),
108+
) where {T, S <: IndexSpace}
109+
return $randfun(T, codomain domain)
110+
end
111+
function $randfun(
112+
rng::Random.AbstractRNG, ::Type{T},
113+
codomain::TensorSpace{S},
114+
domain::TensorSpace{S} = one(codomain),
115+
) where {T, S <: IndexSpace}
116+
return $randfun(rng, T, codomain domain)
117+
end
118+
119+
# filling in default eltype
120+
$randfun(V::TensorMapSpace) = $randfun(Float64, V)
121+
function $randfun(rng::Random.AbstractRNG, V::TensorMapSpace)
122+
return $randfun(rng, Float64, V)
123+
end
124+
125+
# filling in default rng
126+
function $randfun(::Type{T}, V::TensorMapSpace) where {T}
127+
return $randfun(Random.default_rng(), T, V)
128+
end
129+
130+
# implementation
131+
function $randfun(
132+
rng::Random.AbstractRNG, ::Type{T},
133+
V::TensorMapSpace
134+
) where {T}
135+
t = CuTensorMap{T}(undef, V)
136+
$randfun!(rng, t)
137+
return t
138+
end
139+
end
140+
end
141+
142+
for randfun in (:rand, :randn, :randisometry)
143+
randfun! = Symbol(randfun, :!)
144+
@eval begin
145+
# converting `codomain` and `domain` into `HomSpace`
146+
function $randfun(
147+
::Type{A}, codomain::TensorSpace{S},
148+
domain::TensorSpace{S}
149+
) where {A <: CuArray, S <: IndexSpace}
150+
return $randfun(A, codomain domain)
151+
end
152+
function $randfun(
153+
::Type{T}, ::Type{A}, codomain::TensorSpace{S},
154+
domain::TensorSpace{S}
155+
) where {T, S <: IndexSpace, A <: CuArray{T}}
156+
return $randfun(T, A, codomain domain)
157+
end
158+
function $randfun(
159+
rng::Random.AbstractRNG, ::Type{T}, ::Type{A},
160+
codomain::TensorSpace{S},
161+
domain::TensorSpace{S}
162+
) where {T, S <: IndexSpace, A <: CuArray{T}}
163+
return $randfun(rng, T, A, codomain domain)
164+
end
165+
166+
# accepting single `TensorSpace`
167+
$randfun(::Type{A}, codomain::TensorSpace) where {A <: CuArray} = $randfun(A, codomain one(codomain))
168+
function $randfun(::Type{T}, ::Type{A}, codomain::TensorSpace) where {T, A <: CuArray{T}}
169+
return $randfun(T, A, codomain one(codomain))
170+
end
171+
function $randfun(
172+
rng::Random.AbstractRNG, ::Type{T},
173+
::Type{A}, codomain::TensorSpace
174+
) where {T, A <: CuArray{T}}
175+
return $randfun(rng, T, A, codomain one(domain))
176+
end
177+
178+
# filling in default eltype
179+
$randfun(::Type{A}, V::TensorMapSpace) where {A <: CuArray} = $randfun(eltype(A), A, V)
180+
function $randfun(rng::Random.AbstractRNG, ::Type{A}, V::TensorMapSpace) where {A <: CuArray}
181+
return $randfun(rng, eltype(A), A, V)
182+
end
183+
184+
# filling in default rng
185+
function $randfun(::Type{T}, ::Type{A}, V::TensorMapSpace) where {T, A <: CuArray{T}}
186+
return $randfun(Random.default_rng(), T, A, V)
187+
end
188+
189+
# implementation
190+
function $randfun(
191+
rng::Random.AbstractRNG, ::Type{T},
192+
::Type{A}, V::TensorMapSpace
193+
) where {T, A <: CuArray{T}}
194+
t = CuTensorMap{T}(undef, V)
195+
$randfun!(rng, t)
196+
return t
197+
end
198+
end
199+
end
200+
201+
function Base.convert(::Type{CuTensorMap}, t::AbstractTensorMap)
202+
return copy!(CuTensorMap{scalartype(t)}(undef, space(t)), t)
203+
end
204+
205+
# Scalar implementation
206+
#-----------------------
207+
function TensorKit.scalar(t::CuTensorMap)
208+
# TODO: should scalar only work if N₁ == N₂ == 0?
209+
return @allowscalar dim(codomain(t)) == dim(domain(t)) == 1 ?
210+
first(blocks(t))[2][1, 1] : throw(DimensionMismatch())
211+
end
212+
213+
TensorKit.scalartype(A::StridedCuArray{T}) where {T} = T
214+
TensorKit.scalartype(::Type{<:CuTensorMap{T}}) where {T} = T
215+
TensorKit.scalartype(::Type{<:CuArray{T}}) where {T} = T
216+
217+
function TensorKit.similarstoragetype(TT::Type{<:CuTensorMap{TTT, S, N₁, N₂}}, ::Type{T}) where {TTT, T, S, N₁, N₂}
218+
return CuVector{T, CUDA.DeviceMemory}
219+
end
220+
221+
function Base.convert(
222+
TT::Type{CuTensorMap{T, S, N₁, N₂}},
223+
t::AbstractTensorMap{<:Any, S, N₁, N₂}
224+
) where {T, S, N₁, N₂}
225+
if typeof(t) === TT
226+
return t
227+
else
228+
tnew = TT(undef, space(t))
229+
return copy!(tnew, t)
230+
end
231+
end
232+
233+
function LinearAlgebra.isposdef(t::CuTensorMap)
234+
domain(t) == codomain(t) ||
235+
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
236+
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
237+
for (c, b) in blocks(t)
238+
# do our own hermitian check
239+
isherm = TensorKit.MatrixAlgebraKit.ishermitian(b; atol = eps(real(eltype(b))), rtol = eps(real(eltype(b))))
240+
isherm || return false
241+
isposdef(Hermitian(b)) || return false
242+
end
243+
return true
244+
end
245+
246+
function Base.promote_rule(
247+
::Type{<:TT₁},
248+
::Type{<:TT₂}
249+
) where {
250+
S, N₁, N₂, TTT₁, TTT₂,
251+
TT₁ <: CuTensorMap{TTT₁, S, N₁, N₂},
252+
TT₂ <: CuTensorMap{TTT₂, S, N₁, N₂},
253+
}
254+
T = TensorKit.VectorInterface.promote_add(TTT₁, TTT₂)
255+
return CuTensorMap{T, S, N₁, N₂}
256+
end
257+
258+
# Conversion to CuArray:
259+
#----------------------
260+
# probably not optimized for speed, only for checking purposes
261+
function Base.convert(::Type{CuArray}, t::AbstractTensorMap)
262+
I = sectortype(t)
263+
if I === Trivial
264+
convert(CuArray, t[])
265+
else
266+
cod = codomain(t)
267+
dom = domain(t)
268+
T = sectorscalartype(I) <: Complex ? complex(scalartype(t)) :
269+
sectorscalartype(I) <: Integer ? scalartype(t) : float(scalartype(t))
270+
A = CUDA.zeros(T, dims(cod)..., dims(dom)...)
271+
for (f₁, f₂) in fusiontrees(t)
272+
F = convert(CuArray, (f₁, f₂))
273+
Aslice = StridedView(A)[axes(cod, f₁.uncoupled)..., axes(dom, f₂.uncoupled)...]
274+
add!(Aslice, StridedView(TensorKit._kron(convert(CuArray, t[f₁, f₂]), F)))
275+
end
276+
return A
277+
end
278+
end

0 commit comments

Comments
 (0)