Skip to content

Commit f281bc4

Browse files
add WithGraph
1 parent 210d2c8 commit f281bc4

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ export
4343
# layers/basic
4444
GNNLayer,
4545
GNNChain,
46+
WithGraph,
4647

4748
# layers/conv
4849
CGConv,

src/gnngraph.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ function LightGraphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out)
305305
NNlib.scatter!(+, degs, src, s)
306306
end
307307
if dir [:in, :both]
308-
NNlib.scatter!(+, degs, src, t)
308+
# @show size(degs) src typeof(t)
309+
NNlib.scatter!(+, degs, src, Int.(t))
309310
end
310311
return degs
311312
end

src/layers/basic.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,46 @@ abstract type GNNLayer end
1111
# To be specialized by layers also needing edge features as input (e.g. NNConv).
1212
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g)))
1313

14+
15+
"""
16+
WithGraph(model, g::GNNGraph; traingraph=false)
17+
18+
A type wrapping the `model` and tying it to the graph `g`.
19+
In the forward pass, can only take feature arrays as inputs,
20+
returning `model(g, x...; kws...)`.
21+
22+
# Examples
23+
24+
```julia
25+
g = GNNGraph([1,2,3], [2,3,1])
26+
x = rand(Float32, 2, 3)
27+
model = SAGEConv(2 => 3)
28+
wg = WithGraph(model, g)
29+
# No need to feed the graph to `wg`
30+
@assert wg(x) == model(g, x)
31+
32+
g2 = GNNGraph([1,1,2,3], [2,4,1,1])
33+
x2 = rand(Float32, 2, 4)
34+
# WithGraph will ignore the internal graph if fed with a new one.
35+
@assert wg(g2, x2) == model(g2, x2)
36+
```
37+
"""
38+
struct WithGraph{M}
39+
model::M
40+
g::GNNGraph
41+
traingraph::Bool
42+
end
43+
44+
45+
WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph)
46+
47+
@functor WithGraph
48+
trainable(l::WithGraph) = l.traingraph ? (l.model, l.g) : (l.model,)
49+
50+
(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
51+
(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...)
52+
53+
1454
"""
1555
GNNChain(layers...)
1656
GNNChain(name = layer, ...)

test/layers/basic.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,19 @@
4444
@test chain(g).ndata.x == y
4545
end
4646
end
47+
48+
@testset "WithGraph" begin
49+
g = GNNGraph([1,2,3], [2,3,1])
50+
x = rand(Float32, 2, 3)
51+
model = SAGEConv(2 => 3)
52+
wg = WithGraph(model, g)
53+
# No need to feed the graph to `wg`
54+
@test wg(x) == model(g, x)
55+
56+
g2 = GNNGraph([1,1,2,3], [2,4,1,1])
57+
x2 = rand(Float32, 2, 4)
58+
# WithGraph will ignore the internal graph if fed with a new one.
59+
@test wg(g2, x2) == model(g2, x2)
60+
end
4761
end
4862

0 commit comments

Comments
 (0)