Skip to content

Commit 7274e04

Browse files
committed
Add @non_differentiable annotation to _adjacency_matrix function
1 parent a2d7a62 commit 7274e04

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using CUDA
44
using Random, Statistics, LinearAlgebra
55
using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj
66
using GNNGraphs: GNNGraph, COO_T, SPARSE_T, to_dense, to_sparse
7+
using ChainRulesCore: @non_differentiable
78

89
const CUDA_COO_T = Tuple{T, T, V} where {T <: AnyCuArray{<:Integer}, V <: Union{Nothing, AnyCuArray}}
910

@@ -63,4 +64,6 @@ function _adjacency_matrix(g::GNNGraph{<:CUDA_COO_T}, T::DataType = eltype(g); d
6364
return dir == :out ? A : A'
6465
end
6566

67+
@non_differentiable _adjacency_matrix(x...)
68+
6669
end #module

0 commit comments

Comments
 (0)