Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@latest
with:
version: '1.9.1'
version: '1.10.4'
- name: Install dependencies
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
run: julia --project=docs/ -e '
using Pkg;
Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path=joinpath(pwd(), "GNNGraphs"))]);
Pkg.instantiate();'
- name: Build and deploy
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # If authenticating with SSH deploy key
run: julia --project=docs/ docs/make.jl
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
run: julia --project=docs/ docs/make.jl
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Manifest.toml
/docs/build/
.vscode
LocalPreferences.toml
.DS_Store
.DS_Store
docs/src/democards/gridtheme.css
5 changes: 5 additions & 0 deletions GNNGraphs/src/operators.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# 2 or more args graph operators
""""
intersect(g1::GNNGraph, g2::GNNGraph)

Intersect two graphs by keeping only the common edges.
"""
function Base.intersect(g1::GNNGraph, g2::GNNGraph)
@assert g1.num_nodes == g2.num_nodes
@assert graph_type_symbol(g1) == graph_type_symbol(g2)
Expand Down
26 changes: 25 additions & 1 deletion GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,17 @@ Graphs.ne(g::GNNGraph) = g.num_edges
Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes
Graphs.vertices(g::GNNGraph) = 1:(g.num_nodes)

function Graphs.neighbors(g::GNNGraph, i; dir = :out)

"""
neighbors(g::GNNGraph, i::Integer; dir=:out)

Return the neighbors of node `i` in the graph `g`.
If `dir=:out`, return the neighbors through outgoing edges.
If `dir=:in`, return the neighbors through incoming edges.

See also [`outneighbors`](@ref Graphs.outneighbors), [`inneighbors`](@ref Graphs.inneighbors).
"""
function Graphs.neighbors(g::GNNGraph, i::Integer; dir::Symbol = :out)
@assert dir ∈ (:in, :out)
if dir == :out
outneighbors(g, i)
Expand All @@ -98,6 +108,13 @@ function Graphs.neighbors(g::GNNGraph, i; dir = :out)
end
end

"""
outneighbors(g::GNNGraph, i::Integer)

Return the neighbors of node `i` in the graph `g` through outgoing edges.

See also [`neighbors`](@ref Graphs.neighbors) and [`inneighbors`](@ref Graphs.inneighbors).
"""
function Graphs.outneighbors(g::GNNGraph{<:COO_T}, i::Integer)
s, t = edge_index(g)
return t[s .== i]
Expand All @@ -108,6 +125,13 @@ function Graphs.outneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer)
return findall(!=(0), A[i, :])
end

"""
inneighbors(g::GNNGraph, i::Integer)

Return the neighbors of node `i` in the graph `g` through incoming edges.

See also [`neighbors`](@ref Graphs.neighbors) and [`outneighbors`](@ref Graphs.outneighbors).
"""
function Graphs.inneighbors(g::GNNGraph{<:COO_T}, i::Integer)
s, t = edge_index(g)
return s[t .== i]
Expand Down
6 changes: 0 additions & 6 deletions GNNlib/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -38,14 +35,11 @@ KrylovKit = "0.6, 0.7"
LinearAlgebra = "1"
MLDatasets = "0.7"
MLUtils = "0.4"
MacroTools = "0.5"
NNlib = "0.9"
NearestNeighbors = "0.4"
Random = "1"
Reexport = "1"
SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
cuDNN = "1"
julia = "1.10"

Expand Down
2 changes: 0 additions & 2 deletions GNNlib/src/GNNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module GNNlib

using Statistics: mean
using LinearAlgebra, Random
using Base: tail
using MacroTools: @forward
using MLUtils
using NNlib
using NNlib: scatter, gather
Expand Down
21 changes: 5 additions & 16 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@ authors = ["Carlo Lucibello and contributors"]
version = "0.6.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -30,26 +24,19 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GraphNeuralNetworksCUDAExt = "CUDA"

[compat]
Adapt = "3, 4"
CUDA = "4, 5"
ChainRulesCore = "1"
DataStructures = "0.18"
Flux = "0.14"
Functors = "0.4.1"
Graphs = "1.4"
GNNGraphs = "1.0"
KrylovKit = "0.6, 0.7, 0.8"
LinearAlgebra = "1"
MLDatasets = "0.7"
MLUtils = "0.4"
MacroTools = "0.5"
MLUtils = "0.4"
NNlib = "0.9"
NearestNeighbors = "0.4"
Random = "1"
Reexport = "1"
SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
cuDNN = "1"
julia = "1.10"

