@@ -329,7 +329,7 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix,
329
329
end
330
330
331
331
@functor GATConv
332
- Flux. trainable (l:: GATConv ) = (l. dense_x, l. dense_e, l. bias, l. a)
332
+ Flux. trainable (l:: GATConv ) = (dense_x = l. dense_x, dense_e = l. dense_e, bias = l. bias, a = l. a)
333
333
334
334
GATConv (ch:: Pair{Int, Int} , args... ; kws... ) = GATConv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
335
335
@@ -457,7 +457,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
457
457
end
458
458
459
459
@functor GATv2Conv
460
- Flux. trainable (l:: GATv2Conv ) = (l. dense_i, l. dense_j, l. dense_e, l. bias, l. a)
460
+ Flux. trainable (l:: GATv2Conv ) = (dense_i = l. dense_i, dense_j = l. dense_j, dense_e = l. dense_e, bias = l. bias, a = l. a)
461
461
462
462
function GATv2Conv (ch:: Pair{Int, Int} , args... ; kws... )
463
463
GATv2Conv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
@@ -668,7 +668,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer
668
668
end
669
669
670
670
@functor GINConv
671
- Flux. trainable (l:: GINConv ) = (l. nn,)
671
+ Flux. trainable (l:: GINConv ) = (nn = l. nn,)
672
672
673
673
GINConv (nn, ϵ; aggr = + ) = GINConv (nn, ϵ, aggr)
674
674
@@ -1569,7 +1569,7 @@ end
1569
1569
@functor TransformerConv
1570
1570
1571
1571
function Flux. trainable (l:: TransformerConv )
1572
- (l. W1, l. W2, l. W3, l. W4, l. W5, l. W6, l. FF, l. BN1, l. BN2)
1572
+ (W1 = l. W1, W2 = l. W2, W3 = l. W3, W4 = l. W4, W5 = l. W5, W6 = l. W6, FF = l. FF, BN1 = l. BN1, BN2 = l. BN2)
1573
1573
end
1574
1574
1575
1575
function TransformerConv (ch:: Pair{Int, Int} , args... ; kws... )
0 commit comments