|
167 | 167 | function Base.show(io::IO, a3tgcn::A3TGCN)
|
168 | 168 | print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))")
|
169 | 169 | end
|
170 |
| - |
171 |
| - |
172 |
| -""" |
173 |
| - EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) |
174 |
| -
|
175 |
| -Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). |
176 |
| -
|
177 |
| -Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. |
178 |
| -
|
179 |
| -
|
180 |
| -# Arguments |
181 |
| -
|
182 |
| -- `in`: Number of input features. |
183 |
| -- `out`: Number of output features. |
184 |
| -- `bias`: Add learnable bias. Default `true`. |
185 |
| -- `init`: Weights' initializer. Default `glorot_uniform`. |
186 |
| -- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. |
187 |
| -
|
188 |
| -# Examples |
189 |
| -
|
190 |
| -```jldoctest |
191 |
| -julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) |
192 |
| -TemporalSnapshotsGNNGraph: |
193 |
| - num_nodes: [10, 10, 10] |
194 |
| - num_edges: [20, 14, 22] |
195 |
| - num_snapshots: 3 |
196 |
| -
|
197 |
| -julia> ev = EvolveGCNO(4 => 5) |
198 |
| -EvolveGCNO(4 => 5) |
199 |
| -
|
200 |
| -julia> size(ev(tg, tg.ndata.x)) |
201 |
| -(3,) |
202 |
| -
|
203 |
| -julia> size(ev(tg, tg.ndata.x)[1]) |
204 |
| -(5, 10) |
205 |
| -``` |
206 |
| -""" |
207 |
| -struct EvolveGCNO |
208 |
| - conv |
209 |
| - W_init |
210 |
| - init_state |
211 |
| - in::Int |
212 |
| - out::Int |
213 |
| - Wf |
214 |
| - Uf |
215 |
| - Bf |
216 |
| - Wi |
217 |
| - Ui |
218 |
| - Bi |
219 |
| - Wo |
220 |
| - Uo |
221 |
| - Bo |
222 |
| - Wc |
223 |
| - Uc |
224 |
| - Bc |
225 |
| -end |
226 |
| - |
227 |
| -function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) |
228 |
| - in, out = ch |
229 |
| - W = init(out, in) |
230 |
| - conv = GCNConv(ch; bias = bias, init = init) |
231 |
| - Wf = init(out, in) |
232 |
| - Uf = init(out, in) |
233 |
| - Bf = bias ? init(out, in) : nothing |
234 |
| - Wi = init(out, in) |
235 |
| - Ui = init(out, in) |
236 |
| - Bi = bias ? init(out, in) : nothing |
237 |
| - Wo = init(out, in) |
238 |
| - Uo = init(out, in) |
239 |
| - Bo = bias ? init(out, in) : nothing |
240 |
| - Wc = init(out, in) |
241 |
| - Uc = init(out, in) |
242 |
| - Bc = bias ? init(out, in) : nothing |
243 |
| - return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) |
244 |
| -end |
245 |
| - |
246 |
| -function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) |
247 |
| - H = egcno.init_state(egcno.out, egcno.in) |
248 |
| - C = egcno.init_state(egcno.out, egcno.in) |
249 |
| - W = egcno.W_init |
250 |
| - X = map(1:tg.num_snapshots) do i |
251 |
| - F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) |
252 |
| - I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) |
253 |
| - O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) |
254 |
| - C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) |
255 |
| - C = F .* C + I .* C̃ |
256 |
| - H = O .* tanh_fast.(C) |
257 |
| - W = H |
258 |
| - egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) |
259 |
| - end |
260 |
| - return X |
261 |
| -end |
262 |
| - |
263 |
| -function Base.show(io::IO, egcno::EvolveGCNO) |
264 |
| - print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") |
265 |
| -end |
0 commit comments