@@ -579,21 +579,23 @@ end
579
579
580
580
(l:: GATv2Conv )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g), edge_features (g)))
581
581
582
- function (l:: GATv2Conv )(g:: GNNGraph , x:: AbstractMatrix ,
582
+ function (l:: GATv2Conv )(g:: AbstractGNNGraph , x,
583
583
e:: Union{Nothing, AbstractMatrix} = nothing )
584
584
check_num_nodes (g, x)
585
585
@assert ! ((e === nothing ) && (l. dense_e != = nothing )) " Input edge features required for this layer"
586
586
@assert ! ((e != = nothing ) && (l. dense_e === nothing )) " Input edge features were not specified in the layer constructor"
587
587
588
+ xj, xi = expand_srcdst (g, x)
589
+
588
590
if l. add_self_loops
589
591
@assert e=== nothing " Using edge features and setting add_self_loops=true at the same time is not yet supported."
590
592
g = add_self_loops (g)
591
593
end
592
594
_, out = l. channel
593
595
heads = l. heads
594
596
595
- Wxi = reshape (l. dense_i (x ), out, heads, :) # out × heads × nnodes
596
- Wxj = reshape (l. dense_j (x ), out, heads, :) # out × heads × nnodes
597
+ Wxi = reshape (l. dense_i (xi ), out, heads, :) # out × heads × nnodes
598
+ Wxj = reshape (l. dense_j (xj ), out, heads, :) # out × heads × nnodes
597
599
598
600
m = apply_edges ((xi, xj, e) -> message (l, xi, xj, e), g, Wxi, Wxj, e)
599
601
α = softmax_edge_neighbors (g, m. logα)
0 commit comments