Expand All @@ -59,11 +46,13 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SparseArrays", "Graphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
5 changes: 3 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
[deps]
DemoCards = "311a05b2-6137-4a5a-b473-18580a3d38b5"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MarkdownLiteral = "736d6165-7244-6769-4267-6b50796e6954"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
Expand All @@ -17,4 +18,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
DemoCards = "0.5.0"
Documenter = "0.27"
Documenter = "1.5"
20 changes: 16 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
using Flux, NNlib, GraphNeuralNetworks, Graphs, SparseArrays
using GraphNeuralNetworks
using GNNGraphs
using Flux
using NNlib
using Graphs
using SparseArrays
using Pluto, PlutoStaticHTML # for tutorials
using Documenter, DemoCards
using DocumenterInterLinks

tutorials, tutorials_cb, tutorial_assets = makedemos("tutorials")

tutorials, tutorials_cb, tutorial_assets = makedemos("tutorials")
assets = []
isnothing(tutorial_assets) || push!(assets, tutorial_assets)

interlinks = InterLinks(
"NNlib" => "https://fluxml.ai/NNlib.jl/stable/",
"Graphs" => "https://juliagraphs.org/Graphs.jl/stable/")


DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup,
:(using GraphNeuralNetworks, Graphs, SparseArrays, NNlib, Flux);
recursive = true)
Expand All @@ -15,10 +26,11 @@ prettyurls = get(ENV, "CI", nothing) == "true"
mathengine = MathJax3()

makedocs(;
modules = [GraphNeuralNetworks, NNlib, Flux, Graphs, SparseArrays],
modules = [GraphNeuralNetworks, GNNGraphs],
doctest = false,
clean = true,
format = Documenter.HTML(; mathengine, prettyurls, assets = assets),
plugins = [interlinks],
format = Documenter.HTML(; mathengine, prettyurls, assets = assets, size_threshold=nothing),
sitename = "GraphNeuralNetworks.jl",
pages = ["Home" => "index.md",
"Graphs" => ["gnngraph.md", "heterograph.md", "temporalgraph.md"],
Expand Down
30 changes: 15 additions & 15 deletions docs/pluto_output/gnn_intro_pluto.md

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions docs/pluto_output/graph_classification_pluto.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
```@raw html
<style>
table {
#documenter-page table {
display: table !important;
margin: 2rem auto !important;
border-top: 2pt solid rgba(0,0,0,0.2);
border-bottom: 2pt solid rgba(0,0,0,0.2);
}

pre, div {
#documenter-page pre, #documenter-page div {
margin-top: 1.4rem !important;
margin-bottom: 1.4rem !important;
}
Expand All @@ -25,8 +25,8 @@
<!--
# This information is used for caching.
[PlutoStaticHTML.State]
input_sha = "f145b80b8f1e399d4cd5686b529cf173942102c538702952fe0743defca62210"
julia_version = "1.9.1"
input_sha = "62d9b08cdb51a5d174d1d090f3e4834f98df0c30b8b515e5befdd8fa22bd5c7f"
julia_version = "1.10.4"
-->
<pre class='language-julia'><code class='language-julia'>begin
using Flux
Expand Down Expand Up @@ -102,7 +102,7 @@ end</code></pre>
<div class="markdown"><p>We have some useful utilities for working with graph datasets, <em>e.g.</em>, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing:</p></div>

<pre class='language-julia'><code class='language-julia'>train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |&gt; getobs</code></pre>
<pre class="code-output documenter-example-output" id="var-train_data">((GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(12, 24) with x: 7×12 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(25, 56) with x: 7×25 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(18, 38) with x: 7×18 data, GNNGraph(23, 52) with x: 7×23 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(20, 46) with x: 7×20 data … GNNGraph(16, 34) with x: 7×16 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(21, 44) with x: 7×21 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(12, 24) with x: 7×12 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(16, 36) with x: 7×16 data], Bool[1 0 … 1 0; 0 1 … 0 1]), (GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(21, 44) with x: 7×21 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(27, 66) with x: 7×27 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(13, 28) with x: 7×13 data … GNNGraph(11, 22) with x: 7×11 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(18, 40) with x: 7×18 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(14, 30) with x: 7×14 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(21, 44) with x: 7×21 data, GNNGraph(22, 50) with x: 7×22 data], Bool[0 0 … 0 0; 1 1 … 1 1]))</pre>
<pre class="code-output documenter-example-output" id="var-train_data">((GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(16, 34) with x: 7×16 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(14, 30) with x: 7×14 data, GNNGraph(18, 38) with x: 7×18 data … GNNGraph(12, 26) with x: 7×12 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(26, 60) with x: 7×26 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(22, 50) with x: 7×22 data], Bool[0 0 … 0 0; 1 1 … 1 1]), (GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(26, 60) with x: 7×26 data, GNNGraph(15, 34) with x: 7×15 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(24, 50) with x: 7×24 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(21, 44) with x: 7×21 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(17, 38) with x: 7×17 data … GNNGraph(12, 26) with x: 7×12 data, GNNGraph(23, 52) with x: 7×23 data, GNNGraph(12, 24) with x: 7×12 data, GNNGraph(23, 50) with x: 7×23 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(18, 40) with x: 7×18 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(28, 62) with x: 7×28 data, GNNGraph(11, 22) with x: 7×11 data], Bool[0 0 … 0 1; 1 1 … 1 0]))</pre>

<pre class='language-julia'><code class='language-julia'>begin
train_loader = DataLoader(train_data, batchsize = 32, shuffle = true)
Expand All @@ -113,7 +113,7 @@ end</code></pre>
(32-element Vector{GraphNeuralNetworks.GNNGraphs.GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, 2×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)</pre>


<div class="markdown"><p>Here, we opt for a <code>batch_size</code> of 32, leading to 5 (randomly shuffled) mini-batches, containing all <span class="tex">$4 \cdot 32+22 = 150$</span> graphs.</p></div>
<div class="markdown"><p>Here, we opt for a <code>batch_size</code> of 32, leading to 5 (randomly shuffled) mini-batches, containing all <span class="tex">\(4 \cdot 32+22 = 150\)</span> graphs.</p></div>


```
Expand All @@ -123,15 +123,15 @@ end</code></pre>
<p>Since graphs in graph classification datasets are usually small, a good idea is to <strong>batch the graphs</strong> before inputting them into a Graph Neural Network to guarantee full GPU utilization. In the image or language domain, this procedure is typically achieved by <strong>rescaling</strong> or <strong>padding</strong> each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the <code>batchsize</code>.</p><p>However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. Therefore, GraphNeuralNetworks.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension).</p><p>This procedure has some crucial advantages over other batching procedures:</p><ol><li><p>GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs.</p></li><li><p>There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, <em>i.e.</em>, the edges.</p></li></ol><p>GraphNeuralNetworks.jl can <strong>batch multiple graphs into a single giant graph</strong>:</p></div>

<pre class='language-julia'><code class='language-julia'>vec_gs, _ = first(train_loader)</code></pre>
<pre class="code-output documenter-example-output" id="var-vec_gs">(GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(17, 38) with x: 7×17 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(14, 30) with x: 7×14 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(24, 50) with x: 7×24 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(15, 34) with x: 7×15 data … GNNGraph(16, 34) with x: 7×16 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(26, 60) with x: 7×26 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(24, 50) with x: 7×24 data], Bool[0 0 … 0 0; 1 1 … 1 1])</pre>
<pre class="code-output documenter-example-output" id="var-vec_gs">(GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(19, 44) with x: 7×19 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(15, 34) with x: 7×15 data, GNNGraph(25, 56) with x: 7×25 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(20, 44) with x: 7×20 data … GNNGraph(12, 24) with x: 7×12 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(11, 22) with x: 7×11 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(14, 30) with x: 7×14 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(23, 54) with x: 7×23 data], Bool[0 0 … 0 0; 1 1 … 1 1])</pre>

