|
| 1 | +module GNNlibAMDGPUExt |
| 2 | + |
| 3 | +using AMDGPU |
| 4 | +using Random, Statistics, LinearAlgebra |
| 5 | +using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj |
| 6 | +using GNNGraphs: GNNGraph, COO_T, SPARSE_T |
| 7 | + |
| 8 | +###### PROPAGATE SPECIALIZATIONS #################### |
| 9 | + |
| 10 | +## COPY_XJ |
| 11 | + |
| 12 | +## avoid the fast path on gpu until we have better cuda support |
| 13 | +function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
| 14 | + xi, xj::AnyROCMatrix, e) |
| 15 | + propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) |
| 16 | +end |
| 17 | + |
| 18 | +## E_MUL_XJ |
| 19 | + |
| 20 | +## avoid the fast path on gpu until we have better cuda support |
| 21 | +function GNNlib.propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
| 22 | + xi, xj::AnyROCMatrix, e::AbstractVector) |
| 23 | + propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e) |
| 24 | +end |
| 25 | + |
| 26 | +## W_MUL_XJ |
| 27 | + |
| 28 | +## avoid the fast path on gpu until we have better support |
| 29 | +function GNNlib.propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), |
| 30 | + xi, xj::AnyROCMatrix, e::Nothing) |
| 31 | + propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) |
| 32 | +end |
| 33 | + |
| 34 | +end #module |
0 commit comments