@@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
241
241
return dir == :out ? A : A'
242
242
end
243
243
244
- function ChainRulesCore . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
244
+ function CRC . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
245
245
dir = :out , weighted = true ) where {G <: GNNGraph{<:ADJMAT_T} }
246
246
A = adjacency_matrix (g, T; dir, weighted)
247
247
if ! weighted
248
248
function adjacency_matrix_pullback_noweight (Δ)
249
- return (NoTangent (), ZeroTangent (), NoTangent ())
249
+ return (CRC . NoTangent (), CRC . ZeroTangent (), CRC . NoTangent ())
250
250
end
251
251
return A, adjacency_matrix_pullback_noweight
252
252
else
253
253
function adjacency_matrix_pullback_weighted (Δ)
254
- dg = Tangent {G} (; graph = Δ .* binarize (A))
255
- return (NoTangent (), dg, NoTangent ())
254
+ dy = CRC. unthunk (Δ)
255
+ dg = CRC. Tangent {G} (; graph = dy .* binarize (dy))
256
+ return (CRC. NoTangent (), dg, CRC. NoTangent ())
256
257
end
257
258
return A, adjacency_matrix_pullback_weighted
258
259
end
259
260
end
260
261
261
- function ChainRulesCore . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
262
+ function CRC . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
262
263
dir = :out , weighted = true ) where {G <: GNNGraph{<:COO_T} }
263
264
A = adjacency_matrix (g, T; dir, weighted)
264
265
w = get_edge_weight (g)
265
266
if ! weighted || w === nothing
266
267
function adjacency_matrix_pullback_noweight (Δ)
267
- return (NoTangent (), ZeroTangent (), NoTangent ())
268
+ return (CRC . NoTangent (), CRC . ZeroTangent (), CRC . NoTangent ())
268
269
end
269
270
return A, adjacency_matrix_pullback_noweight
270
271
else
271
272
function adjacency_matrix_pullback_weighted (Δ)
273
+ dy = CRC. unthunk (Δ)
272
274
s, t = edge_index (g)
273
- dg = Tangent {G} (; graph = (NoTangent (), NoTangent (), NNlib. gather (Δ, s, t)))
274
- return (NoTangent (), dg, NoTangent ())
275
+ @show dy s t
276
+ # TODO using CRC.@thunk gives an error
277
+ # TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed
278
+ dw = zeros_like (w)
279
+ idx = CartesianIndex .(s, t) # TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed
280
+ NNlib. gather! (dw, dy, idx)
281
+ @show dw
282
+ dg = CRC. Tangent {G} (; graph = (CRC. NoTangent (), CRC. NoTangent (), dw))
283
+ return (CRC. NoTangent (), dg, CRC. NoTangent ())
275
284
end
276
285
return A, adjacency_matrix_pullback_weighted
277
286
end
@@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num
378
387
vec (sum (A, dims = 1 )) .+ vec (sum (A, dims = 2 ))
379
388
end
380
389
381
- function ChainRulesCore . rrule (:: typeof (_degree), graph, T, dir, edge_weight:: Nothing , num_nodes)
390
+ function CRC . rrule (:: typeof (_degree), graph, T, dir, edge_weight:: Nothing , num_nodes)
382
391
degs = _degree (graph, T, dir, edge_weight, num_nodes)
383
392
function _degree_pullback (Δ)
384
- return ( NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent ( ), NoTangent () )
393
+ return ntuple (i -> (CRC . NoTangent (),), 6 )
385
394
end
386
395
return degs, _degree_pullback
387
396
end
388
397
389
- function ChainRulesCore . rrule (:: typeof (_degree), A:: ADJMAT_T , T, dir, edge_weight:: Bool , num_nodes)
398
+ function CRC . rrule (:: typeof (_degree), A:: ADJMAT_T , T, dir, edge_weight:: Bool , num_nodes)
390
399
degs = _degree (A, T, dir, edge_weight, num_nodes)
391
400
if edge_weight === false
392
401
function _degree_pullback_noweights (Δ)
393
- return ( NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent ( ), NoTangent () )
402
+ return ntuple (i -> (CRC . NoTangent (),), 6 )
394
403
end
395
404
return degs, _degree_pullback_noweights
396
405
else
397
406
function _degree_pullback_weights (Δ)
407
+ dy = CRC. unthunk (Δ)
398
408
# We propagate the gradient only to the non-zero elements
399
409
# of the adjacency matrix.
400
410
bA = binarize (A)
401
411
if dir == :in
402
- dA = bA .* Δ '
412
+ dA = bA .* dy '
403
413
elseif dir == :out
404
- dA = Δ .* bA
414
+ dA = dy .* bA
405
415
else # dir == :both
406
- dA = Δ .* bA + Δ ' .* bA
416
+ dA = dy .* bA + dy ' .* bA
407
417
end
408
- return (NoTangent (), dA, NoTangent (), NoTangent (), NoTangent (), NoTangent ())
418
+ return (CRC . NoTangent (), dA, CRC . NoTangent (), CRC . NoTangent (), CRC . NoTangent (), CRC . NoTangent ())
409
419
end
410
420
return degs, _degree_pullback_weights
411
421
end
@@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32;
452
462
A = A + I
453
463
end
454
464
degs = vec (sum (A; dims = 2 ))
455
- ChainRulesCore . ignore_derivatives () do
465
+ CRC . ignore_derivatives () do
456
466
@assert all (! iszero, degs) " Graph contains isolated nodes, cannot compute `normalized_adjacency`."
457
467
end
458
468
inv_sqrtD = Diagonal (inv .(sqrt .(degs)))
@@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32;
609
619
end
610
620
end
611
621
612
- @non_differentiable edge_index (x... )
613
- @non_differentiable adjacency_list (x... )
614
- @non_differentiable graph_indicator (x... )
615
- @non_differentiable has_multi_edges (x... )
616
- @non_differentiable Graphs. has_self_loops (x... )
617
- @non_differentiable is_bidirected (x... )
618
- @non_differentiable normalized_adjacency (x... ) # TODO remove this in the future
619
- @non_differentiable normalized_laplacian (x... ) # TODO remove this in the future
620
- @non_differentiable scaled_laplacian (x... ) # TODO remove this in the future
622
+ CRC . @non_differentiable edge_index (x... )
623
+ CRC . @non_differentiable adjacency_list (x... )
624
+ CRC . @non_differentiable graph_indicator (x... )
625
+ CRC . @non_differentiable has_multi_edges (x... )
626
+ CRC . @non_differentiable Graphs. has_self_loops (x... )
627
+ CRC . @non_differentiable is_bidirected (x... )
628
+ CRC . @non_differentiable normalized_adjacency (x... ) # TODO remove this in the future
629
+ CRC . @non_differentiable normalized_laplacian (x... ) # TODO remove this in the future
630
+ CRC . @non_differentiable scaled_laplacian (x... ) # TODO remove this in the future
0 commit comments