Skip to content

Commit 5f55b14

Browse files
Merge pull request #54 from CarloLucibello/cl/withgraph
add WithGraph
2 parents e75ce87 + 9f9c8bc commit 5f55b14

File tree

4 files changed

+68
-4
lines changed

4 files changed

+68
-4
lines changed

docs/src/api/basic.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ Pages = ["basic.md"]
1414

1515
## Docs
1616

17-
```@docs
18-
GNNLayer
19-
GNNChain
20-
```
17+
```@autodocs
18+
Modules = [GraphNeuralNetworks]
19+
Pages = ["layers/basic.jl"]
20+
Private = false
21+
```

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ export
4343
# layers/basic
4444
GNNLayer,
4545
GNNChain,
46+
WithGraph,
4647

4748
# layers/conv
4849
CGConv,

src/layers/basic.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,48 @@ abstract type GNNLayer end
1111
# To be specialized by layers also needing edge features as input (e.g. NNConv).
1212
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata=l(g, node_features(g)))
1313

14+
15+
"""
16+
WithGraph(model, g::GNNGraph; traingraph=false)
17+
18+
A type wrapping the `model` and tying it to the graph `g`.
19+
In the forward pass, can only take feature arrays as inputs,
20+
returning `model(g, x...; kws...)`.
21+
22+
If `traingraph=false`, the graph's parameters, won't be collected
23+
when calling `Flux.params` on a `WithGraph` object.
24+
25+
# Examples
26+
27+
```julia
28+
g = GNNGraph([1,2,3], [2,3,1])
29+
x = rand(Float32, 2, 3)
30+
model = SAGEConv(2 => 3)
31+
wg = WithGraph(model, g)
32+
# No need to feed the graph to `wg`
33+
@assert wg(x) == model(g, x)
34+
35+
g2 = GNNGraph([1,1,2,3], [2,4,1,1])
36+
x2 = rand(Float32, 2, 4)
37+
# WithGraph will ignore the internal graph if fed with a new one.
38+
@assert wg(g2, x2) == model(g2, x2)
39+
```
40+
"""
41+
struct WithGraph{M}
42+
model::M
43+
g::GNNGraph
44+
traingraph::Bool
45+
end
46+
47+
WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph)
48+
49+
@functor WithGraph
50+
Flux.trainable(l::WithGraph) = l.traingraph ? (l.model, l.g) : (l.model,)
51+
52+
(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
53+
(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...)
54+
55+
1456
"""
1557
GNNChain(layers...)
1658
GNNChain(name = layer, ...)

test/layers/basic.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,25 @@
4444
@test chain(g).ndata.x == y
4545
end
4646
end
47+
48+
@testset "WithGraph" begin
49+
x = rand(Float32, 2, 3)
50+
g = GNNGraph([1,2,3], [2,3,1], ndata=x)
51+
model = SAGEConv(2 => 3)
52+
wg = WithGraph(model, g)
53+
# No need to feed the graph to `wg`
54+
@test wg(x) == model(g, x)
55+
@test Flux.params(wg) == Flux.params(model)
56+
g2 = GNNGraph([1,1,2,3], [2,4,1,1])
57+
x2 = rand(Float32, 2, 4)
58+
# WithGraph will ignore the internal graph if fed with a new one.
59+
@test wg(g2, x2) == model(g2, x2)
60+
61+
wg = WithGraph(model, g, traingraph=false)
62+
@test length(Flux.params(wg)) == length(Flux.params(model))
63+
64+
wg = WithGraph(model, g, traingraph=true)
65+
@test length(Flux.params(wg)) == length(Flux.params(model)) + length(Flux.params(g))
66+
end
4767
end
4868

0 commit comments

Comments
 (0)