11import tensorflow as tf
22from kgcnn .layers .base import GraphBaseLayer
33from kgcnn .ops .axis import get_positive_axis
4+ from kgcnn .ops .segment import segment_ops_by_name
45ks = tf .keras
56
67
@@ -263,9 +264,9 @@ def get_config(self):
263264 return config
264265
265266
266- @ks .utils .register_keras_serializable (package = 'kgcnn' , name = 'GraphInstanceNormalization ' )
267- class GraphInstanceNormalization (GraphBaseLayer ):
268- r"""Graph instance normalization for (ragged) graph tensor objects.
267+ @ks .utils .register_keras_serializable (package = 'kgcnn' , name = 'GraphNormalization ' )
268+ class GraphNormalization (GraphBaseLayer ):
269+ r"""Graph normalization for (ragged) graph tensor objects.
269270
270271 Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
271272
@@ -277,6 +278,7 @@ class GraphInstanceNormalization(GraphBaseLayer):
277278
278279 \text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
279280
281+
280282 Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
281283 Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
282284 We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
@@ -287,14 +289,30 @@ class GraphInstanceNormalization(GraphBaseLayer):
287289 then applied to the feature values across all nodes for each
288290 individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
289291
290- """
292+ Additionally, the following proposed additions for GraphNorm are added when compared to InstanceNorm.
293+
294+ .. math::
291295
296+ \text{GraphNorm}(\hat{h}_{i,j}) = \gamma_j \cdot \frac{\hat{h}_{i,j} - \alpha_j \mu_j }{\hat{\sigma}_j}+\beta_j
297+
298+ where :math:`\mu_j = \frac{\sum^n_{i=1} hat{h}_{i,j}}{n}` ,
299+ :math:`\hat{\sigma}^2_j = \frac{\sum^n_{i=1} (hat{h}_{i,j} - \alpha_j \mu_j)^2}{n}` ,
300+ and :math:`\gamma_j` , :math:`beta_j` are the affine parameters as in other normalization methods.
301+
302+ .. code-block:: python
303+
304+ import tensorflow as tf
305+ from kgcnn.layers.norm import GraphNormalization
306+ layer = GraphNormalization()
307+ test = tf.ragged.constant([[[0.0, 0.0],[1.0, -1.0]],[[1.0, 1.0],[0.0, 0.0],[-2.0, -2.0]]], ragged_rank=1)
308+ print(layer(test))
309+
310+ """
292311 def __init__ (self ,
293- axis = None ,
294- epsilon = 1e-3 , center = True , scale = True ,
295- beta_initializer = 'zeros' , gamma_initializer = 'ones' ,
296- beta_regularizer = None , gamma_regularizer = None , beta_constraint = None ,
297- gamma_constraint = None ,
312+ mean_shift = True , epsilon = 1e-3 , center = True , scale = True ,
313+ beta_initializer = 'zeros' , gamma_initializer = 'ones' , alpha_initializer = 'ones' ,
314+ beta_regularizer = None , gamma_regularizer = None , alpha_regularizer = None ,
315+ beta_constraint = None , gamma_constraint = None , alpha_constraint = None ,
298316 ** kwargs ):
299317 r"""Initialize layer :obj:`GraphBatchNormalization`.
300318
@@ -307,6 +325,7 @@ def __init__(self,
307325 scale: If True, multiply by `gamma`. If False, `gamma` is not used.
308326 Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
309327 this can be disabled since the scaling will be done by the next layer.
328+ mean_shift (bool): Whether to apply alpha. Default is True.
310329 beta_initializer: Initializer for the beta weight. Defaults to zeros.
311330 gamma_initializer: Initializer for the gamma weight. Defaults to ones.
312331 beta_regularizer: Optional regularizer for the beta weight. None by default.
@@ -315,11 +334,84 @@ def __init__(self,
315334 gamma_constraint: Optional constraint for the gamma weight. None by default.
316335
317336 """
318- super (GraphInstanceNormalization , self ).__init__ (** kwargs )
337+ super (GraphNormalization , self ).__init__ (** kwargs )
338+ self .epsilon = epsilon
339+ self ._eps = tf .constant (epsilon , dtype = self .dtype )
340+ self .center = center
341+ self .mean_shift = mean_shift
342+ self .scale = scale
343+ self .beta_initializer = ks .initializers .get (beta_initializer )
344+ self .gamma_initializer = ks .initializers .get (gamma_initializer )
345+ self .alpha_initializer = ks .initializers .get (alpha_initializer )
346+ self .beta_regularizer = ks .regularizers .get (beta_regularizer )
347+ self .gamma_regularizer = ks .regularizers .get (gamma_regularizer )
348+ self .alpha_regularizer = ks .regularizers .get (alpha_regularizer )
349+ self .beta_constraint = ks .constraints .get (beta_constraint )
350+ self .gamma_constraint = ks .constraints .get (gamma_constraint )
351+ self .alpha_constraint = ks .constraints .get (alpha_constraint )
319352
320353 def build (self , input_shape ):
321354 """Build layer."""
322- super (GraphInstanceNormalization , self ).build (input_shape )
355+ super (GraphNormalization , self ).build (input_shape )
356+ param_shape = [x if x is not None else 1 for x in input_shape [2 :]]
357+ if self .scale :
358+ self .gamma = self .add_weight (
359+ name = "gamma" ,
360+ shape = param_shape ,
361+ initializer = self .gamma_initializer ,
362+ regularizer = self .gamma_regularizer ,
363+ constraint = self .gamma_constraint ,
364+ trainable = True ,
365+ experimental_autocast = False ,
366+ )
367+ else :
368+ self .gamma = None
369+
370+ if self .center :
371+ self .beta = self .add_weight (
372+ name = "beta" ,
373+ shape = param_shape ,
374+ initializer = self .beta_initializer ,
375+ regularizer = self .beta_regularizer ,
376+ constraint = self .beta_constraint ,
377+ trainable = True ,
378+ experimental_autocast = False ,
379+ )
380+ else :
381+ self .beta = None
382+
383+ if self .mean_shift :
384+ self .alpha = self .add_weight (
385+ name = "alpha" ,
386+ shape = param_shape ,
387+ initializer = self .alpha_initializer ,
388+ regularizer = self .alpha_regularizer ,
389+ constraint = self .alpha_constraint ,
390+ trainable = True ,
391+ experimental_autocast = False ,
392+ )
393+ else :
394+ self .alpha = None
395+
396+ self .built = True
397+
398+ def _ragged_mean_std (self , inputs ):
399+ # Here a segment operation for ragged_rank=1 tensors is used.
400+ # Alternative is to simply use tf.reduce_mean which should also work for latest tf-version.
401+ # Then tf.nn.moments could be used or similar tf implementation for variance and mean.
402+ values = inputs .values
403+ if values .dtype in ("float16" , "bfloat16" ) and self .dtype == "float32" :
404+ values = tf .cast (values , "float32" )
405+ mean = segment_ops_by_name ("mean" , values , inputs .value_rowids ())
406+ if self .mean_shift :
407+ mean = mean * tf .expand_dims (self .alpha , axis = 0 )
408+ mean = tf .repeat (mean , inputs .row_lengths (), axis = 0 )
409+ diff = values - tf .stop_gradient (mean )
410+ square_diff = tf .square (diff )
411+ variance = segment_ops_by_name ("mean" , square_diff , inputs .value_rowids ())
412+ std = tf .sqrt (variance + self ._eps )
413+ std = tf .repeat (std , inputs .row_lengths (), axis = 0 )
414+ return mean , std , diff / std
323415
324416 def call (self , inputs , ** kwargs ):
325417 """Forward pass.
@@ -330,18 +422,39 @@ def call(self, inputs, **kwargs):
330422 Returns:
331423 tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
332424 """
333- raise NotImplementedError ("Not yet implemented" )
425+ inputs = self .assert_ragged_input_rank (inputs , ragged_rank = 1 ) # Must have ragged_rank = 1.
426+ mean , std , new_values = self ._ragged_mean_std (inputs )
427+ # Recomputing diff.
428+ if self .scale :
429+ new_values = new_values * tf .expand_dims (self .gamma , axis = 0 )
430+ if self .center :
431+ new_values = new_values + self .beta
432+ return inputs .with_values (new_values )
334433
335434 def get_config (self ):
336435 """Get layer configuration."""
337- config = super (GraphInstanceNormalization , self ).get_config ()
338- config .update ({})
436+ config = super (GraphNormalization , self ).get_config ()
437+ config .update ({
438+ "mean_shift" : self .mean_shift ,
439+ "epsilon" : self .epsilon ,
440+ "center" : self .center ,
441+ "scale" : self .scale ,
442+ "beta_initializer" : ks .initializers .serialize (self .beta_initializer ),
443+ "gamma_initializer" : ks .initializers .serialize (self .gamma_initializer ),
444+ "alpha_initializer" : ks .initializers .serialize (self .alpha_initializer ),
445+ "beta_regularizer" : ks .regularizers .serialize (self .beta_regularizer ),
446+ "gamma_regularizer" : ks .regularizers .serialize (self .gamma_regularizer ),
447+ "alpha_regularizer" : ks .regularizers .serialize (self .alpha_regularizer ),
448+ "beta_constraint" : ks .constraints .serialize (self .beta_constraint ),
449+ "gamma_constraint" : ks .constraints .serialize (self .gamma_constraint ),
450+ "alpha_constraint" : ks .constraints .serialize (self .alpha_constraint ),
451+ })
339452 return config
340453
341454
342- @ks .utils .register_keras_serializable (package = 'kgcnn' , name = 'GraphNormalization ' )
343- class GraphNormalization ( GraphBaseLayer ):
344- r"""Graph normalization for (ragged) graph tensor objects.
455+ @ks .utils .register_keras_serializable (package = 'kgcnn' , name = 'GraphInstanceNormalization ' )
456+ class GraphInstanceNormalization ( GraphNormalization ):
457+ r"""Graph instance normalization for (ragged) graph tensor objects.
345458
346459 Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
347460
@@ -353,7 +466,6 @@ class GraphNormalization(GraphBaseLayer):
353466
354467 \text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
355468
356-
357469 Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
358470 Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
359471 We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
@@ -364,29 +476,20 @@ class GraphNormalization(GraphBaseLayer):
364476 then applied to the feature values across all nodes for each
365477 individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
366478
367- Additionally, the following proposed additions for GraphNorm are added when compared to InstanceNorm.
479+ .. code-block:: python
368480
369- .. math::
370-
371- \text{GraphNorm}(\hat{h}_{i,j}) = \gamma_j \cdot \frac{\hat{h}_{i,j} - \alpha_j \mu_j }{\hat{\sigma}_j}+\beta_j
372-
373- where :math:`\mu_j = \frac{\sum^n_{i=1} hat{h}_{i,j}}{n}` ,
374- :math:`\hat{\sigma}^2_j = \frac{\sum^n_{i=1} (hat{h}_{i,j} - \alpha_j \mu_j)^2}{n}` ,
375- and :math:`\gamma_j` , :math:`beta_j` are the affine parameters as in other normalization methods.
481+ import tensorflow as tf
482+ from kgcnn.layers.norm import GraphInstanceNormalization
483+ layer = GraphInstanceNormalization()
484+ test = tf.ragged.constant([[[0.0, 0.0],[1.0, -1.0]],[[1.0, 1.0],[0.0, 0.0],[-2.0, -2.0]]], ragged_rank=1)
485+ print(layer(test))
376486
377487 """
378- def __init__ (self ,
379- axis = None ,
380- epsilon = 1e-3 , center = True , scale = True ,
381- beta_initializer = 'zeros' , gamma_initializer = 'ones' ,
382- beta_regularizer = None , gamma_regularizer = None , beta_constraint = None ,
383- gamma_constraint = None ,
384- ** kwargs ):
488+
489+ def __init__ (self , ** kwargs ):
385490 r"""Initialize layer :obj:`GraphBatchNormalization`.
386491
387492 Args:
388- axis: Integer or List/Tuple. The axis or axes to normalize across in addition to graph instances.
389- This should be always > 1 or None. Default is None.
390493 epsilon: Small float added to variance to avoid dividing by zero. Defaults to 1e-3.
391494 center: If True, add offset of `beta` to normalized tensor. If False,
392495 `beta` is ignored. Defaults to True.
@@ -401,27 +504,4 @@ def __init__(self,
401504 gamma_constraint: Optional constraint for the gamma weight. None by default.
402505
403506 """
404- super (GraphNormalization , self ).__init__ (** kwargs )
405-
406- def build (self , input_shape ):
407- """Build layer."""
408- super (GraphNormalization , self ).build (input_shape )
409-
410- def call (self , inputs , ** kwargs ):
411- """Forward pass.
412-
413- Args:
414- inputs (tf.RaggedTensor, tf.Tensor): Node or edge embeddings of shape (batch, [M], F, ...)
415-
416- Returns:
417- tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
418- """
419- raise NotImplementedError ("Not yet implemented" )
420-
421- def get_config (self ):
422- """Get layer configuration."""
423- config = super (GraphNormalization , self ).get_config ()
424- config .update ({})
425- return config
426-
427-
507+ super (GraphInstanceNormalization , self ).__init__ (mean_shift = False , ** kwargs )
0 commit comments