Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import NearestNeighbors
import NNlib
import StatsBase
import KrylovKit
using ChainRulesCore
import ChainRulesCore as CRC
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
Expand Down
11 changes: 6 additions & 5 deletions GNNGraphs/src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
# Remove when merged

function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
function CRC.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
ks = map(first, ps)
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
project_ks, project_vs = map(CRC.ProjectTo, ks), map(CRC.ProjectTolast, ps)
function Dict_pullback(ȳ)
dy = CRC.unthunk(ȳ)
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
dk, dv = proj_k(getkey(, k, NoTangent())), proj_v(get(, k, NoTangent()))
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
dk, dv = proj_k(getkey(dy, k, CRC.NoTangent())), proj_v(get(dy, k, CRC.NoTangent()))
CRC.Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
end
return (NoTangent(), dps...)
return (CRC.NoTangent(), dps...)
end
return T(ps...), Dict_pullback
end
2 changes: 1 addition & 1 deletion GNNGraphs/src/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function _findnz_idx(A)
return s, t, nz
end

@non_differentiable _findnz_idx(A)
CRC.@non_differentiable _findnz_idx(A)

function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true)
s, t, nz = _findnz_idx(A)
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/src/gnnheterograph/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ end

# TODO this is not correct but Zygote cannot differentiate
# through dictionary generation
# @non_differentiable edge_type_subgraph(::Any...)
# CRC.@non_differentiable edge_type_subgraph(::Any...)

function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
ntypes = Symbol[]
Expand All @@ -285,7 +285,7 @@ function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
return ntypes
end

@non_differentiable _ntypes_from_edges(::Any...)
CRC.@non_differentiable _ntypes_from_edges(::Any...)

function Base.getindex(g::GNNHeteroGraph, node_t::NType)
return g.ndata[node_t]
Expand Down
62 changes: 36 additions & 26 deletions GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
return dir == :out ? A : A'
end

function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}}
A = adjacency_matrix(g, T; dir, weighted)
if !weighted
function adjacency_matrix_pullback_noweight(Δ)
return (NoTangent(), ZeroTangent(), NoTangent())
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
end
return A, adjacency_matrix_pullback_noweight
else
function adjacency_matrix_pullback_weighted(Δ)
dg = Tangent{G}(; graph = Δ .* binarize(A))
return (NoTangent(), dg, NoTangent())
dy = CRC.unthunk(Δ)
dg = CRC.Tangent{G}(; graph = dy .* binarize(dy))
return (CRC.NoTangent(), dg, CRC.NoTangent())
end
return A, adjacency_matrix_pullback_weighted
end
end

function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}}
A = adjacency_matrix(g, T; dir, weighted)
w = get_edge_weight(g)
if !weighted || w === nothing
function adjacency_matrix_pullback_noweight(Δ)
return (NoTangent(), ZeroTangent(), NoTangent())
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
end
return A, adjacency_matrix_pullback_noweight
else
function adjacency_matrix_pullback_weighted(Δ)
dy = CRC.unthunk(Δ)
s, t = edge_index(g)
dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t)))
return (NoTangent(), dg, NoTangent())
@show dy s t
#TODO using CRC.@thunk gives an error
#TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed
dw = zeros_like(w)
idx = CartesianIndex.(s, t) #TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed
NNlib.gather!(dw, dy, idx)
@show dw
dg = CRC.Tangent{G}(; graph = (CRC.NoTangent(), CRC.NoTangent(), dw))
return (CRC.NoTangent(), dg, CRC.NoTangent())
end
return A, adjacency_matrix_pullback_weighted
end
Expand Down Expand Up @@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num
vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2))
end

function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
function CRC.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
degs = _degree(graph, T, dir, edge_weight, num_nodes)
function _degree_pullback(Δ)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
return ntuple(i -> (CRC.NoTangent(),), 6)
end
return degs, _degree_pullback
end

