Skip to content

Commit 9f9c8bc

Browse files
update docstring
1 parent f281bc4 commit 9f9c8bc

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

docs/src/api/basic.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ Pages = ["basic.md"]
1414

1515
## Docs
1616

17-
```@docs
18-
GNNLayer
19-
GNNChain
20-
```
17+
```@autodocs
18+
Modules = [GraphNeuralNetworks]
19+
Pages = ["layers/basic.jl"]
20+
Private = false
21+
```

src/gnngraph.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,7 @@ function LightGraphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out)
305305
NNlib.scatter!(+, degs, src, s)
306306
end
307307
if dir [:in, :both]
308-
# @show size(degs) src typeof(t)
309-
NNlib.scatter!(+, degs, src, Int.(t))
308+
NNlib.scatter!(+, degs, src, t)
310309
end
311310
return degs
312311
end

src/layers/basic.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ A type wrapping the `model` and tying it to the graph `g`.
1919
In the forward pass, can only take feature arrays as inputs,
2020
returning `model(g, x...; kws...)`.
2121
22+
If `traingraph=false`, the graph's parameters, won't be collected
23+
when calling `Flux.params` on a `WithGraph` object.
24+
2225
# Examples
2326
2427
```julia
@@ -41,11 +44,10 @@ struct WithGraph{M}
4144
traingraph::Bool
4245
end
4346

44-
4547
WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph)
4648

4749
@functor WithGraph
48-
trainable(l::WithGraph) = l.traingraph ? (l.model, l.g) : (l.model,)
50+
Flux.trainable(l::WithGraph) = l.traingraph ? (l.model, l.g) : (l.model,)
4951

5052
(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
5153
(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...)

test/layers/basic.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,23 @@
4646
end
4747

4848
@testset "WithGraph" begin
49-
g = GNNGraph([1,2,3], [2,3,1])
5049
x = rand(Float32, 2, 3)
50+
g = GNNGraph([1,2,3], [2,3,1], ndata=x)
5151
model = SAGEConv(2 => 3)
5252
wg = WithGraph(model, g)
5353
# No need to feed the graph to `wg`
5454
@test wg(x) == model(g, x)
55-
55+
@test Flux.params(wg) == Flux.params(model)
5656
g2 = GNNGraph([1,1,2,3], [2,4,1,1])
5757
x2 = rand(Float32, 2, 4)
5858
# WithGraph will ignore the internal graph if fed with a new one.
5959
@test wg(g2, x2) == model(g2, x2)
60+
61+
wg = WithGraph(model, g, traingraph=false)
62+
@test length(Flux.params(wg)) == length(Flux.params(model))
63+
64+
wg = WithGraph(model, g, traingraph=true)
65+
@test length(Flux.params(wg)) == length(Flux.params(model)) + length(Flux.params(g))
6066
end
6167
end
6268

0 commit comments

Comments
 (0)