|  | 
| 167 | 167 | function Base.show(io::IO, a3tgcn::A3TGCN) | 
| 168 | 168 |     print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") | 
| 169 | 169 | end | 
| 170 |  | - | 
| 171 |  | - | 
| 172 |  | -""" | 
| 173 |  | -    EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) | 
| 174 |  | -
 | 
| 175 |  | -Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). | 
| 176 |  | -
 | 
| 177 |  | -Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. | 
| 178 |  | -
 | 
| 179 |  | -
 | 
| 180 |  | -# Arguments | 
| 181 |  | -
 | 
| 182 |  | -- `in`: Number of input features. | 
| 183 |  | -- `out`: Number of output features. | 
| 184 |  | -- `bias`: Add learnable bias. Default `true`. | 
| 185 |  | -- `init`: Weights' initializer. Default `glorot_uniform`. | 
| 186 |  | -- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. | 
| 187 |  | -
 | 
| 188 |  | -# Examples | 
| 189 |  | -
 | 
| 190 |  | -```jldoctest | 
| 191 |  | -julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) | 
| 192 |  | -TemporalSnapshotsGNNGraph: | 
| 193 |  | -  num_nodes: [10, 10, 10] | 
| 194 |  | -  num_edges: [20, 14, 22] | 
| 195 |  | -  num_snapshots: 3 | 
| 196 |  | -
 | 
| 197 |  | -julia> ev = EvolveGCNO(4 => 5) | 
| 198 |  | -EvolveGCNO(4 => 5) | 
| 199 |  | -
 | 
| 200 |  | -julia> size(ev(tg, tg.ndata.x)) | 
| 201 |  | -(3,) | 
| 202 |  | -
 | 
| 203 |  | -julia> size(ev(tg, tg.ndata.x)[1]) | 
| 204 |  | -(5, 10) | 
| 205 |  | -``` | 
| 206 |  | -""" | 
| 207 |  | -struct EvolveGCNO | 
| 208 |  | -    conv | 
| 209 |  | -    W_init | 
| 210 |  | -    init_state | 
| 211 |  | -    in::Int | 
| 212 |  | -    out::Int | 
| 213 |  | -    Wf | 
| 214 |  | -    Uf | 
| 215 |  | -    Bf | 
| 216 |  | -    Wi | 
| 217 |  | -    Ui | 
| 218 |  | -    Bi | 
| 219 |  | -    Wo | 
| 220 |  | -    Uo | 
| 221 |  | -    Bo | 
| 222 |  | -    Wc | 
| 223 |  | -    Uc | 
| 224 |  | -    Bc | 
| 225 |  | -end | 
| 226 |  | -  | 
| 227 |  | -function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) | 
| 228 |  | -    in, out = ch | 
| 229 |  | -    W = init(out, in) | 
| 230 |  | -    conv = GCNConv(ch; bias = bias, init = init) | 
| 231 |  | -    Wf = init(out, in) | 
| 232 |  | -    Uf = init(out, in) | 
| 233 |  | -    Bf = bias ? init(out, in) : nothing | 
| 234 |  | -    Wi = init(out, in) | 
| 235 |  | -    Ui = init(out, in) | 
| 236 |  | -    Bi = bias ? init(out, in) : nothing | 
| 237 |  | -    Wo = init(out, in) | 
| 238 |  | -    Uo = init(out, in) | 
| 239 |  | -    Bo = bias ? init(out, in) : nothing | 
| 240 |  | -    Wc = init(out, in) | 
| 241 |  | -    Uc = init(out, in) | 
| 242 |  | -    Bc = bias ? init(out, in) : nothing | 
| 243 |  | -    return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) | 
| 244 |  | -end | 
| 245 |  | - | 
| 246 |  | -function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) | 
| 247 |  | -    H = egcno.init_state(egcno.out, egcno.in) | 
| 248 |  | -    C = egcno.init_state(egcno.out, egcno.in) | 
| 249 |  | -    W = egcno.W_init | 
| 250 |  | -    X = map(1:tg.num_snapshots) do i | 
| 251 |  | -        F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) | 
| 252 |  | -        I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) | 
| 253 |  | -        O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) | 
| 254 |  | -        C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) | 
| 255 |  | -        C = F .* C + I .* C̃ | 
| 256 |  | -        H = O .* tanh_fast.(C) | 
| 257 |  | -        W = H | 
| 258 |  | -        egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) | 
| 259 |  | -    end | 
| 260 |  | -    return X | 
| 261 |  | -end | 
| 262 |  | -  | 
| 263 |  | -function Base.show(io::IO, egcno::EvolveGCNO) | 
| 264 |  | -    print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") | 
| 265 |  | -end | 
0 commit comments