@@ -255,3 +255,264 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
255
255
end
256
256
257
257
258
+ @concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)}
259
+ ϕe
260
+ ϕx
261
+ ϕh
262
+ num_features
263
+ residual:: Bool
264
+ end
265
+
266
+ function EGNNConv (ch:: Pair{Int, Int} , hidden_size = 2 * ch[1 ]; residual = false )
267
+ return EGNNConv ((ch[1 ], 0 ) => ch[2 ]; hidden_size, residual)
268
+ end
269
+
270
+ # Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
271
+ function EGNNConv (ch:: Pair{NTuple{2, Int}, Int} ; hidden_size:: Int = 2 * ch[1 ][1 ],
272
+ residual = false )
273
+ (in_size, edge_feat_size), out_size = ch
274
+ act_fn = swish
275
+
276
+ # +1 for the radial feature: ||x_i - x_j||^2
277
+ ϕe = Chain (Dense (in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
278
+ Dense (hidden_size => hidden_size, act_fn))
279
+
280
+ ϕh = Chain (Dense (in_size + hidden_size => hidden_size, swish),
281
+ Dense (hidden_size => out_size))
282
+
283
+ ϕx = Chain (Dense (hidden_size => hidden_size, swish),
284
+ Dense (hidden_size => 1 , use_bias = false ))
285
+
286
+ num_features = (in = in_size, edge = edge_feat_size, out = out_size,
287
+ hidden = hidden_size)
288
+ if residual
289
+ @assert in_size== out_size " Residual connection only possible if in_size == out_size"
290
+ end
291
+ return EGNNConv (ϕe, ϕx, ϕh, num_features, residual)
292
+ end
293
+
294
+ LuxCore. outputsize (l:: EGNNConv ) = (l. num_features. out,)
295
+
296
+ (l:: EGNNConv )(g, h, x, ps, st) = l (g, h, x, nothing , ps, st)
297
+
298
+ function (l:: EGNNConv )(g, h, x, e, ps, st)
299
+ ϕe = StatefulLuxLayer {true} (l. ϕe, ps. ϕe, _getstate (st, :ϕe ))
300
+ ϕx = StatefulLuxLayer {true} (l. ϕx, ps. ϕx, _getstate (st, :ϕx ))
301
+ ϕh = StatefulLuxLayer {true} (l. ϕh, ps. ϕh, _getstate (st, :ϕh ))
302
+ m = (; ϕe, ϕx, ϕh, l. residual, l. num_features)
303
+ return GNNlib. egnn_conv (m, g, h, x, e), st
304
+ end
305
+
306
+ function Base. show (io:: IO , l:: EGNNConv )
307
+ ne = l. num_features. edge
308
+ nin = l. num_features. in
309
+ nout = l. num_features. out
310
+ nh = l. num_features. hidden
311
+ print (io, " EGNNConv(($nin , $ne ) => $nout ; hidden_size=$nh " )
312
+ if l. residual
313
+ print (io, " , residual=true" )
314
+ end
315
+ print (io, " )" )
316
+ end
317
+
318
+ @concrete struct DConv <: GNNLayer
319
+ in_dims:: Int
320
+ out_dims:: Int
321
+ k:: Int
322
+ init_weight
323
+ init_bias
324
+ use_bias:: Bool
325
+ end
326
+
327
+ function DConv (ch:: Pair{Int, Int} , k:: Int ;
328
+ init_weight = glorot_uniform,
329
+ init_bias = zeros32,
330
+ use_bias = true )
331
+ in, out = ch
332
+ return DConv (in, out, k, init_weight, init_bias, use_bias)
333
+ end
334
+
335
+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: DConv )
336
+ weights = l. init_weight (rng, 2 , l. k, l. out_dims, l. in_dims)
337
+ if l. use_bias
338
+ bias = l. init_bias (rng, l. out_dims)
339
+ return (; weights, bias)
340
+ else
341
+ return (; weights)
342
+ end
343
+ end
344
+
345
+ LuxCore. outputsize (l:: DConv ) = (l. out_dims,)
346
+ LuxCore. parameterlength (l:: DConv ) = l. use_bias ? 2 * l. in_dims * l. out_dims * l. k + l. out_dims :
347
+ 2 * l. in_dims * l. out_dims * l. k
348
+
349
+ function (l:: DConv )(g, x, ps, st)
350
+ m = (; ps. weights, bias = _getbias (ps), l. k)
351
+ return GNNlib. d_conv (m, g, x), st
352
+ end
353
+
354
+ function Base. show (io:: IO , l:: DConv )
355
+ print (io, " DConv($(l. in_dims) => $(l. out_dims) , k=$(l. k) )" )
356
+ end
357
+
358
+ @concrete struct GATConv <: GNNLayer
359
+ dense_x
360
+ dense_e
361
+ init_weight
362
+ init_bias
363
+ use_bias:: Bool
364
+ σ
365
+ negative_slope
366
+ channel:: Pair{NTuple{2, Int}, Int}
367
+ heads:: Int
368
+ concat:: Bool
369
+ add_self_loops:: Bool
370
+ dropout
371
+ end
372
+
373
+
374
+ GATConv (ch:: Pair{Int, Int} , args... ; kws... ) = GATConv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
375
+
376
+ function GATConv (ch:: Pair{NTuple{2, Int}, Int} , σ = identity;
377
+ heads:: Int = 1 , concat:: Bool = true , negative_slope = 0.2 ,
378
+ init_weight = glorot_uniform, init_bias = zeros32,
379
+ use_bias:: Bool = true ,
380
+ add_self_loops = true , dropout= 0.0 )
381
+ (in, ein), out = ch
382
+ if add_self_loops
383
+ @assert ein== 0 " Using edge features and setting add_self_loops=true at the same time is not yet supported."
384
+ end
385
+
386
+ dense_x = Dense (in => out * heads, use_bias = false )
387
+ dense_e = ein > 0 ? Dense (ein => out * heads, use_bias = false ) : nothing
388
+ negative_slope = convert (Float32, negative_slope)
389
+ return GATConv (dense_x, dense_e, init_weight, init_bias, use_bias,
390
+ σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
391
+ end
392
+
393
+ LuxCore. outputsize (l:: GATConv ) = (l. concat ? l. channel[2 ]* l. heads : l. channel[2 ],)
394
+ # #TODO : parameterlength
395
+
396
+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: GATConv )
397
+ (in, ein), out = l. channel
398
+ dense_x = LuxCore. initialparameters (rng, l. dense_x)
399
+ a = l. init_weight (ein > 0 ? 3 out : 2 out, l. heads)
400
+ ps = (; dense_x, a)
401
+ if ein > 0
402
+ ps = (ps... , dense_e = LuxCore. initialparameters (rng, l. dense_e))
403
+ end
404
+ if l. use_bias
405
+ ps = (ps... , bias = l. init_bias (rng, l. concat ? out * l. heads : out))
406
+ end
407
+ return ps
408
+ end
409
+
410
+ (l:: GATConv )(g, x, ps, st) = l (g, x, nothing , ps, st)
411
+
412
+ function (l:: GATConv )(g, x, e, ps, st)
413
+ dense_x = StatefulLuxLayer {true} (l. dense_x, ps. dense_x, _getstate (st, :dense_x ))
414
+ dense_e = l. dense_e === nothing ? nothing :
415
+ StatefulLuxLayer {true} (l. dense_e, ps. dense_e, _getstate (st, :dense_e ))
416
+
417
+ m = (; l. add_self_loops, l. channel, l. heads, l. concat, l. dropout, l. σ,
418
+ ps. a, bias = _getbias (ps), dense_x, dense_e, l. negative_slope)
419
+ return GNNlib. gat_conv (m, g, x, e), st
420
+ end
421
+
422
+ function Base. show (io:: IO , l:: GATConv )
423
+ (in, ein), out = l. channel
424
+ print (io, " GATConv(" , ein == 0 ? in : (in, ein), " => " , out ÷ l. heads)
425
+ l. σ == identity || print (io, " , " , l. σ)
426
+ print (io, " , negative_slope=" , l. negative_slope)
427
+ print (io, " )" )
428
+ end
429
+
430
+ @concrete struct GATv2Conv <: GNNLayer
431
+ dense_i
432
+ dense_j
433
+ dense_e
434
+ init_weight
435
+ init_bias
436
+ use_bias:: Bool
437
+ σ
438
+ negative_slope
439
+ channel:: Pair{NTuple{2, Int}, Int}
440
+ heads:: Int
441
+ concat:: Bool
442
+ add_self_loops:: Bool
443
+ dropout
444
+ end
445
+
446
+ function GATv2Conv (ch:: Pair{Int, Int} , args... ; kws... )
447
+ GATv2Conv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
448
+ end
449
+
450
+ function GATv2Conv (ch:: Pair{NTuple{2, Int}, Int} ,
451
+ σ = identity;
452
+ heads:: Int = 1 ,
453
+ concat:: Bool = true ,
454
+ negative_slope = 0.2 ,
455
+ init_weight = glorot_uniform,
456
+ init_bias = zeros32,
457
+ use_bias:: Bool = true ,
458
+ add_self_loops = true ,
459
+ dropout= 0.0 )
460
+
461
+ (in, ein), out = ch
462
+
463
+ if add_self_loops
464
+ @assert ein== 0 " Using edge features and setting add_self_loops=true at the same time is not yet supported."
465
+ end
466
+
467
+ dense_i = Dense (in => out * heads; use_bias, init_weight, init_bias)
468
+ dense_j = Dense (in => out * heads; use_bias = false , init_weight)
469
+ if ein > 0
470
+ dense_e = Dense (ein => out * heads; use_bias = false , init_weight)
471
+ else
472
+ dense_e = nothing
473
+ end
474
+ return GATv2Conv (dense_i, dense_j, dense_e,
475
+ init_weight, init_bias, use_bias,
476
+ σ, negative_slope,
477
+ ch, heads, concat, add_self_loops, dropout)
478
+ end
479
+
480
+
481
+ LuxCore. outputsize (l:: GATv2Conv ) = (l. concat ? l. channel[2 ]* l. heads : l. channel[2 ],)
482
+ # #TODO : parameterlength
483
+
484
+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: GATv2Conv )
485
+ (in, ein), out = l. channel
486
+ dense_i = LuxCore. initialparameters (rng, l. dense_i)
487
+ dense_j = LuxCore. initialparameters (rng, l. dense_j)
488
+ a = l. init_weight (out, l. heads)
489
+ ps = (; dense_i, dense_j, a)
490
+ if ein > 0
491
+ ps = (ps... , dense_e = LuxCore. initialparameters (rng, l. dense_e))
492
+ end
493
+ if l. use_bias
494
+ ps = (ps... , bias = l. init_bias (rng, l. concat ? out * l. heads : out))
495
+ end
496
+ return ps
497
+ end
498
+
499
+ (l:: GATv2Conv )(g, x, ps, st) = l (g, x, nothing , ps, st)
500
+
501
+ function (l:: GATv2Conv )(g, x, e, ps, st)
502
+ dense_i = StatefulLuxLayer {true} (l. dense_i, ps. dense_i, _getstate (st, :dense_i ))
503
+ dense_j = StatefulLuxLayer {true} (l. dense_j, ps. dense_j, _getstate (st, :dense_j ))
504
+ dense_e = l. dense_e === nothing ? nothing :
505
+ StatefulLuxLayer {true} (l. dense_e, ps. dense_e, _getstate (st, :dense_e ))
506
+
507
+ m = (; l. add_self_loops, l. channel, l. heads, l. concat, l. dropout, l. σ,
508
+ ps. a, bias = _getbias (ps), dense_i, dense_j, dense_e, l. negative_slope)
509
+ return GNNlib. gatv2_conv (m, g, x, e), st
510
+ end
511
+
512
+ function Base. show (io:: IO , l:: GATv2Conv )
513
+ (in, ein), out = l. channel
514
+ print (io, " GATv2Conv(" , ein == 0 ? in : (in, ein), " => " , out ÷ l. heads)
515
+ l. σ == identity || print (io, " , " , l. σ)
516
+ print (io, " , negative_slope=" , l. negative_slope)
517
+ print (io, " )" )
518
+ end
0 commit comments