Skip to content

Commit 306da07

Browse files
fix zygote error
1 parent 2dd14fd commit 306da07

File tree

14 files changed

+104
-69
lines changed

14 files changed

+104
-69
lines changed

GNNGraphs/docs/src/guides/datasets.md

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
11
# Datasets
22

3-
GNNGraphs.jl doesn't come with its own datasets, but leverages those available in the Julia (and non-Julia) ecosystem. In particular, the [examples in the GraphNeuralNetworks.jl repository](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) make use of the [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) package. There you will find common graph datasets such as Cora, PubMed, Citeseer, TUDataset and [many others](https://juliaml.github.io/MLDatasets.jl/dev/datasets/graphs/).
3+
GNNGraphs.jl doesn't come with its own datasets, but leverages those available in the Julia (and non-Julia) ecosystem.
4+
5+
## MLDatasets.jl
6+
7+
Some of the [examples in the GraphNeuralNetworks.jl repository](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) make use of the [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) package. There you will find common graph datasets such as Cora, PubMed, Citeseer, TUDataset and [many others](https://juliaml.github.io/MLDatasets.jl/dev/datasets/graphs/).
48
For graphs with static structures and temporal features, datasets such as METRLA, PEMSBAY, ChickenPox, and WindMillEnergy are available. For graphs featuring both temporal structures and temporal features, the TemporalBrains dataset is suitable.
59

610
GraphNeuralNetworks.jl provides the [`mldataset2gnngraph`](@ref) method for interfacing with MLDatasets.jl.
711

12+
## PyGDatasets.jl
13+
14+
The package [PyGDatasets.jl](https://github.com/CarloLucibello/PyGDatasets.jl) makes available to Julia users the datasets from the [pytorch geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html) library.
15+
16+
PyGDatasets' datasets are compatible with GNNGraphs, so no additional conversion is needed.
17+
```julia
18+
julia> using PyGDatasets
19+
20+
julia> dataset = load_dataset("TUDataset", name="MUTAG")
21+
TUDataset(MUTAG) - InMemoryGNNDataset
22+
num_graphs: 188
23+
node_features: [:x]
24+
edge_features: [:edge_attr]
25+
graph_features: [:y]
26+
root: /Users/carlo/.julia/scratchspaces/44f67abd-f36e-4be4-bfe5-65f468a62b3d/datasets/TUDataset
27+
28+
julia> g = dataset[1]
29+
GNNGraph:
30+
num_nodes: 17
31+
num_edges: 38
32+
ndata:
33+
x = 7×17 Matrix{Float32}
34+
edata:
35+
edge_attr = 4×38 Matrix{Float32}
36+
gdata:
37+
y = 1-element Vector{Int64}
38+
39+
julia> using MLUtils: DataLoader
40+
41+
julia> data_loader = DataLoader(dataset, batch_size=32);
42+
```
43+
44+
PyGDatasets is based on [PythonCall.jl](https://github.com/JuliaPy/PythonCall.jl). It carries over some heavy dependencies such as python, pytorch and pytorch geometric.

GNNGraphs/src/GNNGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import NearestNeighbors
88
import NNlib
99
import StatsBase
1010
import KrylovKit
11-
using ChainRulesCore
11+
import ChainRulesCore as CRC
1212
using LinearAlgebra, Random, Statistics
1313
import MLUtils
1414
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like

GNNGraphs/src/chainrules.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
22
# Remove when merged
33

4-
function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
4+
function CRC.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
55
ks = map(first, ps)
6-
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
6+
project_ks, project_vs = map(CRC.ProjectTo, ks), map(CRC.ProjectTo last, ps)
77
function Dict_pullback(ȳ)
8+
dy = CRC.unthunk(ȳ)
89
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
9-
dk, dv = proj_k(getkey(, k, NoTangent())), proj_v(get(, k, NoTangent()))
10-
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
10+
dk, dv = proj_k(getkey(dy, k, CRC.NoTangent())), proj_v(get(dy, k, CRC.NoTangent()))
11+
CRC.Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
1112
end
12-
return (NoTangent(), dps...)
13+
return (CRC.NoTangent(), dps...)
1314
end
1415
return T(ps...), Dict_pullback
1516
end

GNNGraphs/src/convert.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function _findnz_idx(A)
7878
return s, t, nz
7979
end
8080

81-
@non_differentiable _findnz_idx(A)
81+
CRC.@non_differentiable _findnz_idx(A)
8282

8383
function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true)
8484
s, t, nz = _findnz_idx(A)

GNNGraphs/src/gnnheterograph/generate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ See [`rand_heterograph`](@ref) for a more general version.
9595
9696
# Examples
9797
98-
```julia-repl
98+
```julia
9999
julia> g = rand_bipartite_heterograph((10, 15), 20)
100100
GNNHeteroGraph:
101101
num_nodes: (:A => 10, :B => 15)

GNNGraphs/src/gnnheterograph/gnnheterograph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ end
273273

274274
# TODO this is not correct but Zygote cannot differentiate
275275
# through dictionary generation
276-
# @non_differentiable edge_type_subgraph(::Any...)
276+
# CRC.@non_differentiable edge_type_subgraph(::Any...)
277277

278278
function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
279279
ntypes = Symbol[]
@@ -285,7 +285,7 @@ function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
285285
return ntypes
286286
end
287287

288-
@non_differentiable _ntypes_from_edges(::Any...)
288+
CRC.@non_differentiable _ntypes_from_edges(::Any...)
289289

290290
function Base.getindex(g::GNNHeteroGraph, node_t::NType)
291291
return g.ndata[node_t]

GNNGraphs/src/query.jl

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
241241
return dir == :out ? A : A'
242242
end
243243

244-
function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
244+
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
245245
dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}}
246246
A = adjacency_matrix(g, T; dir, weighted)
247247
if !weighted
248248
function adjacency_matrix_pullback_noweight(Δ)
249-
return (NoTangent(), ZeroTangent(), NoTangent())
249+
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
250250
end
251251
return A, adjacency_matrix_pullback_noweight
252252
else
253253
function adjacency_matrix_pullback_weighted(Δ)
254-
dg = Tangent{G}(; graph = Δ .* binarize(A))
255-
return (NoTangent(), dg, NoTangent())
254+
dy = CRC.unthunk(Δ)
255+
dg = CRC.Tangent{G}(; graph = dy .* binarize(dy))
256+
return (CRC.NoTangent(), dg, CRC.NoTangent())
256257
end
257258
return A, adjacency_matrix_pullback_weighted
258259
end
259260
end
260261

