@@ -1066,7 +1066,7 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
1066
1066
end
1067
1067
1068
1068
@doc raw """
1069
- GMMConv(in => out, n_kernel , e_dim, σ=identity; [init, bias])
1069
+ GMMConv(in => out, K , e_dim, σ=identity; [init, bias])
1070
1070
Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
1071
1071
Performs the operation
1072
1072
```math
@@ -1083,7 +1083,7 @@ edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1083
1083
# Arguments
1084
1084
- `in`: Number of input features.
1085
1085
- `out`: Number of output features.
1086
- - `n_kernel ` : Number of kernels.
1086
+ - `K ` : Number of kernels.
1087
1087
- `e_dim` : Dimensionality of pseudo coordinates.
1088
1088
- `σ`: Activation function. Default `identity`.
1089
1089
- `bias`: Add learnable bias. Default `true`.
@@ -1096,12 +1096,12 @@ edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1096
1096
s = [1,1,2,3]
1097
1097
t = [2,3,1,1]
1098
1098
g = GNNGraph(s,t)
1099
- in_feature, out_feature, n_k , e_dim = 4, 7, 8, 10
1099
+ in_feature, out_feature, K , e_dim = 4, 7, 8, 10
1100
1100
x = randn(in_feature, g.num_nodes)
1101
1101
e = randn(e_dim, g.num_edges)
1102
1102
1103
1103
# create layer
1104
- l = GMMConv(in_feature=>out_feature, n_k , e_dim)
1104
+ l = GMMConv(in_feature=>out_feature, K , e_dim)
1105
1105
1106
1106
# forward pass
1107
1107
l(g, x, e)
@@ -1114,42 +1114,41 @@ struct GMMConv{A<:AbstractMatrix, B, F} <:GNNLayer
1114
1114
bias:: B
1115
1115
σ:: F
1116
1116
ch:: Pair{Int, Int}
1117
- n_kernel :: Int
1117
+ K :: Int
1118
1118
e_dim:: Int
1119
1119
dense_x:: Dense
1120
1120
end
1121
1121
1122
- Flux . @functor GMMConv
1122
+ @functor GMMConv
1123
1123
1124
1124
function GMMConv (ch:: Pair{Int, Int} ,
1125
- n_kernel:: Int ,
1126
- e_dim:: Int ,
1127
1125
σ= identity;
1126
+ K:: Int = 1 ,
1127
+ e_dim:: Int = 1 ,
1128
1128
init= Flux. glorot_uniform,
1129
1129
bias:: Bool = true )
1130
1130
in, out = ch
1131
- mu = init (n_kernel , e_dim)
1132
- sigma_inv = init (n_kernel , e_dim)
1131
+ mu = init (K , e_dim)
1132
+ sigma_inv = init (K , e_dim)
1133
1133
b = bias ? Flux. create_bias (ones (out), true ) : false
1134
- dense_x = Dense (in, out* n_kernel , bias= false )
1135
- GMMConv (mu, sigma_inv, b, σ, ch, n_kernel , e_dim, dense_x)
1134
+ dense_x = Dense (in, out* K , bias= false )
1135
+ GMMConv (mu, sigma_inv, b, σ, ch, K , e_dim, dense_x)
1136
1136
end
1137
1137
1138
1138
function (l:: GMMConv )(g:: GNNGraph , x:: AbstractMatrix , u:: AbstractMatrix )
1139
1139
1140
-
1141
1140
@assert (l. e_dim == size (u)[1 ] && g. num_edges == size (u)[2 ]) " Pseudo-cordinate dim $(size (u)) does not match (e_dim=$(e_dim) ,num_edge=$(g. num_edges) )"
1142
1141
1143
1142
num_edges = g. num_edges
1144
1143
d = degree (g, dir= :in )
1145
1144
u = reshape (u, (l. e_dim, 1 , num_edges))
1146
- mu = reshape (l. mu, (l. e_dim, l. n_kernel , 1 ))
1145
+ mu = reshape (l. mu, (l. e_dim, l. K , 1 ))
1147
1146
1148
1147
e = - 0.5 * (u.- mu). ^ 2
1149
- e = e .* ((reshape (l. sigma_inv, (l. e_dim, l. n_kernel , 1 )).^ 2 ) )
1150
- e = exp .(sum (e, dims = 1 )) # (1, n_kernel , num_edge)
1148
+ e = e .* ((reshape (l. sigma_inv, (l. e_dim, l. K , 1 )).^ 2 ) )
1149
+ e = exp .(sum (e, dims = 1 )) # (1, K , num_edge)
1151
1150
1152
- xj = reshape (l. dense_x (x), (l. ch[2 ],l. n_kernel ,:)) # (out, n_kernel , num_nodes)
1151
+ xj = reshape (l. dense_x (x), (l. ch[2 ],l. K ,:)) # (out, K , num_nodes)
1153
1152
x = propagate (e_mul_xj, g, + , xj= xj, e= e)
1154
1153
x = dropdims (mean (x, dims= 2 ), dims= 2 ) # (out, num_nodes)
1155
1154
x = 1 / d .* x
@@ -1158,10 +1157,10 @@ function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, u::AbstractMatrix)
1158
1157
end
1159
1158
1160
1159
function Base. show (io:: IO , l:: GMMConv )
1161
- in, out, n_kernel , e_dim = l. ch[1 ], l. ch[2 ], l. n_kernel , l. e_dim
1160
+ in, out, K , e_dim = l. ch[1 ], l. ch[2 ], l. K , l. e_dim
1162
1161
print (io, " GMMConv(" , in, " => " , out)
1163
- print (io, " , n_kernel= " , n_kernel )
1164
- print (io, " , e_dim = " , e_dim)
1162
+ print (io, " , K= " , K )
1163
+ print (io, " , e_dim= " , e_dim)
1165
1164
print (io, " )" )
1166
1165
1167
1166
end
0 commit comments