Skip to content

Commit f431560

Browse files
adapt to flux 0.15
1 parent c6a489f commit f431560

File tree

10 files changed

+75
-116
lines changed

10 files changed

+75
-116
lines changed

GNNGraphs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GNNGraphs"
22
uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "1.3.0"
4+
version = "1.4.0-DEV"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -31,7 +31,7 @@ GNNGraphsSimpleWeightedGraphsExt = "SimpleWeightedGraphs"
3131
Adapt = "4"
3232
CUDA = "5"
3333
ChainRulesCore = "1"
34-
Functors = "0.4.1, 0.5"
34+
Functors = "0.5"
3535
Graphs = "1.4"
3636
KrylovKit = "0.8"
3737
LinearAlgebra = "1"

GNNGraphs/src/GNNGraphs.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module GNNGraphs
22

33
using SparseArrays
4-
using Functors: @functor
54
import Graphs
65
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
76
has_self_loops, is_directed, induced_subgraph, has_edge
@@ -13,7 +12,6 @@ using ChainRulesCore
1312
using LinearAlgebra, Random, Statistics
1413
import MLUtils
1514
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
16-
import Functors
1715
using MLDataDevices: get_device, cpu_device, CPUDevice
1816

1917
include("chainrules.jl") # hacks for differentiability

GNNGraphs/src/datastore.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ struct DataStore
7070
end
7171
end
7272

73-
@functor DataStore
74-
7573
DataStore(data) = DataStore(-1, data)
7674
DataStore(n::Int, data::NamedTuple) = DataStore(n, Dict{Symbol, Any}(pairs(data)))
7775
DataStore(n::Int, data) = DataStore(n, Dict{Symbol, Any}(data))

GNNGraphs/src/gnngraph.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
116116
gdata::DataStore
117117
end
118118

119-
@functor GNNGraph
120-
121119
function GNNGraph(data::D;
122120
num_nodes = nothing,
123121
graph_indicator = nothing,

GNNGraphs/src/gnnheterograph/gnnheterograph.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
9595
etypes::Vector{EType}
9696
end
9797

98-
@functor GNNHeteroGraph
99-
10098
GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
10199
GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...)
102100

GNNGraphs/src/temporalsnapshotsgnngraph.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,5 +240,3 @@ function print_feature_t(io::IO, feature)
240240
print(io, "no")
241241
end
242242
end
243-
244-
@functor TemporalSnapshotsGNNGraph

GNNGraphs/test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using CUDA, cuDNN
22
using GNNGraphs
33
using GNNGraphs: getn, getdata
4-
using Functors
4+
using Functors: Functors
55
using LinearAlgebra, Statistics, Random
66
using NNlib
77
import MLUtils

GraphNeuralNetworks/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1818

1919
[compat]
2020
ChainRulesCore = "1"
21-
Flux = "0.14"
21+
Flux = "0.15"
2222
GNNGraphs = "1.0"
2323
GNNlib = "0.2"
2424
LinearAlgebra = "1"

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 69 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
# Adapting Flux.Recur to work with GNNGraphs
2-
function (m::Flux.Recur)(g::GNNGraph, x)
3-
m.state, y = m.cell(m.state, g, x)
4-
return y
5-
end
1+
# # Adapting Flux.Recur to work with GNNGraphs
2+
# function (m::Flux.Recur)(g::GNNGraph, x)
3+
# m.state, y = m.cell(m.state, g, x)
4+
# return y
5+
# end
66

7-
function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T
8-
h = [m(g, x_t) for x_t in Flux.eachlastdim(x)]
9-
sze = size(h[1])
10-
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
11-
end
7+
# function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T
8+
# h = [m(g, x_t) for x_t in Flux.eachlastdim(x)]
9+
# sze = size(h[1])
10+
# reshape(reduce(hcat, h), sze[1], sze[2], length(h))
11+
# end
1212

1313
struct TGCNCell <: GNNLayer
1414
conv::GCNConv
1515
gru::Flux.GRUv3Cell
16-
state0
1716
in::Int
1817
out::Int
1918
end
@@ -23,29 +22,26 @@ Flux.@layer TGCNCell
2322
function TGCNCell(ch::Pair{Int, Int};
2423
bias::Bool = true,
2524
init = Flux.glorot_uniform,
26-
init_state = Flux.zeros32,
2725
add_self_loops = false,
2826
use_edge_weight = true)
2927
in, out = ch
30-
conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops,
31-
use_edge_weight)
32-
gru = Flux.GRUv3Cell(out, out)
33-
state0 = init_state(out,1)
34-
return TGCNCell(conv, gru, state0, in,out)
28+
conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops, use_edge_weight)
29+
gru = Flux.GRUCell(out => out)
30+
return TGCNCell(conv, gru, in, out)
3531
end
3632

37-
function (tgcn::TGCNCell)(h, g::GNNGraph, x::AbstractArray)
38-
= tgcn.conv(g, x)
39-
h, x̃ = tgcn.gru(h, x̃)
40-
return h, x̃
33+
function (tgcn::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat, h::AbstractVecOrMat)
34+
x = tgcn.conv(g, x)
35+
x, h = tgcn.gru(x, h)
36+
return x, h
4137
end
4238

4339
function Base.show(io::IO, tgcn::TGCNCell)
4440
print(io, "TGCNCell($(tgcn.in) => $(tgcn.out))")
4541
end
4642