function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
function CRC.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
degs = _degree(A, T, dir, edge_weight, num_nodes)
if edge_weight === false
function _degree_pullback_noweights(Δ)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
return ntuple(i -> (CRC.NoTangent(),), 6)
end
return degs, _degree_pullback_noweights
else
function _degree_pullback_weights(Δ)
dy = CRC.unthunk(Δ)
# We propagate the gradient only to the non-zero elements
# of the adjacency matrix.
bA = binarize(A)
if dir == :in
dA = bA .* Δ'
dA = bA .* dy'
elseif dir == :out
dA = Δ .* bA
dA = dy .* bA
else # dir == :both
dA = Δ .* bA + Δ' .* bA
dA = dy .* bA + dy' .* bA
end
return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent())
return (CRC.NoTangent(), dA, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
end
return degs, _degree_pullback_weights
end
Expand Down Expand Up @@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32;
A = A + I
end
degs = vec(sum(A; dims = 2))
ChainRulesCore.ignore_derivatives() do
CRC.ignore_derivatives() do
@assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`."
end
inv_sqrtD = Diagonal(inv.(sqrt.(degs)))
Expand Down Expand Up @@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32;
end
end

@non_differentiable edge_index(x...)
@non_differentiable adjacency_list(x...)
@non_differentiable graph_indicator(x...)
@non_differentiable has_multi_edges(x...)
@non_differentiable Graphs.has_self_loops(x...)
@non_differentiable is_bidirected(x...)
@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
CRC.@non_differentiable edge_index(x...)
CRC.@non_differentiable adjacency_list(x...)
CRC.@non_differentiable graph_indicator(x...)
CRC.@non_differentiable has_multi_edges(x...)
CRC.@non_differentiable Graphs.has_self_loops(x...)
CRC.@non_differentiable is_bidirected(x...)
CRC.@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
CRC.@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
CRC.@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
12 changes: 6 additions & 6 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,8 @@ function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
return edgemasks
end

@non_differentiable _unbatch_nodemasks(::Any...)
@non_differentiable _unbatch_edgemasks(::Any...)
CRC.@non_differentiable _unbatch_nodemasks(::Any...)
CRC.@non_differentiable _unbatch_edgemasks(::Any...)

"""
getgraph(g::GNNGraph, i; nmap=false)
Expand Down Expand Up @@ -998,10 +998,10 @@ dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz)
# """
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)

@non_differentiable negative_sample(x...)
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable dense_zeros_like(x...)
CRC.@non_differentiable negative_sample(x...)
CRC.@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
CRC.@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
CRC.@non_differentiable dense_zeros_like(x...)

"""
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph
Expand Down
12 changes: 6 additions & 6 deletions GNNGraphs/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ end

binarize(x) = map(>(0), x)

@non_differentiable binarize(x...)
@non_differentiable edge_encoding(x...)
@non_differentiable edge_decoding(x...)
CRC.@non_differentiable binarize(x...)
CRC.@non_differentiable edge_encoding(x...)
CRC.@non_differentiable edge_decoding(x...)

### PRINTING #####

Expand Down Expand Up @@ -330,11 +330,11 @@ function dims2string(d)
join(map(string, d), '×')
end

@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
@non_differentiable normalize_graphdata(::Nothing)
CRC.@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
CRC.@non_differentiable normalize_graphdata(::Nothing)

iscuarray(x::AbstractArray) = false
@non_differentiable iscuarray(::Any)
CRC.@non_differentiable iscuarray(::Any)


@doc raw"""
Expand Down
3 changes: 1 addition & 2 deletions GNNGraphs/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
Expand All @@ -9,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -20,7 +20,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
GPUArraysCore = "0.1"
9 changes: 3 additions & 6 deletions GraphNeuralNetworks/examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

[compat]
DiffEqFlux = "2"
Flux = "0.13"
GraphNeuralNetworks = "0.6"
Flux = "0.16"
GraphNeuralNetworks = "1"
Graphs = "1"
MLDatasets = "0.7"
julia = "1.9"
julia = "1.10"
Loading