Skip to content

Commit 8bbc3ff

Browse files
committed
Start on GPU extensions
1 parent b32f1f4 commit 8bbc3ff

File tree

19 files changed

+2947
-38
lines changed

19 files changed

+2947
-38
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ steps:
1515
queue: "juliagpu"
1616
cuda: "*"
1717
if: build.message !~ /\[skip tests\]/
18-
timeout_in_minutes: 30
18+
timeout_in_minutes: 60
1919
matrix:
2020
setup:
2121
julia:
@@ -36,7 +36,7 @@ steps:
3636
rocm: "*"
3737
rocmgpu: "*"
3838
if: build.message !~ /\[skip tests\]/
39-
timeout_in_minutes: 30
39+
timeout_in_minutes: 60
4040
matrix:
4141
setup:
4242
julia:

Project.toml

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,33 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1818
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1919

2020
[weakdeps]
21+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
22+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2123
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2224
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
25+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
26+
27+
[sources]
28+
GPUArrays = {rev = "master", url = "https://github.com/JuliaGPU/GPUArrays.jl"}
29+
MatrixAlgebraKit = {rev = "main", url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl"}
2330

2431
[extensions]
32+
TensorKitAMDGPUExt = "AMDGPU"
33+
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
2534
TensorKitChainRulesCoreExt = "ChainRulesCore"
2635
TensorKitFiniteDifferencesExt = "FiniteDifferences"
2736

2837
[compat]
38+
AMDGPU = "2"
39+
Adapt = "4"
2940
Aqua = "0.6, 0.7, 0.8"
3041
ArgParse = "1.2.0"
42+
CUDA = "5.9"
3143
ChainRulesCore = "1"
3244
ChainRulesTestUtils = "1"
3345
Combinatorics = "1"
3446
FiniteDifferences = "0.12"
47+
GPUArrays = "11.3.1"
3548
LRUCache = "1.0.2"
3649
LinearAlgebra = "1"
3750
MatrixAlgebraKit = "0.6.0"
@@ -48,21 +61,27 @@ TestExtras = "0.2,0.3"
4861
TupleTools = "1.1"
4962
VectorInterface = "0.4.8, 0.5"
5063
Zygote = "0.7"
64+
cuTENSOR = "2"
5165
julia = "1.10"
5266

5367
[extras]
54-
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
68+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
69+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
5570
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
71+
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
72+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5673
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5774
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5875
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5976
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
77+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
6078
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6179
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
6280
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
6381
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6482
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
6583
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
84+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
6685

6786
[targets]
68-
test = ["ArgParse", "Aqua", "Combinatorics", "LinearAlgebra", "TensorOperations", "Test", "TestExtras", "SafeTestsets", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
87+
test = ["ArgParse", "Adapt", "AMDGPU", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
module TensorKitAMDGPUExt
2+
3+
using AMDGPU, AMDGPU.rocBLAS, LinearAlgebra
4+
using AMDGPU: @allowscalar
5+
import AMDGPU: rand as rocrand, rand! as rocrand!, randn as rocrandn, randn! as rocrandn!
6+
7+
using TensorKit
8+
using TensorKit.Factorizations
9+
using TensorKit.Strided
10+
using TensorKit.Factorizations: AbstractAlgorithm
11+
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype
12+
13+
using TensorKit.MatrixAlgebraKit
14+
15+
using Random
16+
17+
include("roctensormap.jl")
18+
19+
const ROCDiagonalTensorMap{T, S} = DiagonalTensorMap{T, S, ROCVector{T, AMDGPU.Mem.HIPBuffer}}
20+
21+
"""
22+
ROCDiagonalTensorMap{T}(undef, domain::S) where {T,S<:IndexSpace}
23+
# expert mode: select storage type `A`
24+
DiagonalTensorMap{T,S,A}(undef, domain::S) where {T,S<:IndexSpace,A<:DenseVector{T}}
25+
26+
Construct a `DiagonalTensorMap` with uninitialized data.
27+
"""
28+
function ROCDiagonalTensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T}
29+
(numin(V) == numout(V) == 1 && domain(V) == codomain(V)) ||
30+
throw(ArgumentError("DiagonalTensorMap requires a space with equal domain and codomain and 2 indices"))
31+
return ROCDiagonalTensorMap{T}(undef, domain(V))
32+
end
33+
function ROCDiagonalTensorMap{T}(::UndefInitializer, V::ProductSpace) where {T}
34+
length(V) == 1 ||
35+
throw(ArgumentError("DiagonalTensorMap requires `numin(d) == numout(d) == 1`"))
36+
return ROCDiagonalTensorMap{T}(undef, only(V))
37+
end
38+
function ROCDiagonalTensorMap{T}(::UndefInitializer, V::S) where {T, S <: IndexSpace}
39+
return ROCDiagonalTensorMap{T, S}(undef, V)
40+
end
41+
ROCDiagonalTensorMap(::UndefInitializer, V::IndexSpace) = ROCDiagonalTensorMap{Float64}(undef, V)
42+
43+
function ROCDiagonalTensorMap(data::ROCVector{T}, V::S) where {T, S}
44+
return ROCDiagonalTensorMap{T, S}(data, V)
45+
end
46+
47+
function ROCDiagonalTensorMap(data::Vector{T}, V::S) where {T, S}
48+
return ROCDiagonalTensorMap{T, S}(ROCVector{T}(data), V)
49+
end
50+
51+
function TensorKit.Factorizations.MAK.initialize_output(::typeof(svd_full!), t::ROCDiagonalTensorMap, alg::DiagonalAlgorithm)
52+
V_cod = fuse(codomain(t))
53+
V_dom = fuse(domain(t))
54+
U = similar(t, codomain(t) V_cod)
55+
S = ROCDiagonalTensorMap{real(scalartype(t))}(undef, V_cod V_dom)
56+
Vᴴ = similar(t, V_dom domain(t))
57+
return U, S, Vᴴ
58+
end
59+
60+
function TensorKit.Factorizations.MAK.initialize_output(::typeof(svd_vals!), t::ROCTensorMap, alg::AbstractAlgorithm)
61+
V_cod = infimum(fuse(codomain(t)), fuse(domain(t)))
62+
return ROCDiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
63+
end
64+
65+
function TensorKit.Factorizations.MAK.initialize_output(::typeof(svd_compact!), t::ROCTensorMap, ::AbstractAlgorithm)
66+
V_cod = V_dom = infimum(fuse(codomain(t)), fuse(domain(t)))
67+
U = similar(t, codomain(t) V_cod)
68+
S = ROCDiagonalTensorMap{real(scalartype(t))}(undef, V_cod)
69+
Vᴴ = similar(t, V_dom domain(t))
70+
return U, S, Vᴴ
71+
end
72+
73+
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eigh_full!), t::ROCTensorMap, ::AbstractAlgorithm)
74+
V_D = fuse(domain(t))
75+
T = real(scalartype(t))
76+
D = ROCDiagonalTensorMap{T}(undef, V_D)
77+
V = similar(t, codomain(t) V_D)
78+
return D, V
79+
end
80+
81+
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eig_full!), t::ROCTensorMap, ::AbstractAlgorithm)
82+
V_D = fuse(domain(t))
83+
Tc = complex(scalartype(t))
84+
D = ROCDiagonalTensorMap{Tc}(undef, V_D)
85+
V = similar(t, Tc, codomain(t) V_D)
86+
return D, V
87+
end
88+
89+
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eigh_vals!), t::ROCTensorMap, alg::AbstractAlgorithm)
90+
V_D = fuse(domain(t))
91+
T = real(scalartype(t))
92+
return D = ROCDiagonalTensorMap{Tc}(undef, V_D)
93+
end
94+
95+
function TensorKit.Factorizations.MAK.initialize_output(::typeof(eig_vals!), t::ROCTensorMap, alg::AbstractAlgorithm)
96+
V_D = fuse(domain(t))
97+
Tc = complex(scalartype(t))
98+
return D = ROCDiagonalTensorMap{Tc}(undef, V_D)
99+
end
100+
101+
102+
# TODO
103+
# add VectorInterface extensions for proper AMDGPU promotion
104+
function TensorKit.VectorInterface.promote_add(TA::Type{<:AMDGPU.StridedROCMatrix{Tx}}, TB::Type{<:AMDGPU.StridedROCMatrix{Ty}}, α::Tα = TensorKit.VectorInterface.One(), β::Tβ = TensorKit.VectorInterface.One()) where {Tx, Ty, Tα, Tβ}
105+
return Base.promote_op(add, Tx, Ty, Tα, Tβ)
106+
end
107+
108+
end

0 commit comments

Comments
 (0)