@@ -1112,39 +1112,37 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
1112
1112
sigma_inv:: A
1113
1113
bias:: B
1114
1114
σ:: F
1115
- ch:: Pair{Int, Int}
1115
+ ch:: Pair{NTuple{2, Int} , Int}
1116
1116
K:: Int
1117
- e_dim:: Int
1118
1117
dense_x:: Dense
1119
1118
end
1120
1119
1121
1120
@functor GMMConv
1122
1121
1123
- function GMMConv (ch:: Pair{Int, Int} ,
1122
+ function GMMConv (ch:: Pair{NTuple{2, Int} , Int} ,
1124
1123
σ= identity;
1125
1124
K:: Int = 1 ,
1126
- e_dim:: Int = 1 ,
1127
1125
init= Flux. glorot_uniform,
1128
1126
bias:: Bool = true )
1129
- in , out = ch
1130
- mu = init (K, e_dim )
1131
- sigma_inv = init (K, e_dim )
1127
+ (nin, ein) , out = ch
1128
+ mu = init (ein, K )
1129
+ sigma_inv = init (K, ein )
1132
1130
b = bias ? Flux. create_bias (ones (out), true ) : false
1133
1131
dense_x = Dense (in, out* K, bias= false )
1134
- GMMConv (mu, sigma_inv, b, σ, ch, K, e_dim, dense_x)
1132
+ GMMConv (mu, sigma_inv, b, σ, ch, K, dense_x)
1135
1133
end
1136
1134
1137
1135
function (l:: GMMConv )(g:: GNNGraph , x:: AbstractMatrix , u:: AbstractMatrix )
1138
1136
1139
- @assert (l. e_dim == size (u)[1 ] && g. num_edges == size (u)[2 ]) " Pseudo-cordinate dim $(size (u)) does not match (e_dim =$(e_dim ) ,num_edge=$(g. num_edges) )"
1137
+ @assert (l. ein == size (u)[1 ] && g. num_edges == size (u)[2 ]) " Pseudo-cordinate dim $(size (u)) does not match (ein =$(ein ) ,num_edge=$(g. num_edges) )"
1140
1138
1141
1139
num_edges = g. num_edges
1142
1140
d = degree (g, dir= :in )
1143
- u = reshape (u, (l. e_dim , 1 , num_edges))
1144
- mu = reshape (l. mu, (l. e_dim , l. K, 1 ))
1141
+ u = reshape (u, (l. ein , 1 , num_edges))
1142
+ mu = reshape (l. mu, (l. ein , l. K, 1 ))
1145
1143
1146
1144
e = - 0.5 * (u.- mu). ^ 2
1147
- e = e .* ((reshape (l. sigma_inv, (l. e_dim , l. K, 1 )).^ 2 ) )
1145
+ e = e .* ((reshape (l. sigma_inv, (l. ein , l. K, 1 )).^ 2 ) )
1148
1146
e = exp .(sum (e, dims = 1 )) # (1, K, num_edge)
1149
1147
1150
1148
xj = reshape (l. dense_x (x), (l. ch[2 ],l. K,:)) # (out, K, num_nodes)
@@ -1156,10 +1154,10 @@ function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
1156
1154
end
1157
1155
1158
1156
function Base. show (io:: IO , l:: GMMConv )
1159
- in, out, K, e_dim = l. ch[1 ], l. ch[2 ], l. K, l. e_dim
1157
+ in, out, K, ein = l. ch[1 ], l. ch[2 ], l. K, l. ein
1160
1158
print (io, " GMMConv(" , in, " => " , out)
1161
1159
print (io, " , K=" , K)
1162
- print (io, " , e_dim =" , e_dim )
1160
+ print (io, " , ein =" , ein )
1163
1161
print (io, " )" )
1164
1162
1165
1163
end
0 commit comments