Skip to content

Commit 9a4f4b7

Browse files
authored
Change Tuple in NamedTuple (#330)
1 parent e3019dd commit 9a4f4b7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/layers/conv.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix,
329329
end
330330

331331
@functor GATConv
332-
Flux.trainable(l::GATConv) = (l.dense_x, l.dense_e, l.bias, l.a)
332+
Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a)
333333

334334
GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)
335335

@@ -457,7 +457,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
457457
end
458458

459459
@functor GATv2Conv
460-
Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.dense_e, l.bias, l.a)
460+
Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a)
461461

462462
function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
463463
GATv2Conv((ch[1], 0) => ch[2], args...; kws...)
@@ -668,7 +668,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer
668668
end
669669

670670
@functor GINConv
671-
Flux.trainable(l::GINConv) = (l.nn,)
671+
Flux.trainable(l::GINConv) = (nn = l.nn,)
672672

673673
GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)
674674

@@ -1569,7 +1569,7 @@ end
15691569
@functor TransformerConv
15701570

15711571
function Flux.trainable(l::TransformerConv)
1572-
(l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2)
1572+
(W1 = l.W1, W2 = l.W2, W3 = l.W3, W4 = l.W4, W5 = l.W5, W6 = l.W6, FF = l.FF, BN1 = l.BN1, BN2 = l.BN2)
15731573
end
15741574

15751575
function TransformerConv(ch::Pair{Int, Int}, args...; kws...)

0 commit comments

Comments
 (0)