Skip to content

Commit 92ff2ef

Browse files
committed
started with InstanceNorm and GraphNorm.
1 parent 94055c7 commit 92ff2ef

File tree

1 file changed

+139
-59
lines changed

1 file changed

+139
-59
lines changed

kgcnn/layers/norm.py

Lines changed: 139 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tensorflow as tf
22
from kgcnn.layers.base import GraphBaseLayer
33
from kgcnn.ops.axis import get_positive_axis
4+
from kgcnn.ops.segment import segment_ops_by_name
45
ks = tf.keras
56

67

@@ -263,9 +264,9 @@ def get_config(self):
263264
return config
264265

265266

266-
@ks.utils.register_keras_serializable(package='kgcnn', name='GraphInstanceNormalization')
267-
class GraphInstanceNormalization(GraphBaseLayer):
268-
r"""Graph instance normalization for (ragged) graph tensor objects.
267+
@ks.utils.register_keras_serializable(package='kgcnn', name='GraphNormalization')
268+
class GraphNormalization(GraphBaseLayer):
269+
r"""Graph normalization for (ragged) graph tensor objects.
269270
270271
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
271272
@@ -277,6 +278,7 @@ class GraphInstanceNormalization(GraphBaseLayer):
277278
278279
\text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
279280
281+
280282
Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
281283
Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
282284
We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
@@ -287,14 +289,30 @@ class GraphInstanceNormalization(GraphBaseLayer):
287289
then applied to the feature values across all nodes for each
288290
individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
289291
290-
"""
292+
Additionally, the following proposed additions for GraphNorm are added when compared to InstanceNorm.
293+
294+
.. math::
291295
296+
\text{GraphNorm}(\hat{h}_{i,j}) = \gamma_j \cdot \frac{\hat{h}_{i,j} - \alpha_j \mu_j }{\hat{\sigma}_j}+\beta_j
297+
298+
where :math:`\mu_j = \frac{\sum^n_{i=1} hat{h}_{i,j}}{n}` ,
299+
:math:`\hat{\sigma}^2_j = \frac{\sum^n_{i=1} (hat{h}_{i,j} - \alpha_j \mu_j)^2}{n}` ,
300+
and :math:`\gamma_j` , :math:`beta_j` are the affine parameters as in other normalization methods.
301+
302+
.. code-block:: python
303+
304+
import tensorflow as tf
305+
from kgcnn.layers.norm import GraphNormalization
306+
layer = GraphNormalization()
307+
test = tf.ragged.constant([[[0.0, 0.0],[1.0, -1.0]],[[1.0, 1.0],[0.0, 0.0],[-2.0, -2.0]]], ragged_rank=1)
308+
print(layer(test))
309+
310+
"""
292311
def __init__(self,
293-
axis=None,
294-
epsilon=1e-3, center=True, scale=True,
295-
beta_initializer='zeros', gamma_initializer='ones',
296-
beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
297-
gamma_constraint=None,
312+
mean_shift=True, epsilon=1e-3, center=True, scale=True,
313+
beta_initializer='zeros', gamma_initializer='ones', alpha_initializer='ones',
314+
beta_regularizer=None, gamma_regularizer=None, alpha_regularizer=None,
315+
beta_constraint=None, gamma_constraint=None, alpha_constraint=None,
298316
**kwargs):
299317
r"""Initialize layer :obj:`GraphBatchNormalization`.
300318
@@ -307,6 +325,7 @@ def __init__(self,
307325
scale: If True, multiply by `gamma`. If False, `gamma` is not used.
308326
Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
309327
this can be disabled since the scaling will be done by the next layer.
328+
mean_shift (bool): Whether to apply alpha. Default is True.
310329
beta_initializer: Initializer for the beta weight. Defaults to zeros.
311330
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
312331
beta_regularizer: Optional regularizer for the beta weight. None by default.
@@ -315,11 +334,84 @@ def __init__(self,
315334
gamma_constraint: Optional constraint for the gamma weight. None by default.
316335
317336
"""
318-
super(GraphInstanceNormalization, self).__init__(**kwargs)
337+
super(GraphNormalization, self).__init__(**kwargs)
338+
self.epsilon = epsilon
339+
self._eps = tf.constant(epsilon, dtype=self.dtype)
340+
self.center = center
341+
self.mean_shift = mean_shift
342+
self.scale = scale
343+
self.beta_initializer = ks.initializers.get(beta_initializer)
344+
self.gamma_initializer = ks.initializers.get(gamma_initializer)
345+
self.alpha_initializer = ks.initializers.get(alpha_initializer)
346+
self.beta_regularizer = ks.regularizers.get(beta_regularizer)
347+
self.gamma_regularizer = ks.regularizers.get(gamma_regularizer)
348+
self.alpha_regularizer = ks.regularizers.get(alpha_regularizer)
349+
self.beta_constraint = ks.constraints.get(beta_constraint)
350+
self.gamma_constraint = ks.constraints.get(gamma_constraint)
351+
self.alpha_constraint = ks.constraints.get(alpha_constraint)
319352

