@@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
241241    return  dir ==  :out  ?  A :  A' 
242242end 
243243
244- function  ChainRulesCore . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ; 
244+ function  CRC . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ; 
245245            dir =  :out , weighted =  true ) where  {G <:  GNNGraph{<:ADJMAT_T} }
246246    A =  adjacency_matrix (g, T; dir, weighted)
247247    if  ! weighted
248248        function  adjacency_matrix_pullback_noweight (Δ)
249-             return  (NoTangent (), ZeroTangent (), NoTangent ())  
249+             return  (CRC . NoTangent (), CRC . ZeroTangent (), CRC . NoTangent ())  
250250        end 
251251        return  A, adjacency_matrix_pullback_noweight
252252    else 
253253        function  adjacency_matrix_pullback_weighted (Δ)
254-             dg =  Tangent {G} (; graph =  Δ .*  binarize (A))
255-             return  (NoTangent (), dg, NoTangent ())  
254+             dy =  CRC. unthunk (Δ)
255+             dg =  CRC. Tangent {G} (; graph =  dy .*  binarize (dy))
256+             return  (CRC. NoTangent (), dg, CRC. NoTangent ())  
256257        end 
257258        return  A, adjacency_matrix_pullback_weighted
258259    end 
259260end 
260261
261- function  ChainRulesCore . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ; 
262+ function  CRC . rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ; 
262263            dir =  :out , weighted =  true ) where  {G <:  GNNGraph{<:COO_T} }
263264    A =  adjacency_matrix (g, T; dir, weighted)
264265    w =  get_edge_weight (g)
265266    if  ! weighted ||  w ===  nothing 
266267        function  adjacency_matrix_pullback_noweight (Δ)
267-             return  (NoTangent (), ZeroTangent (), NoTangent ())  
268+             return  (CRC . NoTangent (), CRC . ZeroTangent (), CRC . NoTangent ())  
268269        end 
269270        return  A, adjacency_matrix_pullback_noweight
270271    else 
271272        function  adjacency_matrix_pullback_weighted (Δ)
273+             dy =  CRC. unthunk (Δ)
272274            s, t =  edge_index (g)
273-             dg =  Tangent {G} (; graph =  (NoTangent (), NoTangent (), NNlib. gather (Δ, s, t)))
274-             return  (NoTangent (), dg, NoTangent ())  
275+             @show  dy s t
276+             # TODO  using CRC.@thunk gives an error
277+             # TODO  use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed
278+             dw =  zeros_like (w)
279+             idx =  CartesianIndex .(s, t) # TODO  remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed
280+             NNlib. gather! (dw, dy, idx)
281+             @show  dw
282+             dg =  CRC. Tangent {G} (; graph =  (CRC. NoTangent (), CRC. NoTangent (), dw))
283+             return  (CRC. NoTangent (), dg, CRC. NoTangent ())  
275284        end 
276285        return  A, adjacency_matrix_pullback_weighted
277286    end 
@@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num
378387           vec (sum (A, dims =  1 )) .+  vec (sum (A, dims =  2 ))
379388end 
380389
381- function  ChainRulesCore . rrule (:: typeof (_degree), graph, T, dir, edge_weight:: Nothing , num_nodes)
390+ function  CRC . rrule (:: typeof (_degree), graph, T, dir, edge_weight:: Nothing , num_nodes)
382391    degs =  _degree (graph, T, dir, edge_weight, num_nodes)
383392    function  _degree_pullback (Δ)
384-         return  ( NoTangent (),  NoTangent (),  NoTangent (),  NoTangent (),  NoTangent ( ), NoTangent () )
393+         return  ntuple (i  ->  (CRC . NoTangent (),), 6 )
385394    end 
386395    return  degs, _degree_pullback
387396end 
388397
389- function  ChainRulesCore . rrule (:: typeof (_degree), A:: ADJMAT_T , T, dir, edge_weight:: Bool , num_nodes)
398+ function  CRC . rrule (:: typeof (_degree), A:: ADJMAT_T , T, dir, edge_weight:: Bool , num_nodes)
390399    degs =  _degree (A, T, dir, edge_weight, num_nodes)
391400    if  edge_weight ===  false 
392401        function  _degree_pullback_noweights (Δ)
393-             return  ( NoTangent (),  NoTangent (),  NoTangent (),  NoTangent (),  NoTangent ( ), NoTangent () )
402+             return  ntuple (i  ->  (CRC . NoTangent (),), 6 )
394403        end 
395404        return  degs, _degree_pullback_noweights
396405    else 
397406        function  _degree_pullback_weights (Δ)
407+             dy =  CRC. unthunk (Δ)
398408            #  We propagate the gradient only to the non-zero elements
399409            #  of the adjacency matrix.
400410            bA =  binarize (A)
401411            if  dir ==  :in 
402-                 dA =  bA .*  Δ ' 
412+                 dA =  bA .*  dy ' 
403413            elseif  dir ==  :out 
404-                 dA =  Δ  .*  bA
414+                 dA =  dy  .*  bA
405415            else  #  dir == :both
406-                 dA =  Δ  .*  bA +  Δ '  .*  bA
416+                 dA =  dy  .*  bA +  dy '  .*  bA
407417            end 
408-             return  (NoTangent (), dA, NoTangent (), NoTangent (), NoTangent (), NoTangent ())
418+             return  (CRC . NoTangent (), dA, CRC . NoTangent (), CRC . NoTangent (), CRC . NoTangent (), CRC . NoTangent ())
409419        end 
410420        return  degs, _degree_pullback_weights
411421    end 
@@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32;
452462        A =  A +  I
453463    end 
454464    degs =  vec (sum (A; dims =  2 ))
455-     ChainRulesCore . ignore_derivatives () do 
465+     CRC . ignore_derivatives () do 
456466        @assert  all (! iszero, degs) " Graph contains isolated nodes, cannot compute `normalized_adjacency`." 
457467    end 
458468    inv_sqrtD =  Diagonal (inv .(sqrt .(degs)))
@@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32;
609619    end 
610620end 
611621
612- @non_differentiable  edge_index (x... )
613- @non_differentiable  adjacency_list (x... )
614- @non_differentiable  graph_indicator (x... )
615- @non_differentiable  has_multi_edges (x... )
616- @non_differentiable  Graphs. has_self_loops (x... )
617- @non_differentiable  is_bidirected (x... )
618- @non_differentiable  normalized_adjacency (x... ) #  TODO  remove this in the future
619- @non_differentiable  normalized_laplacian (x... ) #  TODO  remove this in the future
620- @non_differentiable  scaled_laplacian (x... ) #  TODO  remove this in the future
622+ CRC . @non_differentiable  edge_index (x... )
623+ CRC . @non_differentiable  adjacency_list (x... )
624+ CRC . @non_differentiable  graph_indicator (x... )
625+ CRC . @non_differentiable  has_multi_edges (x... )
626+ CRC . @non_differentiable  Graphs. has_self_loops (x... )
627+ CRC . @non_differentiable  is_bidirected (x... )
628+ CRC . @non_differentiable  normalized_adjacency (x... ) #  TODO  remove this in the future
629+ CRC . @non_differentiable  normalized_laplacian (x... ) #  TODO  remove this in the future
630+ CRC . @non_differentiable  scaled_laplacian (x... ) #  TODO  remove this in the future
0 commit comments