@@ -1065,4 +1065,118 @@ function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
10651065 return x̄, ē
10661066end
10671067
1068+ @doc raw """
1069+ GMMConv((in, ein) => out, σ=identity; K=1, bias=true, init=glorot_uniform, residual=false)
1070+
1071+ Graph mixture model convolution layer from the paper [Geometric deep learning on graphs and manifolds using mixture model CNNs](https://arxiv.org/abs/1611.08402)
1072+ Performs the operation
1073+ ```math
1074+ \m athbf{x}_i' = \f rac{1}{|N(i)|} \s um_{j\i n N(i)}\f rac{1}{K}\s um_{k=1}^k \m athbf{w}_k(\m athbf{e}_{j\t o i}) \o dot \T heta_k \m athbf{x}_j
1075+ ```
1076+ where
1077+ ```math
1078+ w^a_{k}(e^a) = \e xp(\f rac{-1}{2}(e^a - \m u^a_k)^T (\S igma^{-1})^a_k(e^a - \m u^a_k))
1079+ ```
1080+ $\T heta_k$, $\m u^a_k$, $\S igma^{-1})^a_k$ are learnable parameters.
1081+
1082+ The input to the layer is a node feature array 'X' of size `(num_features, num_nodes)` and
1083+ edge pseudo-cordinate array 'U' of size `(num_features, num_edges)`
1084+
1085+ # Arguments
1086+
1087+ - `in`: Number of input node features.
1088+ - `ein`: Number of input edge features.
1089+ - `out`: Number of output features.
1090+ - `σ`: Activation function. Default `identity`.
1091+ - `K`: Number of kernels. Default `1`.
1092+ - `bias`: Add learnable bias. Default `true`.
1093+ - `init`: Weights' initializer. Default `glorot_uniform`.
1094+ - `residual`: Residual conncetion. Default `false`.
1095+
1096+ #Examples
1097+
1098+ ```julia
1099+ # create data
1100+ s = [1,1,2,3]
1101+ t = [2,3,1,1]
1102+ g = GNNGraph(s,t)
1103+ nin, ein, out, K = 4, 10, 7, 8
1104+ x = randn(Float32, nin, g.num_nodes)
1105+ e = randn(Float32, ein, g.num_edges)
1106+
1107+ # create layer
1108+ l = GMMConv((nin, ein) => out, K=K)
1109+
1110+ # forward pass
1111+ l(g, x, e)
1112+ ```
1113+ """
1114+ struct GMMConv{A<: AbstractMatrix , B, F} <: GNNLayer
1115+ mu:: A
1116+ sigma_inv:: A
1117+ bias:: B
1118+ σ:: F
1119+ ch:: Pair{NTuple{2,Int},Int}
1120+ K:: Int
1121+ dense_x:: Dense
1122+ residual:: Bool
1123+ end
1124+
1125+ @functor GMMConv
1126+
1127+ function GMMConv (ch:: Pair{NTuple{2,Int},Int} ,
1128+ σ= identity;
1129+ K:: Int = 1 ,
1130+ bias:: Bool = true ,
1131+ init= Flux. glorot_uniform,
1132+ residual= false )
1133+
1134+ (nin, ein), out = ch
1135+ mu = init (ein, K)
1136+ sigma_inv = init (ein, K)
1137+ b = bias ? Flux. create_bias (mu, true , out) : false
1138+ dense_x = Dense (nin, out* K, bias= false )
1139+ GMMConv (mu, sigma_inv, b, σ, ch, K, dense_x, residual)
1140+ end
1141+
1142+ function (l:: GMMConv )(g:: GNNGraph , x:: AbstractMatrix , e:: AbstractMatrix )
1143+ (nin, ein), out = l. ch # Notational Simplicity
1144+
1145+ @assert (ein == size (e)[1 ] && g. num_edges == size (e)[2 ]) " Pseudo-cordinate dimension is not equal to (ein,num_edge)"
1146+
1147+ num_edges = g. num_edges
1148+ w = reshape (e, (ein, 1 , num_edges))
1149+ mu = reshape (l. mu, (ein, l. K, 1 ))
1150+
1151+ w = @. ((w - mu)^ 2 ) / 2
1152+ w = w .* reshape (l. sigma_inv.^ 2 , (ein, l. K, 1 ))
1153+ w = exp .(sum (w, dims = 1 )) # (1, K, num_edge)
1154+
1155+ xj = reshape (l. dense_x (x), (out, l. K, :)) # (out, K, num_nodes)
10681156
1157+ m = propagate (e_mul_xj, g, mean, xj= xj, e= w)
1158+ m = dropdims (mean (m, dims= 2 ), dims= 2 ) # (out, num_nodes)
1159+
1160+ m = l. σ (m .+ l. bias)
1161+
1162+ if l. residual
1163+ if size (x, 1 ) == size (m, 1 )
1164+ m += x
1165+ else
1166+ @warn " Residual not applied : output feature is not equal to input_feature"
1167+ end
1168+ end
1169+
1170+ return m
1171+ end
1172+
1173+ (l:: GMMConv )(g:: GNNGraph ) = GNNGraph (g, ndata= l (g, node_features (g), edge_features (g)))
1174+
1175+ function Base. show (io:: IO , l:: GMMConv )
1176+ (nin, ein), out = l. ch
1177+ print (io, " GMMConv((" , nin, " ," , ein, " )=>" , out)
1178+ l. σ == identity || print (io, " , σ=" , l. dense_s. σ)
1179+ print (io, " , K=" , l. K)
1180+ l. residual== true || print (io, " , residual=" , l. residual)
1181+ print (io, " )" )
1182+ end
0 commit comments