<pre class='language-julia'><code class='language-julia'>MLUtils.batch(vec_gs)</code></pre>
<pre class="code-output documenter-example-output" id="var-hash102363">GNNGraph:
num_nodes: 585
num_edges: 1292
num_nodes: 575
num_edges: 1276
num_graphs: 32
ndata:
x = 7×585 Matrix{Float32}</pre>
x = 7×575 Matrix{Float32}</pre>


<div class="markdown"><p>Each batched graph object is equipped with a <strong><code>graph_indicator</code> vector</strong>, which maps each node to its respective graph in the batch:</p><p class="tex">$$\textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ]$$</p></div>
Expand All @@ -154,7 +154,7 @@ end</code></pre>
<pre class="code-output documenter-example-output" id="var-create_model">create_model (generic function with 1 method)</pre>


<div class="markdown"><p>Here, we again make use of the <code>GCNConv</code> with <span class="tex">$\mathrm{ReLU}(x) = \max(x, 0)$</span> activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer.</p><p>Let's train our network for a few epochs to see how well it performs on the training as well as test set:</p></div>
<div class="markdown"><p>Here, we again make use of the <code>GCNConv</code> with <span class="tex">\(\mathrm{ReLU}(x) = \max(x, 0)\)</span> activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer.</p><p>Let's train our network for a few epochs to see how well it performs on the training as well as test set:</p></div>

<pre class='language-julia'><code class='language-julia'>function eval_loss_accuracy(model, data_loader, device)
loss = 0.0
Expand Down
Loading
Loading