@@ -9,10 +9,29 @@ class GraphLayerNormalization(GraphBaseLayer):
99 r"""Graph Layer normalization for (ragged) graph tensor objects.
1010
1111 Uses `ks.layers.LayerNormalization` on all node or edge features in a batch.
12- Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`_ .
12+ Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
1313 To this end, the (positive) :obj:`axis` parameter must be strictly > 0 and ideally > 1,
1414 since first two dimensions are flattened for normalization.
1515
16+ The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
17+ definition and description of <https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
18+ the first dimension.
19+
20+ .. math::
21+
22+ \text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
23+
24+
25+ Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
26+ Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
27+ We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
28+ the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
29+ :math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
30+
31+ To adapt Layer-Norm to GNNs, we view each node as a basic component, resembling words in a sentence, and apply
32+ normalization to all feature values across different dimensions of each node,
33+ i.e. , over dimension :math:`j` of :math:`\hat{h}_{i,j,g}` .
34+
1635 """
1736
1837 def __init__ (self ,
@@ -115,10 +134,31 @@ class GraphBatchNormalization(GraphBaseLayer):
115134 r"""Graph batch normalization for (ragged) graph tensor objects.
116135
117136 Uses `ks.layers.BatchNormalization` on all node or edge features in a batch.
118- Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`_ .
137+ Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
119138 To this end, the (positive) :obj:`axis` parameter must be strictly > 0 and ideally > 1,
120139 since first two dimensions are flattened for normalization.
121140
141+ The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
142+ definition and description of `<https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
143+ the first dimension.
144+
145+ .. math::
146+
147+ \text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
148+
149+
150+ Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
151+ Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
152+ We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
153+ the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
154+ :math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
155+
156+ For BatchNorm, normalization and the computation of :math:`mu`
157+ and :math:`\sigma` are applied to all values in the same feature dimension
158+ across the nodes of all graphs in the batch as in
159+ `Xu et al. (2019) <https://openreview.net/forum?id=ryGs6iA5Km>`__ , i.e., over dimensions :math:`g`, :math:`i`
160+ of :math:`\hat{h}_{i,j,g}` .
161+
122162 """
123163 def __init__ (self ,
124164 axis = - 1 ,
@@ -149,6 +189,7 @@ def __init__(self,
149189 gamma_regularizer: Optional regularizer for the gamma weight.
150190 beta_constraint: Optional constraint for the beta weight.
151191 gamma_constraint: Optional constraint for the gamma weight.
192+
152193 """
153194 super (GraphBatchNormalization , self ).__init__ (** kwargs )
154195 # The axis 0,1 are merged for ragged embedding input.
@@ -220,3 +261,167 @@ def get_config(self):
220261 config = super (GraphBatchNormalization , self ).get_config ()
221262 config .update ({"axis" : self .axis })
222263 return config
264+
265+
266+ @ks .utils .register_keras_serializable (package = 'kgcnn' , name = 'GraphInstanceNormalization' )
267+ class GraphInstanceNormalization (GraphBaseLayer ):
268+ r"""Graph instance normalization for (ragged) graph tensor objects.
269+
270+ Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
271+
272+ The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
273+ definition and description of `<https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
274+ the first dimension.
275+
276+ .. math::
277+
278+ \text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
279+
280+ Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
281+ Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
282+ We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
283+ the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
284+ :math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
285+
286+ For InstanceNorm, we regard each graph as an instance. The normalization is
287+ then applied to the feature values across all nodes for each
288+ individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
289+
290+ """
291+
292+ def __init__ (self ,
293+ axis = None ,
294+ epsilon = 1e-3 , center = True , scale = True ,
295+ beta_initializer = 'zeros' , gamma_initializer = 'ones' ,
296+ beta_regularizer = None , gamma_regularizer = None , beta_constraint = None ,
297+ gamma_constraint = None ,
298+ ** kwargs ):
299+ r"""Initialize layer :obj:`GraphBatchNormalization`.
300+
301+ Args:
302+ axis: Integer or List/Tuple. The axis or axes to normalize across in addition to graph instances.
303+ This should be always > 1 or None. Default is None.
304+ epsilon: Small float added to variance to avoid dividing by zero. Defaults to 1e-3.
305+ center: If True, add offset of `beta` to normalized tensor. If False,
306+ `beta` is ignored. Defaults to True.
307+ scale: If True, multiply by `gamma`. If False, `gamma` is not used.
308+ Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
309+ this can be disabled since the scaling will be done by the next layer.
310+ beta_initializer: Initializer for the beta weight. Defaults to zeros.
311+ gamma_initializer: Initializer for the gamma weight. Defaults to ones.
312+ beta_regularizer: Optional regularizer for the beta weight. None by default.
313+ gamma_regularizer: Optional regularizer for the gamma weight. None by default.
314+ beta_constraint: Optional constraint for the beta weight. None by default.
315+ gamma_constraint: Optional constraint for the gamma weight. None by default.
316+
317+ """
318+ super (GraphInstanceNormalization , self ).__init__ (** kwargs )
319+
320+ def build (self , input_shape ):
321+ """Build layer."""
322+ super (GraphInstanceNormalization , self ).build (input_shape )
323+
324+ def call (self , inputs , ** kwargs ):
325+ """Forward pass.
326+
327+ Args:
328+ inputs (tf.RaggedTensor, tf.Tensor): Node or edge embeddings of shape (batch, [M], F, ...)
329+
330+ Returns:
331+ tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
332+ """
333+ raise NotImplementedError ("Not yet implemented" )
334+
335+ def get_config (self ):
336+ """Get layer configuration."""
337+ config = super (GraphInstanceNormalization , self ).get_config ()
338+ config .update ({})
339+ return config
340+
341+
342+ @ks .utils .register_keras_serializable (package = 'kgcnn' , name = 'GraphNormalization' )
343+ class GraphNormalization (GraphBaseLayer ):
344+ r"""Graph normalization for (ragged) graph tensor objects.
345+
346+ Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
347+
348+ The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
349+ definition and description of `<https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
350+ the first dimension.
351+
352+ .. math::
353+
354+ \text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
355+
356+
357+ Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
358+ Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
359+ We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
360+ the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
361+ :math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
362+
363+ For InstanceNorm, we regard each graph as an instance. The normalization is
364+ then applied to the feature values across all nodes for each
365+ individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
366+
367+ Additionally, the following proposed additions for GraphNorm are added when compared to InstanceNorm.
368+
369+ .. math::
370+
371+ \text{GraphNorm}(\hat{h}_{i,j}) = \gamma_j \cdot \frac{\hat{h}_{i,j} - \alpha_j \mu_j }{\hat{\sigma}_j}+\beta_j
372+
373+ where :math:`\mu_j = \frac{\sum^n_{i=1} hat{h}_{i,j}}{n}` ,
374+ :math:`\hat{\sigma}^2_j = \frac{\sum^n_{i=1} (hat{h}_{i,j} - \alpha_j \mu_j)^2}{n}` ,
375+ and :math:`\gamma_j` , :math:`beta_j` are the affine parameters as in other normalization methods.
376+
377+ """
378+ def __init__ (self ,
379+ axis = None ,
380+ epsilon = 1e-3 , center = True , scale = True ,
381+ beta_initializer = 'zeros' , gamma_initializer = 'ones' ,
382+ beta_regularizer = None , gamma_regularizer = None , beta_constraint = None ,
383+ gamma_constraint = None ,
384+ ** kwargs ):
385+ r"""Initialize layer :obj:`GraphBatchNormalization`.
386+
387+ Args:
388+ axis: Integer or List/Tuple. The axis or axes to normalize across in addition to graph instances.
389+ This should be always > 1 or None. Default is None.
390+ epsilon: Small float added to variance to avoid dividing by zero. Defaults to 1e-3.
391+ center: If True, add offset of `beta` to normalized tensor. If False,
392+ `beta` is ignored. Defaults to True.
393+ scale: If True, multiply by `gamma`. If False, `gamma` is not used.
394+ Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
395+ this can be disabled since the scaling will be done by the next layer.
396+ beta_initializer: Initializer for the beta weight. Defaults to zeros.
397+ gamma_initializer: Initializer for the gamma weight. Defaults to ones.
398+ beta_regularizer: Optional regularizer for the beta weight. None by default.
399+ gamma_regularizer: Optional regularizer for the gamma weight. None by default.
400+ beta_constraint: Optional constraint for the beta weight. None by default.
401+ gamma_constraint: Optional constraint for the gamma weight. None by default.
402+
403+ """
404+ super (GraphNormalization , self ).__init__ (** kwargs )
405+
406+ def build (self , input_shape ):
407+ """Build layer."""
408+ super (GraphNormalization , self ).build (input_shape )
409+
410+ def call (self , inputs , ** kwargs ):
411+ """Forward pass.
412+
413+ Args:
414+ inputs (tf.RaggedTensor, tf.Tensor): Node or edge embeddings of shape (batch, [M], F, ...)
415+
416+ Returns:
417+ tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
418+ """
419+ raise NotImplementedError ("Not yet implemented" )
420+
421+ def get_config (self ):
422+ """Get layer configuration."""
423+ config = super (GraphNormalization , self ).get_config ()
424+ config .update ({})
425+ return config
426+
427+
0 commit comments