Skip to content

Commit 37f2d28

Browse files
amdgpu extension
1 parent 3475aec commit 37f2d28

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

GNNlib/Project.toml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515

1616
[weakdeps]
17+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1718
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1819

1920
[extensions]
2021
GNNlibCUDAExt = "CUDA"
22+
GNNlibAMDGPUExt = "AMDGPU"
2123

2224
[compat]
25+
AMDGPU = "1"
2326
CUDA = "4, 5"
2427
ChainRulesCore = "1.24"
2528
DataStructures = "0.18"
@@ -30,12 +33,3 @@ NNlib = "0.9"
3033
Random = "1"
3134
Statistics = "1"
3235
julia = "1.10"
33-
34-
[extras]
35-
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
36-
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
37-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
38-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
39-
40-
[targets]
41-
test = ["Test", "ReTestItems", "Reexport", "SparseArrays"]

GNNlib/ext/GNNlibAMDGPUExt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
module GNNlibAMDGPUExt
2+
3+
using AMDGPU
4+
using Random, Statistics, LinearAlgebra
5+
using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj
6+
using GNNGraphs: GNNGraph, COO_T, SPARSE_T
7+
8+
###### PROPAGATE SPECIALIZATIONS ####################
9+
10+
## COPY_XJ
11+
12+
## avoid the fast path on gpu until we have better cuda support
13+
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
14+
xi, xj::AnyROCMatrix, e)
15+
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
16+
end
17+
18+
## E_MUL_XJ
19+
20+
## avoid the fast path on gpu until we have better cuda support
21+
function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
22+
xi, xj::AnyROCMatrix, e::AbstractVector)
23+
propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e)
24+
end
25+
26+
## W_MUL_XJ
27+
28+
## avoid the fast path on gpu until we have better support
29+
function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+),
30+
xi, xj::AnyROCMatrix, e::Nothing)
31+
propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e)
32+
end
33+
34+
end #module

0 commit comments

Comments
 (0)