Skip to content

Commit c21e7cc

Browse files
improvement for differentiability with heterogeneous graphs (#304)
* rrule for dictionary construction * cleanup * fix tests * fix tests * more tests
1 parent 692d2b2 commit c21e7cc

File tree

12 files changed

+89
-35
lines changed

12 files changed

+89
-35
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import MLUtils
1818
using MLUtils: getobs, numobs
1919
import Functors
2020

21+
include("chainrules.jl") # hacks for differentiability
22+
2123
include("datastore.jl")
2224
export DataStore
2325

src/GNNGraphs/chainrules.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
2+
# Remove when merged
3+
4+
function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
5+
ks = map(first, ps)
6+
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
7+
function Dict_pullback(ȳ)
8+
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
9+
dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent()))
10+
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
11+
end
12+
return (NoTangent(), dps...)
13+
end
14+
return T(ps...), Dict_pullback
15+
end

src/GNNGraphs/convert.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
2323
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
2424
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
2525
end
26-
graph = Dict(k => v for (k, v) in pairs(graph)) # try to restrict the key/value types
26+
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
2727
return graph, _num_nodes, num_edges
2828
end
2929

src/GNNGraphs/gatherscatter.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x)
2-
_gather(x::Dict, i) = Dict(k => _gather(v, i) for (k, v) in x)
2+
_gather(x::Dict, i) = Dict([k => _gather(v, i) for (k, v) in x]...)
33
_gather(x::Tuple, i) = map(x -> _gather(x, i), x)
44
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
55
_gather(x::Nothing, i) = nothing
66

77
_scatter(aggr, src::Nothing, idx, n) = nothing
88
_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
99
_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src)
10-
_scatter(aggr, src::Dict, idx, n) = Dict(k => _scatter(aggr, v, idx, n) for (k, v) in src)
10+
_scatter(aggr, src::Dict, idx, n) = Dict([k => _scatter(aggr, v, idx, n) for (k, v) in src]...)
1111

1212
function _scatter(aggr,
1313
src::AbstractArray,

src/GNNGraphs/gnnheterograph.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,16 @@ GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
9999

100100
function GNNHeteroGraph(data::Dict; kws...)
101101
all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form (source_type, edge_type, target_type)"))
102-
return GNNHeteroGraph(Dict(k => v for (k, v) in pairs(data)); kws...)
102+
return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...)
103103
end
104104

105105
function GNNHeteroGraph(data::EDict;
106106
num_nodes = nothing,
107107
graph_indicator = nothing,
108108
graph_type = :coo,
109109
dir = :out,
110-
ndata = NDict{DataStore}(),
111-
edata = EDict{DataStore}(),
110+
ndata = nothing,
111+
edata = nothing,
112112
gdata = (;))
113113
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
114114
@assert dir [:in, :out]
@@ -132,8 +132,8 @@ function GNNHeteroGraph(data::EDict;
132132
num_graphs = !isnothing(graph_indicator) ?
133133
maximum([maximum(gi) for gi in values(graph_indicator)]) : 1
134134

135-
ndata = normalize_heterographdata(ndata, default_name = :x, n = num_nodes)
136-
edata = normalize_heterographdata(edata, default_name = :e, n = num_edges,
135+
ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes)
136+
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
137137
duplicate_if_needed = true)
138138
gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs)
139139

@@ -226,7 +226,6 @@ num_node_types(g::GNNGraph) = 1
226226

227227
num_node_types(g::GNNHeteroGraph) = length(g.ntypes)
228228