261-
function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
262+
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
262263
dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}}
263264
A = adjacency_matrix(g, T; dir, weighted)
264265
w = get_edge_weight(g)
265266
if !weighted || w === nothing
266267
function adjacency_matrix_pullback_noweight(Δ)
267-
return (NoTangent(), ZeroTangent(), NoTangent())
268+
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
268269
end
269270
return A, adjacency_matrix_pullback_noweight
270271
else
271272
function adjacency_matrix_pullback_weighted(Δ)
273+
dy = CRC.unthunk(Δ)
272274
s, t = edge_index(g)
273-
dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t)))
274-
return (NoTangent(), dg, NoTangent())
275+
@show dy s t
276+
#TODO using CRC.@thunk gives an error
277+
#TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed
278+
dw = zeros_like(w)
279+
idx = CartesianIndex.(s, t) #TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed
280+
NNlib.gather!(dw, dy, idx)
281+
@show dw
282+
dg = CRC.Tangent{G}(; graph = (CRC.NoTangent(), CRC.NoTangent(), dw))
283+
return (CRC.NoTangent(), dg, CRC.NoTangent())
275284
end
276285
return A, adjacency_matrix_pullback_weighted
277286
end
@@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num
378387
vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2))
379388
end
380389

381-
function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
390+
function CRC.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
382391
degs = _degree(graph, T, dir, edge_weight, num_nodes)
383392
function _degree_pullback(Δ)
384-
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
393+
return ntuple(i -> (CRC.NoTangent(),), 6)
385394
end
386395
return degs, _degree_pullback
387396
end
388397