320353
def build(self, input_shape):
321354
"""Build layer."""
322-
super(GraphInstanceNormalization, self).build(input_shape)
355+
super(GraphNormalization, self).build(input_shape)
356+
param_shape = [x if x is not None else 1 for x in input_shape[2:]]
357+
if self.scale:
358+
self.gamma = self.add_weight(
359+
name="gamma",
360+
shape=param_shape,
361+
initializer=self.gamma_initializer,
362+
regularizer=self.gamma_regularizer,
363+
constraint=self.gamma_constraint,
364+
trainable=True,
365+
experimental_autocast=False,
366+
)
367+
else:
368+
self.gamma = None
369+
370+
if self.center:
371+
self.beta = self.add_weight(
372+
name="beta",
373+
shape=param_shape,
374+
initializer=self.beta_initializer,
375+
regularizer=self.beta_regularizer,
376+
constraint=self.beta_constraint,
377+
trainable=True,
378+
experimental_autocast=False,
379+
)
380+
else:
381+
self.beta = None
382+
383+
if self.mean_shift:
384+
self.alpha = self.add_weight(
385+
name="alpha",
386+
shape=param_shape,
387+
initializer=self.alpha_initializer,
388+
regularizer=self.alpha_regularizer,
389+
constraint=self.alpha_constraint,
390+
trainable=True,
391+
experimental_autocast=False,
392+
)
393+
else:
394+
self.alpha = None
395+
396+
self.built = True
397+
398+
def _ragged_mean_std(self, inputs):
399+
# Here a segment operation for ragged_rank=1 tensors is used.
400+
# Alternative is to simply use tf.reduce_mean which should also work for latest tf-version.
401+
# Then tf.nn.moments could be used or similar tf implementation for variance and mean.
402+
values = inputs.values
403+
if values.dtype in ("float16", "bfloat16") and self.dtype == "float32":
404+
values = tf.cast(values, "float32")
405+
mean = segment_ops_by_name("mean", values, inputs.value_rowids())
406+
if self.mean_shift:
407+
mean = mean * tf.expand_dims(self.alpha, axis=0)
408+
mean = tf.repeat(mean, inputs.row_lengths(), axis=0)
409+
diff = values - tf.stop_gradient(mean)
410+
square_diff = tf.square(diff)
411+
variance = segment_ops_by_name("mean", square_diff, inputs.value_rowids())
412+
std = tf.sqrt(variance + self._eps)
413+
std = tf.repeat(std, inputs.row_lengths(), axis=0)
414+
return mean, std, diff/std
323415

324416
def call(self, inputs, **kwargs):
325417
"""Forward pass.
@@ -330,18 +422,39 @@ def call(self, inputs, **kwargs):
330422
Returns:
331423
tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
332424
"""
333-
raise NotImplementedError("Not yet implemented")
425+
inputs = self.assert_ragged_input_rank(inputs, ragged_rank=1) # Must have ragged_rank = 1.
426+
mean, std, new_values = self._ragged_mean_std(inputs)
427+
# Recomputing diff.
428+
if self.scale:
429+
new_values = new_values * tf.expand_dims(self.gamma, axis=0)
430+
if self.center:
431+
new_values = new_values + self.beta
432+
return inputs.with_values(new_values)
334433