229-
230229
"""
231230
edge_type_subgraph(g::GNNHeteroGraph, edge_ts)
232231
@@ -240,16 +239,16 @@ function edge_type_subgraph(g::GNNHeteroGraph, edge_ts::AbstractVector{<:EType})
240239
@assert edge_t in g.etypes "Edge type $(edge_t) not found in graph"
241240
end
242241
node_ts = _ntypes_from_edges(edge_ts)
243-
graph = Dict(edge_t => g.graph[edge_t] for edge_t in edge_ts)
244-
num_nodes = Dict(node_t => g.num_nodes[node_t] for node_t in node_ts)
245-
num_edges = Dict(edge_t => g.num_edges[edge_t] for edge_t in edge_ts)
242+
graph = Dict([edge_t => g.graph[edge_t] for edge_t in edge_ts]...)
243+
num_nodes = Dict([node_t => g.num_nodes[node_t] for node_t in node_ts]...)
244+
num_edges = Dict([edge_t => g.num_edges[edge_t] for edge_t in edge_ts]...)
246245
if g.graph_indicator === nothing
247246
graph_indicator = nothing
248247
else
249-
graph_indicator = Dict(node_t => g.graph_indicator[node_t] for node_t in node_ts)
248+
graph_indicator = Dict([node_t => g.graph_indicator[node_t] for node_t in node_ts]...)
250249
end
251-
ndata = Dict(node_t => g.ndata[node_t] for node_t in node_ts if node_t in keys(g.ndata))
252-
edata = Dict(edge_t => g.edata[edge_t] for edge_t in edge_ts if edge_t in keys(g.edata))
250+
ndata = Dict([node_t => g.ndata[node_t] for node_t in node_ts if node_t in keys(g.ndata)]...)
251+
edata = Dict([edge_t => g.edata[edge_t] for edge_t in edge_ts if edge_t in keys(g.edata)]...)
253252

254253
return GNNHeteroGraph(graph, num_nodes, num_edges, g.num_graphs,
255254
graph_indicator, ndata, edata, g.gdata,
@@ -258,7 +257,7 @@ end
258257

259258
# TODO this is not correct but Zygote cannot differentiate
260259
# through dictionary generation
261-
@non_differentiable edge_type_subgraph(::Any...)
260+
# @non_differentiable edge_type_subgraph(::Any...)
262261

263262
function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
264263
ntypes = Symbol[]

src/GNNGraphs/utils.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,16 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
163163
end
164164

165165
# For heterogeneous graphs
166+
function normalize_heterographdata(data::Nothing; default_name::Symbol, ns::Dict, kws...)
167+
Dict([k => normalize_graphdata(nothing; default_name = default_name, n, kws...)
168+
for (k, n) in ns]...)
169+
end
170+
166171
normalize_heterographdata(data; kws...) = normalize_heterographdata(Dict(data); kws...)
167172

168-
function normalize_heterographdata(data::Dict; default_name::Symbol, n::Dict, kws...)
169-
isempty(data) && return data
170-
Dict(k => normalize_graphdata(v; default_name = default_name, n = n[k], kws...)
171-
for (k, v) in data)
173+
function normalize_heterographdata(data::Dict; default_name::Symbol, ns::Dict, kws...)
174+
Dict([k => normalize_graphdata(get(data, k, nothing); default_name = default_name, n, kws...)
175+
for (k, n) in ns]...)
172176
end
173177

174178
ones_like(x::AbstractArray, T::Type, sz = size(x)) = fill!(similar(x, T, sz), 1)

src/layers/heteroconv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function HeteroGraphConv(itr; aggr = +)
1212
return HeteroGraphConv(etypes, layers, aggr)
1313
end
1414

15-
function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::NamedTuple)
15+
function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::Union{NamedTuple,Dict})
1616
function forw(l, et)
1717
sg = edge_type_subgraph(g, et)
1818
node1_t, _, node2_t = et

test/GNNGraphs/chainrules.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testset "dict constructor" begin
2+
grad = gradient(1.) do x
3+
d = Dict([:x => x, :y => 5]...)
4+
return sum(d[:x].^2)
5+
end[1]
6+
7+
@test grad == 2
8+
9+
## BROKEN Constructors
10+
# grad = gradient(1.) do x
11+
# d = Dict([(:x => x), (:y => 5)])
12+
# return sum(d[:x].^2)
13+
# end[1]
14+
15+
# @test grad == 2
16+
17+
18+
# grad = gradient(1.) do x
19+
# d = Dict([(:x => x), (:y => 5)])
20+
# return sum(d[:x].^2)
21+
# end[1]
22+
23+
# @test grad == 2
24+
end

test/GNNGraphs/gnnheterograph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
@test hg.num_edges == Dict((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10)
1515
@test hg.graph_indicator === nothing
1616
@test hg.num_graphs == 1
17-
@test hg.ndata == Dict()
18-
@test hg.edata == Dict()
17+
@test hg.ndata isa Dict{Symbol, DataStore}
18+
@test hg.edata isa Dict{Tuple{Symbol, Symbol, Symbol}, DataStore}
1919
@test isempty(hg.gdata)
2020
@test sort(hg.ntypes) == [:A, :B]
2121
@test sort(hg.etypes) == [(:A, :rel1, :B), (:B, :rel2, :A)]

test/layers/heteroconv.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,26 @@
55
model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d),
66
(:B,:to,:A) => GraphConv(d => d)])
77

8-
x = (A = rand(Float32, d, n), B = rand(Float32, d, 2n))
8+
for x in [
9+
(A = rand(Float32, d, n), B = rand(Float32, d, 2n)),
10+
Dict(:A => rand(Float32, d, n), :B => rand(Float32, d, 2n))
11+
]
12+
# x = (A = rand(Float32, d, n), B = rand(Float32, d, 2n))
13+
x = Dict(:A => rand(Float32, d, n), :B => rand(Float32, d, 2n))
14+
15+
y = model(g, x)
916

10-
y = model(g, x)
17+
grad, dx = gradient((model, x) -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model, x)
18+
ngrad, ndx = ngradient((model, x) -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model, x)
1119

12-
grad = gradient(model -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model)[1]
13-
ngrad = ngradient(model -> sum(model(g, x)[1]) + sum(model(g, x)[2].^2), model)[1]
20+
@test grad.layers[1].weight1 ngrad.layers[1].weight1 rtol=1e-4
21+
@test grad.layers[1].weight2 ngrad.layers[1].weight2 rtol=1e-4
22+
@test grad.layers[1].bias ngrad.layers[1].bias rtol=1e-4
23+
@test grad.layers[2].weight1 ngrad.layers[2].weight1 rtol=1e-4
24+
@test grad.layers[2].weight2 ngrad.layers[2].weight2 rtol=1e-4
25+
@test grad.layers[2].bias ngrad.layers[2].bias rtol=1e-4
1426

15-
@test grad.layers[1].weight1 ngrad.layers[1].weight1 rtol=1e-4
16-
@test grad.layers[1].weight2 ngrad.layers[1].weight2 rtol=1e-4
17-
@test grad.layers[1].bias ngrad.layers[1].bias rtol=1e-4
18-
@test grad.layers[2].weight1 ngrad.layers[2].weight1 rtol=1e-4
19-
@test grad.layers[2].weight2 ngrad.layers[2].weight2 rtol=1e-4
20-
@test grad.layers[2].bias ngrad.layers[2].bias rtol=1e-4
27+
@test dx[:A] ndx[:A] rtol=1e-4
28+
@test dx[:B] ndx[:B] rtol=1e-4
29+
end
2130
end

0 commit comments

Comments
 (0)