Skip to content

Commit 59b52e9

Browse files
fix zygote error (#579)
1 parent 4309da1 commit 59b52e9

File tree

9 files changed

+62
-55
lines changed

9 files changed

+62
-55
lines changed

GNNGraphs/src/GNNGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import NearestNeighbors
88
import NNlib
99
import StatsBase
1010
import KrylovKit
11-
using ChainRulesCore
11+
import ChainRulesCore as CRC
1212
using LinearAlgebra, Random, Statistics
1313
import MLUtils
1414
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like

GNNGraphs/src/chainrules.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
22
# Remove when merged
33

4-
function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
4+
function CRC.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
55
ks = map(first, ps)
6-
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
6+
project_ks, project_vs = map(CRC.ProjectTo, ks), map(CRC.ProjectTo last, ps)
77
function Dict_pullback(ȳ)
8+
dy = CRC.unthunk(ȳ)
89
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
9-
dk, dv = proj_k(getkey(, k, NoTangent())), proj_v(get(, k, NoTangent()))
10-
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
10+
dk, dv = proj_k(getkey(dy, k, CRC.NoTangent())), proj_v(get(dy, k, CRC.NoTangent()))
11+
CRC.Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
1112
end
12-
return (NoTangent(), dps...)
13+
return (CRC.NoTangent(), dps...)
1314
end
1415
return T(ps...), Dict_pullback
1516
end

GNNGraphs/src/convert.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function _findnz_idx(A)
7878
return s, t, nz
7979
end
8080

81-
@non_differentiable _findnz_idx(A)
81+
CRC.@non_differentiable _findnz_idx(A)
8282

8383
function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true)
8484
s, t, nz = _findnz_idx(A)

GNNGraphs/src/gnnheterograph/gnnheterograph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ end
273273

274274
# TODO this is not correct but Zygote cannot differentiate
275275
# through dictionary generation
276-
# @non_differentiable edge_type_subgraph(::Any...)
276+
# CRC.@non_differentiable edge_type_subgraph(::Any...)
277277

278278
function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
279279
ntypes = Symbol[]
@@ -285,7 +285,7 @@ function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
285285
return ntypes
286286
end
287287

288-
@non_differentiable _ntypes_from_edges(::Any...)
288+
CRC.@non_differentiable _ntypes_from_edges(::Any...)
289289

290290
function Base.getindex(g::GNNHeteroGraph, node_t::NType)
291291
return g.ndata[node_t]

GNNGraphs/src/query.jl

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
241241
return dir == :out ? A : A'
242242
end
243243

244-
function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
244+
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
245245
dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}}
246246
A = adjacency_matrix(g, T; dir, weighted)
247247
if !weighted
248248
function adjacency_matrix_pullback_noweight(Δ)
249-
return (NoTangent(), ZeroTangent(), NoTangent())
249+
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
250250
end
251251
return A, adjacency_matrix_pullback_noweight
252252
else
253253
function adjacency_matrix_pullback_weighted(Δ)
254-
dg = Tangent{G}(; graph = Δ .* binarize(A))
255-
return (NoTangent(), dg, NoTangent())
254+
dy = CRC.unthunk(Δ)
255+
dg = CRC.Tangent{G}(; graph = dy .* binarize(dy))
256+
return (CRC.NoTangent(), dg, CRC.NoTangent())
256257
end
257258
return A, adjacency_matrix_pullback_weighted
258259
end
259260
end
260261

261-
function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
262+
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
262263
dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}}
263264
A = adjacency_matrix(g, T; dir, weighted)
264265
w = get_edge_weight(g)
265266
if !weighted || w === nothing
266267
function adjacency_matrix_pullback_noweight(Δ)
267-
return (NoTangent(), ZeroTangent(), NoTangent())
268+
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
268269
end
269270
return A, adjacency_matrix_pullback_noweight
270271
else
271272
function adjacency_matrix_pullback_weighted(Δ)
273+
dy = CRC.unthunk(Δ)
272274
s, t = edge_index(g)
273-
dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t)))
274-
return (NoTangent(), dg, NoTangent())
275+
@show dy s t
276+
#TODO using CRC.@thunk gives an error
277+
#TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed
278+
dw = zeros_like(w)
279+
idx = CartesianIndex.(s, t) #TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed
280+
NNlib.gather!(dw, dy, idx)
281+
@show dw
282+
dg = CRC.Tangent{G}(; graph = (CRC.NoTangent(), CRC.NoTangent(), dw))
283+
return (CRC.NoTangent(), dg, CRC.NoTangent())
275284
end
276285
return A, adjacency_matrix_pullback_weighted
277286
end
@@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num
378387
vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2))
379388
end
380389

381-
function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
390+
function CRC.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
382391
degs = _degree(graph, T, dir, edge_weight, num_nodes)
383392
function _degree_pullback(Δ)
384-
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
393+
return ntuple(i -> (CRC.NoTangent(),), 6)
385394
end
386395
return degs, _degree_pullback
387396
end
388397