4743
"""
48-
TGCN(in => out; [bias, init, init_state, add_self_loops, use_edge_weight])
44+
TGCN(in => out; [bias, init, add_self_loops, use_edge_weight])
4945
5046
Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf).
5147
@@ -57,12 +53,20 @@ Performs a layer of GCNConv to model spatial dependencies, followed by a Gated R
5753
- `out`: Number of output features.
5854
- `bias`: Add learnable bias. Default `true`.
5955
- `init`: Weights' initializer. Default `glorot_uniform`.
60-
- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
6156
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
6257
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
6358
If `add_self_loops=true` the new weights will be set to 1.
6459
This option is ignored if the `edge_weight` is explicitly provided in the forward pass.
6560
Default `false`.
61+
62+
# Forward
63+
64+
tgcn(g::GNNGraph, x, [h])
65+
66+
- `x`: The input to the TGCN. It should be a matrix size `in x timesteps` or an array of size `in x timesteps x num_nodes`.
67+
- `h`: The initial hidden state of the GRU cell. If given, it is a vector of size `out` or a matrix of size `out x num_nodes`.
68+
If not provided, it is assumed to be a vector of zeros.
69+
6670
# Examples
6771
6872
```jldoctest
@@ -78,30 +82,43 @@ Recur(
7882
) # Total: 8 trainable arrays, 264 parameters,
7983
# plus 1 non-trainable, 6 parameters, summarysize 1.492 KiB.
8084
81-
julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5);
85+
julia> g = rand_graph(5, 10);
86+
87+
julia> x = rand(Float32, 2, 5);
8288
8389
julia> y = tgcn(g, x);
8490
8591
julia> size(y)
8692
(6, 5)
8793
88-
julia> Flux.reset!(tgcn);
89-
90-
julia> tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)) |> size # batch size of 20
94+
julia> tgcn(g, rand(Float32, 2, 5, 20)) |> size # batch size of 20
9195
(6, 5, 20)
9296
```
93-
94-
!!! warning "Batch size changes"
95-
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior.
9697
"""
97-
TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...))
98+
struct TGCN
99+
tgcn::TGCNCell
100+
end
101+
102+
Flux.@layer TGCN
103+
104+
TGCN(ch::Pair{Int, Int}; kws...) = TGCN(TGCNCell(ch; kws...))
98105

99-
Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0)
106+
function (tgcn::TGCN)(g::GNNGraph, x::AbstractArray, h)
107+
for i in 1:size(x, 2)
108+
x, h = tgcn.tgcn(g, x[:, i], h)
109+
end
110+
return x
111+
end
112+
100113

101-
# make TGCN compatible with GNNChain
102-
(l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
103-
_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x)
104-
_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g)
114+
# TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...))
115+
116+
# Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0)
117+
118+
# # make TGCN compatible with GNNChain
119+
# (l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
120+
# _applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x)
121+
# _applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g)
105122

106123

107124
"""
@@ -149,7 +166,7 @@ julia> size(y)
149166
Failing to call `reset!` when the input batch size changes can lead to unexpected behavior.
150167
"""
151168
struct A3TGCN <: GNNLayer
152-
tgcn::Flux.Recur{TGCNCell}
169+
tgcn::TGCN
153170
dense1::Dense
154171
dense2::Dense
155172
in::Int
@@ -272,12 +289,12 @@ julia> size(z)
272289
(5, 5, 30)
273290
```
274291
"""
275-
GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...))
276-
Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
292+
# GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...))
293+
# Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
277294

278-
(l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
279-
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x)
280-
_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g)
295+
# (l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
296+
# _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x)
297+
# _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g)
281298

282299
struct GConvLSTMCell <: GNNLayer
283300
conv_x_i::ChebConv
@@ -394,12 +411,12 @@ julia> size(z)
394411
(5, 5, 30)
395412
```
396413
"""
397-
GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...))
398-
Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
414+
# GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...))
415+
# Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0)
399416

400-
(l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
401-
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
402-
_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)
417+
# (l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
418+
# _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x)
419+
# _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g)
403420

404421
struct DCGRUCell
405422
in::Int
@@ -477,12 +494,12 @@ julia> size(z)
477494
(5, 5, 30)
478495
```
479496
"""
480-
DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...))
481-
Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
497+
# DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...))
498+
# Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
482499

483-
(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
484-
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
485-
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)
500+
# (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
501+
# _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
502+
# _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)
486503

487504
"""
488505
EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
@@ -539,8 +556,6 @@ struct EvolveGCNO
539556
Bc
540557
end
541558

542-
Flux.@functor EvolveGCNO
543-
544559
function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
545560
in, out = ch
546561
W = init(out, in)
@@ -580,51 +595,3 @@ end
580595
function Base.show(io::IO, egcno::EvolveGCNO)
581596
print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))")
582597
end
583-
584-
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
585-
return l.(tg.snapshots, x)
586-
end
587-
588-
function (l::ChebConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
589-
return l.(tg.snapshots, x)
590-
end
591-
592-
function (l::GATConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
593-
return l.(tg.snapshots, x)
594-
end
595-
596-
function (l::GATv2Conv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
597-
return l.(tg.snapshots, x)
598-
end
599-
600-
function (l::GatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
601-
return l.(tg.snapshots, x)
602-
end
603-
604-
function (l::CGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
605-
return l.(tg.snapshots, x)
606-
end
607-
608-
function (l::SGConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
609-
return l.(tg.snapshots, x)
610-
end
611-
612-
function (l::TransformerConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
613-
return l.(tg.snapshots, x)
614-
end
615-
616-
function (l::GCNConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
617-
return l.(tg.snapshots, x)
618-
end
619-
620-
function (l::ResGatedGraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
621-
return l.(tg.snapshots, x)
622-
end
623-
624-
function (l::SAGEConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
625-
return l.(tg.snapshots, x)
626-
end
627-
628-
function (l::GraphConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
629-
return l.(tg.snapshots, x)
630-
end

GraphNeuralNetworks/test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
33
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
44
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
55
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
6+
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
7+
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
68
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
79
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
810
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"

0 commit comments

Comments
 (0)