335434
def get_config(self):
336435
"""Get layer configuration."""
337-
config = super(GraphInstanceNormalization, self).get_config()
338-
config.update({})
436+
config = super(GraphNormalization, self).get_config()
437+
config.update({
438+
"mean_shift": self.mean_shift,
439+
"epsilon": self.epsilon,
440+
"center": self.center,
441+
"scale": self.scale,
442+
"beta_initializer": ks.initializers.serialize(self.beta_initializer),
443+
"gamma_initializer": ks.initializers.serialize(self.gamma_initializer),
444+
"alpha_initializer": ks.initializers.serialize(self.alpha_initializer),
445+
"beta_regularizer": ks.regularizers.serialize(self.beta_regularizer),
446+
"gamma_regularizer": ks.regularizers.serialize(self.gamma_regularizer),
447+
"alpha_regularizer": ks.regularizers.serialize(self.alpha_regularizer),
448+
"beta_constraint": ks.constraints.serialize(self.beta_constraint),
449+
"gamma_constraint": ks.constraints.serialize(self.gamma_constraint),
450+
"alpha_constraint": ks.constraints.serialize(self.alpha_constraint),
451+
})
339452
return config
340453

341454

342-
@ks.utils.register_keras_serializable(package='kgcnn', name='GraphNormalization')
343-
class GraphNormalization(GraphBaseLayer):
344-
r"""Graph normalization for (ragged) graph tensor objects.
455+
@ks.utils.register_keras_serializable(package='kgcnn', name='GraphInstanceNormalization')
456+
class GraphInstanceNormalization(GraphNormalization):
457+
r"""Graph instance normalization for (ragged) graph tensor objects.
345458
346459
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
347460
@@ -353,7 +466,6 @@ class GraphNormalization(GraphBaseLayer):
353466
354467
\text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
355468
356-
357469
Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
358470
Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
359471
We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
@@ -364,29 +476,20 @@ class GraphNormalization(GraphBaseLayer):
364476
then applied to the feature values across all nodes for each
365477
individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
366478
367-
Additionally, the following proposed additions for GraphNorm are added when compared to InstanceNorm.
479+
.. code-block:: python
368480
369-
.. math::
370-
371-
\text{GraphNorm}(\hat{h}_{i,j}) = \gamma_j \cdot \frac{\hat{h}_{i,j} - \alpha_j \mu_j }{\hat{\sigma}_j}+\beta_j
372-
373-
where :math:`\mu_j = \frac{\sum^n_{i=1} hat{h}_{i,j}}{n}` ,
374-
:math:`\hat{\sigma}^2_j = \frac{\sum^n_{i=1} (hat{h}_{i,j} - \alpha_j \mu_j)^2}{n}` ,
375-
and :math:`\gamma_j` , :math:`beta_j` are the affine parameters as in other normalization methods.
481+
import tensorflow as tf
482+
from kgcnn.layers.norm import GraphInstanceNormalization
483+
layer = GraphInstanceNormalization()
484+
test = tf.ragged.constant([[[0.0, 0.0],[1.0, -1.0]],[[1.0, 1.0],[0.0, 0.0],[-2.0, -2.0]]], ragged_rank=1)
485+
print(layer(test))
376486
377487
"""
378-
def __init__(self,
379-
axis=None,
380-
epsilon=1e-3, center=True, scale=True,
381-
beta_initializer='zeros', gamma_initializer='ones',
382-
beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
383-
gamma_constraint=None,
384-
**kwargs):
488+
489+
def __init__(self, **kwargs):
385490
r"""Initialize layer :obj:`GraphBatchNormalization`.
386491
387492
Args:
388-
axis: Integer or List/Tuple. The axis or axes to normalize across in addition to graph instances.
389-
This should be always > 1 or None. Default is None.
390493
epsilon: Small float added to variance to avoid dividing by zero. Defaults to 1e-3.
391494
center: If True, add offset of `beta` to normalized tensor. If False,
392495
`beta` is ignored. Defaults to True.
@@ -401,27 +504,4 @@ def __init__(self,
401504
gamma_constraint: Optional constraint for the gamma weight. None by default.
402505
403506
"""
404-
super(GraphNormalization, self).__init__(**kwargs)
405-
406-
def build(self, input_shape):
407-
"""Build layer."""
408-
super(GraphNormalization, self).build(input_shape)
409-
410-
def call(self, inputs, **kwargs):
411-
"""Forward pass.
412-
413-
Args:
414-
inputs (tf.RaggedTensor, tf.Tensor): Node or edge embeddings of shape (batch, [M], F, ...)
415-
416-
Returns:
417-
tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
418-
"""
419-
raise NotImplementedError("Not yet implemented")
420-
421-
def get_config(self):
422-
"""Get layer configuration."""
423-
config = super(GraphNormalization, self).get_config()
424-
config.update({})
425-
return config
426-
427-
507+
super(GraphInstanceNormalization, self).__init__(mean_shift=False, **kwargs)

0 commit comments

Comments
 (0)