389-
function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
398+
function CRC.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
390399
degs = _degree(A, T, dir, edge_weight, num_nodes)
391400
if edge_weight === false
392401
function _degree_pullback_noweights(Δ)
393-
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
402+
return ntuple(i -> (CRC.NoTangent(),), 6)
394403
end
395404
return degs, _degree_pullback_noweights
396405
else
397406
function _degree_pullback_weights(Δ)
407+
dy = CRC.unthunk(Δ)
398408
# We propagate the gradient only to the non-zero elements
399409
# of the adjacency matrix.
400410
bA = binarize(A)
401411
if dir == :in
402-
dA = bA .* Δ'
412+
dA = bA .* dy'
403413
elseif dir == :out
404-
dA = Δ .* bA
414+
dA = dy .* bA
405415
else # dir == :both
406-
dA = Δ .* bA + Δ' .* bA
416+
dA = dy .* bA + dy' .* bA
407417
end
408-
return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent())
418+
return (CRC.NoTangent(), dA, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
409419
end
410420
return degs, _degree_pullback_weights
411421
end
@@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32;
452462
A = A + I
453463
end
454464
degs = vec(sum(A; dims = 2))
455-
ChainRulesCore.ignore_derivatives() do
465+
CRC.ignore_derivatives() do
456466
@assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`."
457467
end
458468
inv_sqrtD = Diagonal(inv.(sqrt.(degs)))
@@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32;
609619
end
610620
end
611621

612-
@non_differentiable edge_index(x...)
613-
@non_differentiable adjacency_list(x...)
614-
@non_differentiable graph_indicator(x...)
615-
@non_differentiable has_multi_edges(x...)
616-
@non_differentiable Graphs.has_self_loops(x...)
617-
@non_differentiable is_bidirected(x...)
618-
@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
619-
@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
620-
@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
622+
CRC.@non_differentiable edge_index(x...)
623+
CRC.@non_differentiable adjacency_list(x...)
624+
CRC.@non_differentiable graph_indicator(x...)
625+
CRC.@non_differentiable has_multi_edges(x...)
626+
CRC.@non_differentiable Graphs.has_self_loops(x...)
627+
CRC.@non_differentiable is_bidirected(x...)
628+
CRC.@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
629+
CRC.@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
630+
CRC.@non_differentiable scaled_laplacian(x...) # TODO remove this in the future

GNNGraphs/src/transform.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,8 @@ function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
808808
return edgemasks
809809
end
810810

811-
@non_differentiable _unbatch_nodemasks(::Any...)
812-
@non_differentiable _unbatch_edgemasks(::Any...)
811+
CRC.@non_differentiable _unbatch_nodemasks(::Any...)
812+
CRC.@non_differentiable _unbatch_edgemasks(::Any...)
813813

814814
"""
815815
getgraph(g::GNNGraph, i; nmap=false)
@@ -998,10 +998,10 @@ dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz)
998998
# """
999999
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)
10001000

1001-
@non_differentiable negative_sample(x...)
1002-
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1003-
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1004-
@non_differentiable dense_zeros_like(x...)
1001+
CRC.@non_differentiable negative_sample(x...)
1002+
CRC.@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1003+
CRC.@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1004+
CRC.@non_differentiable dense_zeros_like(x...)
10051005

10061006
"""
10071007
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph

GNNGraphs/src/utils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,9 @@ end
292292

293293
binarize(x) = map(>(0), x)
294294

295-
@non_differentiable binarize(x...)
296-
@non_differentiable edge_encoding(x...)
297-
@non_differentiable edge_decoding(x...)
295+
CRC.@non_differentiable binarize(x...)
296+
CRC.@non_differentiable edge_encoding(x...)
297+
CRC.@non_differentiable edge_decoding(x...)
298298

299299
### PRINTING #####
300300

@@ -330,11 +330,11 @@ function dims2string(d)
330330
join(map(string, d), '×')
331331
end
332332

333-
@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
334-
@non_differentiable normalize_graphdata(::Nothing)
333+
CRC.@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
334+
CRC.@non_differentiable normalize_graphdata(::Nothing)
335335

336336
iscuarray(x::AbstractArray) = false
337-
@non_differentiable iscuarray(::Any)
337+
CRC.@non_differentiable iscuarray(::Any)
338338

339339

340340
@doc raw"""

GNNGraphs/test/Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[deps]
2-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
32
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
43
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
54
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
@@ -9,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
98
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
109
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1110
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
11+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -20,7 +20,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2020
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
2121
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
2222
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
23-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2423

2524
[compat]
2625
GPUArraysCore = "0.1"

0 commit comments

Comments
 (0)