389-
function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
398+
function CRC.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
390399
degs = _degree(A, T, dir, edge_weight, num_nodes)
391400
if edge_weight === false
392401
function _degree_pullback_noweights(Δ)
393-
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
402+
return ntuple(i -> (CRC.NoTangent(),), 6)
394403
end
395404
return degs, _degree_pullback_noweights
396405
else
397406
function _degree_pullback_weights(Δ)
407+
dy = CRC.unthunk(Δ)
398408
# We propagate the gradient only to the non-zero elements
399409
# of the adjacency matrix.
400410
bA = binarize(A)
401411
if dir == :in
402-
dA = bA .* Δ'
412+
dA = bA .* dy'
403413
elseif dir == :out
404-
dA = Δ .* bA
414+
dA = dy .* bA
405415
else # dir == :both
406-
dA = Δ .* bA + Δ' .* bA
416+
dA = dy .* bA + dy' .* bA
407417
end
408-
return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent())
418+
return (CRC.NoTangent(), dA, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
409419
end
410420
return degs, _degree_pullback_weights
411421
end
@@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32;
452462
A = A + I
453463
end
454464
degs = vec(sum(A; dims = 2))
455-
ChainRulesCore.ignore_derivatives() do
465+
CRC.ignore_derivatives() do
456466
@assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`."
457467
end
458468
inv_sqrtD = Diagonal(inv.(sqrt.(degs)))
@@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32;
609619
end
610620
end
611621

612-
@non_differentiable edge_index(x...)
613-
@non_differentiable adjacency_list(x...)
614-
@non_differentiable graph_indicator(x...)
615-
@non_differentiable has_multi_edges(x...)
616-
@non_differentiable Graphs.has_self_loops(x...)
617-
@non_differentiable is_bidirected(x...)
618-
@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
619-
@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
620-
@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
622+
CRC.@non_differentiable edge_index(x...)
623+
CRC.@non_differentiable adjacency_list(x...)
624+
CRC.@non_differentiable graph_indicator(x...)
625+
CRC.@non_differentiable has_multi_edges(x...)
626+
CRC.@non_differentiable Graphs.has_self_loops(x...)
627+
CRC.@non_differentiable is_bidirected(x...)
628+
CRC.@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
629+
CRC.@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
630+
CRC.@non_differentiable scaled_laplacian(x...) # TODO remove this in the future

GNNGraphs/src/transform.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,8 @@ function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
808808
return edgemasks
809809
end
810810

811-
@non_differentiable _unbatch_nodemasks(::Any...)
812-
@non_differentiable _unbatch_edgemasks(::Any...)
811+
CRC.@non_differentiable _unbatch_nodemasks(::Any...)
812+
CRC.@non_differentiable _unbatch_edgemasks(::Any...)
813813

814814
"""
815815
getgraph(g::GNNGraph, i; nmap=false)
@@ -998,10 +998,10 @@ dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz)
998998
# """
999999
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)
10001000

1001-
@non_differentiable negative_sample(x...)
1002-
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1003-
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1004-
@non_differentiable dense_zeros_like(x...)
1001+
CRC.@non_differentiable negative_sample(x...)
1002+
CRC.@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1003+
CRC.@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
1004+
CRC.@non_differentiable dense_zeros_like(x...)
10051005

10061006
"""
10071007
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph

GNNGraphs/src/utils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,9 @@ end
292292

293293
binarize(x) = map(>(0), x)
294294

295-
@non_differentiable binarize(x...)
296-
@non_differentiable edge_encoding(x...)
297-
@non_differentiable edge_decoding(x...)
295+
CRC.@non_differentiable binarize(x...)
296+
CRC.@non_differentiable edge_encoding(x...)
297+
CRC.@non_differentiable edge_decoding(x...)
298298

299299
### PRINTING #####
300300

@@ -330,11 +330,11 @@ function dims2string(d)
330330
join(map(string, d), '×')
331331
end
332332

333-
@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
334-
@non_differentiable normalize_graphdata(::Nothing)
333+
CRC.@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
334+
CRC.@non_differentiable normalize_graphdata(::Nothing)
335335

336336
iscuarray(x::AbstractArray) = false
337-
@non_differentiable iscuarray(::Any)
337+
CRC.@non_differentiable iscuarray(::Any)
338338

339339

340340
@doc raw"""

GNNGraphs/test/Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[deps]
2-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
32
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
43
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
54
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
@@ -9,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
98
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
109
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1110
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
11+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -20,7 +20,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2020
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
2121
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
2222
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
23-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2423

2524
[compat]
2625
GPUArraysCore = "0.1"
Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3-
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
43
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
54
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
65
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
76
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
87
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
98
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
109
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
11-
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1210

1311
[compat]
14-
DiffEqFlux = "2"
15-
Flux = "0.13"
16-
GraphNeuralNetworks = "0.6"
12+
Flux = "0.16"
13+
GraphNeuralNetworks = "1"
1714
Graphs = "1"
1815
MLDatasets = "0.7"
19-
julia = "1.9"
16+
julia = "1.10"

0 commit comments

Comments
 (0)