You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: GNNLux/docs/src_tutorials/graph_classification.jl
+21-16Lines changed: 21 additions & 16 deletions
Original file line number
Diff line number
Diff line change
@@ -13,11 +13,22 @@
13
13
# The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl.
14
14
# Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**:
15
15
16
-
using Lux, GNNLux
17
-
using MLDatasets, MLUtils
16
+
using Lux
17
+
using GNNLux
18
+
using MLDatasets
19
+
using MLUtils
18
20
using LinearAlgebra, Random, Statistics
19
21
using Zygote, Optimisers, OneHotArrays
20
22
23
+
24
+
struct GlobalPool{F} <:GNNLayer
25
+
aggr::F
26
+
end
27
+
28
+
(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st
# This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl.
193
205
194
-
# As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`.
195
-
# This should bring you close to **82% test accuracy**.
196
-
197
-
# ## Conclusion
198
-
199
-
# In this chapter, you have learned how to apply GNNs to the task of graph classification.
200
-
# You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings.
201
-
206
+
# As an exercise, you are invited to complete the following code to the extent that it makes use of `
Copy file name to clipboardExpand all lines: GraphNeuralNetworks/docs/src_tutorials/introductory_tutorials/temporal_graph_classification.jl
+12-2Lines changed: 12 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -9,12 +9,17 @@
9
9
#
10
10
# We start by importing the necessary libraries. We use `GraphNeuralNetworks.jl`, `Flux.jl` and `MLDatasets.jl`, among others.
11
11
12
+
## Comments Miguel for CLaudio:
13
+
# 1. Create method to check the download datasets are download correctly, if not problems may arise. This happened to me when downloading TemporalBrains dataset.
14
+
15
+
12
16
using Flux
13
17
using GraphNeuralNetworks
14
18
using Statistics, Random
15
19
using LinearAlgebra
16
20
using MLDatasets: TemporalBrains
17
-
using CUDA # comment out if you don't have a CUDA GPU
21
+
using DataDeps
22
+
#using CUDA # comment out if you don't have a CUDA GPU
18
23
19
24
ENV["DATADEPS_ALWAYS_ACCEPT"] ="true"# don't ask for dataset download confirmation
20
25
Random.seed!(17); # for reproducibility
@@ -29,7 +34,7 @@ Random.seed!(17); # for reproducibility
29
34
# Each temporal graph has a label representing gender ('M' for male and 'F' for female) and age group (22-25, 26-30, 31-35, and 36+).
30
35
# The network's edge weights are binarized, and the threshold is set to 0.6 by default.
31
36
32
-
brain_dataset =TemporalBrains()
37
+
brain_dataset =MLDatasets.TemporalBrains()
33
38
34
39
# After loading the dataset from the MLDatasets.jl package, we see that there are 1000 graphs and we need to convert them to the `TemporalSnapshotsGNNGraph` format.
35
40
# So we create a function called `data_loader` that implements the latter and splits the dataset into the training set that will be used to train the model and the test set that will be used to test the performance of the model. Due to computational costs, we use only 250 out of the original 1000 graphs, 200 for training and 50 for testing.
@@ -83,6 +88,11 @@ end
83
88
84
89
Flux.@layer GenderPredictionModel
85
90
91
+
function (l::GINConv)(g::TemporalSnapshotsGNNGraph, x::AbstractVector)
0 commit comments