Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions GNNGraphs/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ Pkg.instantiate()
using Documenter
using DocumenterInterLinks
using GNNGraphs
using MLUtils # this is needed by setdocmeta!
import Graphs
using Graphs: induced_subgraph

ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets

DocMeta.setdocmeta!(GNNGraphs, :DocTestSetup, :(using GNNGraphs, MLUtils); recursive = true)

Expand All @@ -25,7 +25,6 @@ mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/ma

makedocs(;
modules = [GNNGraphs],
doctest = false, # TODO enable doctest
format = Documenter.HTML(; mathengine,
prettyurls = get(ENV, "CI", nothing) == "true",
assets = [],
Expand Down
1 change: 1 addition & 0 deletions GNNGraphs/docs/src/api/heterograph.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ Modules = [GNNGraphs]
Pages = ["gnnheterograph/generate.jl"]
Private = false
```
å
14 changes: 7 additions & 7 deletions GNNGraphs/docs/src/guides/heterograph.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,28 @@ julia> g.etypes # edge types

Node, edge, and graph features can be added at construction time or later using:
```jldoctest hetero
# equivalent to g.ndata[:user][:x] = ...
julia> g[:user].x = rand(Float32, 64, 3);
julia> g[:user].x = rand(Float32, 64, 3); # equivalent to g.ndata[:user][:x] = ...

julia> g[:movie].z = rand(Float32, 64, 13);

# equivalent to g.edata[(:user, :rate, :movie)][:e] = ...
julia> g[:user, :rate, :movie].e = rand(Float32, 64, 4);
julia> g[:user, :rate, :movie].e = rand(Float32, 64, 4); # equivalent to g.edata[(:user, :rate, :movie)][:e] = ...

julia> g
GNNHeteroGraph:
num_nodes: Dict(:movie => 13, :user => 3)
num_edges: Dict((:user, :rate, :movie) => 4)
ndata:
:movie => DataStore(z = [64×13 Matrix{Float32}])
:user => DataStore(x = [64×3 Matrix{Float32}])
:movie => DataStore(z = [64×13 Matrix{Float32}])
:user => DataStore(x = [64×3 Matrix{Float32}])
edata:
(:user, :rate, :movie) => DataStore(e = [64×4 Matrix{Float32}])
(:user, :rate, :movie) => DataStore(e = [64×4 Matrix{Float32}])
```

## Batching
Similarly to graphs, also heterographs can be batched together.
```jldoctest hetero
julia> using MLUtils

julia> gs = [rand_bipartite_heterograph((5, 10), 20) for _ in 1:32];

julia> MLUtils.batch(gs)
Expand Down
6 changes: 5 additions & 1 deletion GNNGraphs/docs/src/guides/temporalgraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ Snapshots in a temporal graph can be accessed using indexing:
julia> snapshots = [rand_graph(10, 20), rand_graph(10, 14), rand_graph(10, 22)];

julia> tg = TemporalSnapshotsGNNGraph(snapshots)
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10]
num_edges: [20, 14, 22]
num_snapshots: 3

julia> tg[1] # first snapshot
GNNGraph:
Expand Down Expand Up @@ -169,7 +173,7 @@ TemporalSnapshotsGNNGraph:
num_edges: [20, 14, 22]
num_snapshots: 3
tgdata:
y = 3×1 Matrix{Float32}
y = 3×1 Matrix{Float32}

julia> tg.ndata # vector of DataStore containing node features for each snapshot
3-element Vector{DataStore}:
Expand Down
6 changes: 3 additions & 3 deletions GNNGraphs/src/gnnheterograph/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,18 @@ function Base.show(io::IO, ::MIME"text/plain", g::GNNHeteroGraph)
print(io, "\n ndata:")
for k in sort(collect(keys(g.ndata)))
isempty(g.ndata[k]) && continue
print(io, "\n\t", _str(k), " => $(shortsummary(g.ndata[k]))")
print(io, "\n ", _str(k), " => $(shortsummary(g.ndata[k]))")
end
end
if !isempty(g.edata) && !all(isempty, values(g.edata))
print(io, "\n edata:")
for k in sort(collect(keys(g.edata)))
isempty(g.edata[k]) && continue
print(io, "\n\t$k => $(shortsummary(g.edata[k]))")
print(io, "\n $k => $(shortsummary(g.edata[k]))")
end
end
if !isempty(g.gdata)
print(io, "\n gdata:\n\t")
print(io, "\n gdata:\n ")
shortsummary(io, g.gdata)
end
end
Expand Down
10 changes: 5 additions & 5 deletions GNNGraphs/src/mldatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ GNNGraph:
num_nodes: 2708
num_edges: 10556
ndata:
val_mask = 2708-element BitVector
targets = 2708-element Vector{Int64}
test_mask = 2708-element BitVector
features = 1433×2708 Matrix{Float32}
train_mask = 2708-element BitVector
val_mask = 2708-element BitVector
targets = 2708-element Vector{Int64}
test_mask = 2708-element BitVector
features = 1433×2708 Matrix{Float32}
train_mask = 2708-element BitVector
```
"""
function mldataset2gnngraph(dataset::D) where {D}
Expand Down
10 changes: 5 additions & 5 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,14 @@ GNNGraph:
num_nodes: 4
num_edges: 5
edata:
e = 5-element Vector{Float64}
e = 5-element Vector{Float64}

julia> g2 = to_bidirected(g)
GNNGraph:
num_nodes: 4
num_edges: 7
edata:
e = 7-element Vector{Float64}
e = 7-element Vector{Float64}

julia> edge_index(g2)
([1, 2, 2, 3, 3, 4, 4], [2, 1, 3, 2, 4, 3, 4])
Expand Down Expand Up @@ -644,22 +644,22 @@ GNNGraph:
num_nodes: 4
num_edges: 4
ndata:
x = 3×4 Matrix{Float32}
x = 3×4 Matrix{Float32}

julia> g2 = rand_graph(5, 4, ndata=zeros(Float32, 3, 5))
GNNGraph:
num_nodes: 5
num_edges: 4
ndata:
x = 3×5 Matrix{Float32}
x = 3×5 Matrix{Float32}

julia> g12 = MLUtils.batch([g1, g2])
GNNGraph:
num_nodes: 9
num_edges: 8
num_graphs: 2
ndata:
x = 3×9 Matrix{Float32}
x = 3×9 Matrix{Float32}

julia> g12.ndata.x
3×9 Matrix{Float32}:
Expand Down
1 change: 1 addition & 0 deletions GNNLux/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
TSne = "24678dba-d5e9-5843-a4c6-250288b04835"
Expand Down
7 changes: 5 additions & 2 deletions GNNLux/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ using GNNLux
using Lux, GNNGraphs, GNNlib, Graphs
using DocumenterInterLinks

ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets

DocMeta.setdocmeta!(GNNGraphs, :DocTestSetup, :(using GNNGraphs, MLUtils); recursive = true)
DocMeta.setdocmeta!(GNNlib, :DocTestSetup, :(using GNNlib); recursive = true)
DocMeta.setdocmeta!(GNNLux, :DocTestSetup, :(using GNNLux); recursive = true)

mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]),
Expand All @@ -37,7 +41,6 @@ cp(joinpath(@__DIR__, "../../GNNlib/docs/src"),

makedocs(;
modules = [GNNLux, GNNGraphs, GNNlib],
doctest = false, # TODO: enable doctest
plugins = [interlinks],
format = Documenter.HTML(; mathengine,
prettyurls = get(ENV, "CI", nothing) == "true",
Expand Down Expand Up @@ -82,7 +85,7 @@ makedocs(;
"Layers" => [
"Basic layers" => "api/basic.md",
"Convolutional layers" => "api/conv.md",
# "Pooling layers" => "api/pool.md",
"Pooling layers" => "api/pool.md",
"Temporal Convolutional layers" => "api/temporalconv.md",
# "Hetero Convolutional layers" => "api/heteroconv.md",
]
Expand Down
7 changes: 0 additions & 7 deletions GNNLux/docs/src/api/pool.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@ CollapsedDocStrings = true

# Pooling Layers

## Index

```@index
Order = [:type, :function]
Pages = ["pool.md"]
```

```@autodocs
Modules = [GNNLux]
Pages = ["layers/pool.jl"]
Expand Down
17 changes: 12 additions & 5 deletions GNNLux/src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ julia> using Lux, GNNLux, Random

julia> rng = Random.default_rng();

julia> m = GNNChain(GCNConv(2=>5),
x -> relu.(x),
Dense(5=>4))
julia> m = GNNChain(GCNConv(2 => 5, relu), Dense(5 => 4))
GNNChain(
layers = NamedTuple(
layer_1 = GCNConv(2 => 5, relu), # 15 parameters
layer_2 = Dense(5 => 4), # 24 parameters
),
) # Total: 39 parameters,
# plus 0 states.

julia> x = randn(rng, Float32, 2, 3);

Expand All @@ -44,8 +49,10 @@ GNNGraph:

julia> ps, st = LuxCore.setup(rng, m);

julia> m(g, x, ps, st) # First entry is the output, second entry is the state of the model
(Float32[-0.15594329 -0.15594329 -0.15594329; 0.93431795 0.93431795 0.93431795; 0.27568763 0.27568763 0.27568763; 0.12568939 0.12568939 0.12568939], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
julia> y, st = m(g, x, ps, st); # First entry is the output, second entry is the state of the model

julia> size(y)
(4, 3)
```
"""
@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
Expand Down
2 changes: 2 additions & 0 deletions GNNlib/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
5 changes: 4 additions & 1 deletion GNNlib/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ using GNNGraphs
import Graphs
using DocumenterInterLinks

ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets
DocMeta.setdocmeta!(GNNGraphs, :DocTestSetup, :(using GNNGraphs, MLUtils); recursive = true)
DocMeta.setdocmeta!(GNNlib, :DocTestSetup, :(using GNNlib); recursive = true)

assets=[]
prettyurls = get(ENV, "CI", nothing) == "true"
mathengine = MathJax3()
Expand All @@ -26,7 +30,6 @@ cp(joinpath(@__DIR__, "../../GNNGraphs/docs/src/"),

makedocs(;
modules = [GNNlib, GNNGraphs],
doctest = false, # TODO enable doctest
plugins = [interlinks],
format = Documenter.HTML(; mathengine, prettyurls, assets = assets, size_threshold=nothing),
sitename = "GNNlib.jl",
Expand Down
1 change: 1 addition & 0 deletions GraphNeuralNetworks/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
Expand Down
5 changes: 4 additions & 1 deletion GraphNeuralNetworks/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ using GraphNeuralNetworks
using Flux, GNNGraphs, GNNlib, Graphs
using DocumenterInterLinks

ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets

DocMeta.setdocmeta!(GNNGraphs, :DocTestSetup, :(using GNNGraphs, MLUtils); recursive = true)
DocMeta.setdocmeta!(GNNlib, :DocTestSetup, :(using GNNlib); recursive = true)
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, :(using GraphNeuralNetworks); recursive = true)

mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]),
Expand All @@ -37,7 +41,6 @@ cp(joinpath(@__DIR__, "../../GNNlib/docs/src"),

makedocs(;
modules = [GraphNeuralNetworks, GNNGraphs, GNNlib],
doctest = false, # TODO: enable doctest
plugins = [interlinks],
format = Documenter.HTML(; mathengine,
prettyurls = get(ENV, "CI", nothing) == "true",
Expand Down
7 changes: 0 additions & 7 deletions GraphNeuralNetworks/docs/src/api/pool.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@ CollapsedDocStrings = true

# Pooling Layers

## Index

```@index
Order = [:type, :function]
Pages = ["pool.md"]
```

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["layers/pool.jl"]
Expand Down
33 changes: 12 additions & 21 deletions GraphNeuralNetworks/src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,22 @@ julia> using Flux, GraphNeuralNetworks
julia> m = GNNChain(GCNConv(2=>5),
BatchNorm(5),
x -> relu.(x),
Dense(5, 4))
GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4))
Dense(5, 4));

julia> x = randn(Float32, 2, 3);

julia> g = rand_graph(3, 6)
GNNGraph:
num_nodes = 3
num_edges = 6
num_nodes: 3
num_edges: 6

julia> m(g, x)
4×3 Matrix{Float32}:
-0.795592 -0.795592 -0.795592
-0.736409 -0.736409 -0.736409
0.994925 0.994925 0.994925
0.857549 0.857549 0.857549
julia> m(g, x) |> size
(4, 3)

julia> m2 = GNNChain(enc = m,
dec = DotDecoder())
GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder())
julia> m2 = GNNChain(enc = m, dec = DotDecoder());

julia> m2(g, x)
1×6 Matrix{Float32}:
2.90053 2.90053 2.90053 2.90053 2.90053 2.90053
julia> m2(g, x) |> size
(1, 6)

julia> m2[:enc](g, x) == m(g, x)
true
Expand Down Expand Up @@ -196,15 +188,14 @@ returns the dot product `x_i ⋅ xj` on each edge.
```jldoctest
julia> g = rand_graph(5, 6)
GNNGraph:
num_nodes = 5
num_edges = 6
num_nodes: 5
num_edges: 6

julia> dotdec = DotDecoder()
DotDecoder()

julia> dotdec(g, rand(2, 5))
1×6 Matrix{Float64}:
0.345098 0.458305 0.106353 0.345098 0.458305 0.106353
julia> dotdec(g, rand(2, 5)) |> size
(1, 6)
```
"""
struct DotDecoder <: GNNLayer end
Expand Down
6 changes: 4 additions & 2 deletions GraphNeuralNetworks/src/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ have to be aggregated using the `aggr` function. The default is to sum the outpu
# Examples

```jldoctest
julia> g = rand_bipartite_heterograph((10, 15), 20)
julia> using GraphNeuralNetworks, Flux

julia> g = rand_bipartite_heterograph((10, 15), 80)
GNNHeteroGraph:
num_nodes: Dict(:A => 10, :B => 15)
num_edges: Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20)
num_edges: Dict((:A, :to, :B) => 80, (:B, :to, :A) => 80)

julia> x = (A = rand(Float32, 64, 10), B = rand(Float32, 64, 15));

Expand Down
Loading
Loading