@@ -1452,3 +1452,209 @@ function Base.show(io::IO, l::EGNNConv)
1452
1452
end
1453
1453
print (io, " )" )
1454
1454
end
1455
+
1456
+
1457
+ @doc raw """
1458
+ TransformerConv((in, ein) => out; [heads, concat, init, add_self_loops, bias_qkv,
1459
+ bias_root, root_weight, gating, skip_connection, batch_norm, ff_channels]))
1460
+
1461
+ The transformer-like multi head attention convolutional operator from the
1462
+ [Masked Label Prediction: Unified Message Passing Model for Semi-Supervised
1463
+ Classification](https://arxiv.org/abs/2009.03509) paper, which also considers
1464
+ edge features.
1465
+ It further contains options to also be configured as the transformer-like convolutional operator from the
1466
+ [Attention, Learn to Solve Routing Problems!](https://arxiv.org/abs/1706.03762) paper,
1467
+ including a successive feed-forward network as well as skip layers and batch normalization.
1468
+
1469
+ The layer's basic forward pass is given by
1470
+ ```math
1471
+ x_i' = W_1x_i + \s um_{j\i n N(i)} \a lpha_{ij} (W_2 x_j + W_6e_{ij})
1472
+ ```
1473
+ where the attention scores are
1474
+ ```math
1475
+ \a lpha_{ij} = \m athrm{softmax}\l eft(\f rac{(W_3x_i)^T(W_4x_j+
1476
+ W_6e_{ij})}{\s qrt{d}}\r ight).
1477
+ ```
1478
+
1479
+ Optionally, a combination of the aggregated value with transformed root node features
1480
+ by a gating mechanism via
1481
+ ```math
1482
+ x'_i = \b eta_i W_1 x_i + (1 - \b eta_i) \u nderbrace{\l eft(\s um_{j \i n \m athcal{N}(i)}
1483
+ \a lpha_{i,j} W_2 x_j \r ight)}_{=m_i}
1484
+ ```
1485
+ with
1486
+ ```math
1487
+ \b eta_i = \t extrm{sigmoid}(W_5^{\t op} [ W_1 x_i, m_i, W_1 x_i - m_i ]).
1488
+ ```
1489
+ can be performed.
1490
+
1491
+ # Arguments
1492
+
1493
+ - `in`: Dimension of input features, which also corresponds to the dimension of
1494
+ the output features.
1495
+ - `ein`: Dimension of the edge features; if 0, no edge features will be used.
1496
+ - `out`: Dimension of the output.
1497
+ - `heads`: Number of heads in output. Default `1`.
1498
+ - `concat`: Concatenate layer output or not. If not, layer output is averaged
1499
+ over the heads. Default `true`.
1500
+ - `init`: Weight matrices' initializing function. Default `glorot_uniform`.
1501
+ - `add_self_loops`: Add self loops to the input graph. Default `false`.
1502
+ - `bias_qkv`: If set, bias is used in the key, query and value transformations for nodes.
1503
+ Default `true`.
1504
+ - `bias_root`: If set, the layer will also learn an additive bias for the root when root
1505
+ weight is used. Default `true`.
1506
+ - `root_weight`: If set, the layer will add the transformed root node features
1507
+ to the output. Default `true`.
1508
+ - `gating`: If set, will combine aggregation and transformed root node features by a
1509
+ gating mechanism. Default `false`.
1510
+ - `skip_connection`: If set, a skip connection will be made from the input and
1511
+ added to the output. Default `false`.
1512
+ - `batch_norm`: If set, a batch normalization will be applied to the output. Default `false`.
1513
+ - `ff_channels`: If positive, a feed-forward NN is appended, with the first having the given
1514
+ number of hidden nodes; this NN also gets a skip connection and batch normalization
1515
+ if the respective parameters are set. Default: `0`.
1516
+ """
1517
+ struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLayer
1518
+ W1:: TW1
1519
+ W2:: TW2
1520
+ W3:: TW3
1521
+ W4:: TW4
1522
+ W5:: TW5
1523
+ W6:: TW6
1524
+ FF:: TFF
1525
+ BN1:: TBN1
1526
+ BN2:: TBN2
1527
+ channels:: Pair{NTuple{2,Int},Int}
1528
+ heads:: Int
1529
+ add_self_loops:: Bool
1530
+ concat:: Bool
1531
+ skip_connection:: Bool
1532
+ sqrt_out:: Float32
1533
+ end
1534
+
1535
+ @functor TransformerConv
1536
+
1537
+ Flux. trainable (l:: TransformerConv ) = (l. W1, l. W2, l. W3, l. W4, l. W5, l. W6, l. FF, l. BN1, l. BN2)
1538
+
1539
+ TransformerConv (ch:: Pair{Int,Int} , args... ; kws... ) = TransformerConv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
1540
+
1541
+ function TransformerConv (ch:: Pair{NTuple{2, Int}, Int} ;
1542
+ heads:: Int = 1 ,
1543
+ concat:: Bool = true ,
1544
+ init = glorot_uniform,
1545
+ add_self_loops:: Bool = false ,
1546
+ bias_qkv = true ,
1547
+ bias_root:: Bool = true ,
1548
+ root_weight:: Bool = true ,
1549
+ gating:: Bool = false ,
1550
+ skip_connection:: Bool = false ,
1551
+ batch_norm:: Bool = false ,
1552
+ ff_channels:: Int = 0 )
1553
+
1554
+ (in, ein), out = ch
1555
+
1556
+ if add_self_loops
1557
+ @assert iszero (ein) " Using edge features and setting add_self_loops=true at the same time is not yet supported."
1558
+ end
1559
+
1560
+ W1 = root_weight ? Dense (in, out * (concat ? heads : 1 ); bias= bias_root, init= init) : nothing
1561
+ W2 = Dense (in => out* heads; bias= bias_qkv, init= init)
1562
+ W3 = Dense (in => out* heads; bias= bias_qkv, init= init)
1563
+ W4 = Dense (in => out* heads; bias= bias_qkv, init= init)
1564
+ out_mha = out * (concat ? heads : 1 )
1565
+ W5 = gating ? Dense (3 * out_mha => 1 , sigmoid; bias= false , init= init) : nothing
1566
+ W6 = ein > 0 ? Dense (ein => out* heads; bias= bias_qkv, init= init) : nothing
1567
+ FF = ff_channels > 0 ? Chain (
1568
+ Dense (out_mha => ff_channels, relu),
1569
+ Dense (ff_channels => out_mha)
1570
+ ) : nothing
1571
+ BN1 = batch_norm ? BatchNorm (out_mha) : nothing
1572
+ BN2 = (batch_norm && ff_channels > 0 ) ? BatchNorm (out_mha) : nothing
1573
+
1574
+ return TransformerConv (W1, W2, W3, W4, W5, W6, FF, BN1, BN2,
1575
+ ch, heads, add_self_loops, concat, skip_connection, Float32 (√ out))
1576
+ end
1577
+
1578
+ function (l:: TransformerConv )(g:: GNNGraph , x:: AbstractMatrix ,
1579
+ e:: Union{AbstractMatrix, Nothing} = nothing )
1580
+ check_num_nodes (g, x)
1581
+
1582
+ if l. add_self_loops
1583
+ g = add_self_loops (g)
1584
+ end
1585
+
1586
+ out = l. channels[2 ]
1587
+ heads = l. heads
1588
+ W1x = ! isnothing (l. W1) ? l. W1 (x) : nothing
1589
+ W2x = reshape (l. W2 (x), out, heads, :)
1590
+ W3x = reshape (l. W3 (x), out, heads, :)
1591
+ W4x = reshape (l. W4 (x), out, heads, :)
1592
+ W6e = ! isnothing (l. W6) ? reshape (l. W6 (e), out, heads, :) : nothing
1593
+
1594
+ m = apply_edges (message_uij, g, l; xi= (; W3x), xj= (; W4x), e= (; W6e))
1595
+ α = softmax_edge_neighbors (g, m)
1596
+ α_val = propagate (message_main, g, + , l; xi= (; W3x), xj= (; W2x), e= (; W6e, α))
1597
+
1598
+ h = α_val
1599
+ if l. concat
1600
+ h = reshape (h, out * heads, :) # concatenate heads
1601
+ else
1602
+ h = mean (h, dims= 2 ) # average heads
1603
+ h = reshape (h, out, :)
1604
+ end
1605
+
1606
+ if ! isnothing (W1x) # root_weight
1607
+ if ! isnothing (l. W5) # gating
1608
+ β = l. W5 (vcat (h, W1x, h .- W1x))
1609
+ h = β .* W1x + (1f0 .- β) .* h
1610
+ else
1611
+ h += W1x
1612
+ end
1613
+ end
1614
+
1615
+ if l. skip_connection
1616
+ @assert size (h, 1 ) == size (x, 1 ) " In-channels must correspond to out-channels * heads if skip_connection is used"
1617
+ h += x
1618
+ end
1619
+ if ! isnothing (l. BN1)
1620
+ h = l. BN1 (h)
1621
+ end
1622
+
1623
+ if ! isnothing (l. FF)
1624
+ h1 = h
1625
+ h = l. FF (h)
1626
+ if l. skip_connection
1627
+ h += h1
1628
+ end
1629
+ if ! isnothing (l. BN2)
1630
+ h = l. BN2 (h)
1631
+ end
1632
+ end
1633
+
1634
+ return h
1635
+ end
1636
+
1637
+ (l:: TransformerConv )(g:: GNNGraph ) = GNNGraph (g, ndata= l (g, node_features (g), edge_features (g)))
1638
+
1639
+ function message_uij (l:: TransformerConv , xi, xj, e)
1640
+ key = xj. W4x
1641
+ if ! isnothing (e. W6e)
1642
+ key += e. W6e
1643
+ end
1644
+ uij = sum (xi. W3x .* key, dims= 1 ) ./ l. sqrt_out
1645
+ return uij
1646
+ end
1647
+
1648
+ function message_main (l:: TransformerConv , xi, xj, e)
1649
+ val = xj. W2x
1650
+ if ! isnothing (e. W6e)
1651
+ val += e. W6e
1652
+ end
1653
+ return e. α .* val
1654
+ end
1655
+
1656
+ function Base. show (io:: IO , l:: TransformerConv )
1657
+ (in, ein), out = l. channels
1658
+ print (io, " TransformerConv(($in , $ein ) => $out , heads=$(l. heads) )" )
1659
+ end
1660
+
0 commit comments