Skip to content

Commit 94055c7

Browse files
committed
started with InstanceNorm and GraphNorm.
1 parent c8c586b commit 94055c7

File tree

2 files changed

+218
-13
lines changed

2 files changed

+218
-13
lines changed

kgcnn/layers/norm.py

Lines changed: 207 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,29 @@ class GraphLayerNormalization(GraphBaseLayer):
99
r"""Graph Layer normalization for (ragged) graph tensor objects.
1010
1111
Uses `ks.layers.LayerNormalization` on all node or edge features in a batch.
12-
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`_.
12+
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
1313
To this end, the (positive) :obj:`axis` parameter must be strictly > 0 and ideally > 1,
1414
since first two dimensions are flattened for normalization.
1515
16+
The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
17+
definition and description of <https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
18+
the first dimension.
19+
20+
.. math::
21+
22+
\text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
23+
24+
25+
Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
26+
Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
27+
We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
28+
the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
29+
:math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
30+
31+
To adapt Layer-Norm to GNNs, we view each node as a basic component, resembling words in a sentence, and apply
32+
normalization to all feature values across different dimensions of each node,
33+
i.e. , over dimension :math:`j` of :math:`\hat{h}_{i,j,g}` .
34+
1635
"""
1736

1837
def __init__(self,
@@ -115,10 +134,31 @@ class GraphBatchNormalization(GraphBaseLayer):
115134
r"""Graph batch normalization for (ragged) graph tensor objects.
116135
117136
Uses `ks.layers.BatchNormalization` on all node or edge features in a batch.
118-
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`_.
137+
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
119138
To this end, the (positive) :obj:`axis` parameter must be strictly > 0 and ideally > 1,
120139
since first two dimensions are flattened for normalization.
121140
141+
The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
142+
definition and description of `<https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
143+
the first dimension.
144+
145+
.. math::
146+
147+
\text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
148+
149+
150+
Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
151+
Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
152+
We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
153+
the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
154+
:math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
155+
156+
For BatchNorm, normalization and the computation of :math:`mu`
157+
and :math:`\sigma` are applied to all values in the same feature dimension
158+
across the nodes of all graphs in the batch as in
159+
`Xu et al. (2019) <https://openreview.net/forum?id=ryGs6iA5Km>`__ , i.e., over dimensions :math:`g`, :math:`i`
160+
of :math:`\hat{h}_{i,j,g}` .
161+
122162
"""
123163
def __init__(self,
124164
axis=-1,
@@ -149,6 +189,7 @@ def __init__(self,
149189
gamma_regularizer: Optional regularizer for the gamma weight.
150190
beta_constraint: Optional constraint for the beta weight.
151191
gamma_constraint: Optional constraint for the gamma weight.
192+
152193
"""
153194
super(GraphBatchNormalization, self).__init__(**kwargs)
154195
# The axis 0,1 are merged for ragged embedding input.
@@ -220,3 +261,167 @@ def get_config(self):
220261
config = super(GraphBatchNormalization, self).get_config()
221262
config.update({"axis": self.axis})
222263
return config
264+
265+
266+
@ks.utils.register_keras_serializable(package='kgcnn', name='GraphInstanceNormalization')
267+
class GraphInstanceNormalization(GraphBaseLayer):
268+
r"""Graph instance normalization for (ragged) graph tensor objects.
269+
270+
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
271+
272+
The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
273+
definition and description of `<https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
274+
the first dimension.
275+
276+
.. math::
277+
278+
\text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
279+
280+
Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
281+
Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
282+
We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
283+
the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
284+
:math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
285+
286+
For InstanceNorm, we regard each graph as an instance. The normalization is
287+
then applied to the feature values across all nodes for each
288+
individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
289+
290+
"""
291+
292+
def __init__(self,
293+
axis=None,
294+
epsilon=1e-3, center=True, scale=True,
295+
beta_initializer='zeros', gamma_initializer='ones',
296+
beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
297+
gamma_constraint=None,
298+
**kwargs):
299+
r"""Initialize layer :obj:`GraphBatchNormalization`.
300+
301+
Args:
302+
axis: Integer or List/Tuple. The axis or axes to normalize across in addition to graph instances.
303+
This should be always > 1 or None. Default is None.
304+
epsilon: Small float added to variance to avoid dividing by zero. Defaults to 1e-3.
305+
center: If True, add offset of `beta` to normalized tensor. If False,
306+
`beta` is ignored. Defaults to True.
307+
scale: If True, multiply by `gamma`. If False, `gamma` is not used.
308+
Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
309+
this can be disabled since the scaling will be done by the next layer.
310+
beta_initializer: Initializer for the beta weight. Defaults to zeros.
311+
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
312+
beta_regularizer: Optional regularizer for the beta weight. None by default.
313+
gamma_regularizer: Optional regularizer for the gamma weight. None by default.
314+
beta_constraint: Optional constraint for the beta weight. None by default.
315+
gamma_constraint: Optional constraint for the gamma weight. None by default.
316+
317+
"""
318+
super(GraphInstanceNormalization, self).__init__(**kwargs)
319+
320+
def build(self, input_shape):
321+
"""Build layer."""
322+
super(GraphInstanceNormalization, self).build(input_shape)
323+
324+
def call(self, inputs, **kwargs):
325+
"""Forward pass.
326+
327+
Args:
328+
inputs (tf.RaggedTensor, tf.Tensor): Node or edge embeddings of shape (batch, [M], F, ...)
329+
330+
Returns:
331+
tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
332+
"""
333+
raise NotImplementedError("Not yet implemented")
334+
335+
def get_config(self):
336+
"""Get layer configuration."""
337+
config = super(GraphInstanceNormalization, self).get_config()
338+
config.update({})
339+
return config
340+
341+
342+
@ks.utils.register_keras_serializable(package='kgcnn', name='GraphNormalization')
343+
class GraphNormalization(GraphBaseLayer):
344+
r"""Graph normalization for (ragged) graph tensor objects.
345+
346+
Following convention suggested by `GraphNorm: A Principled Approach (...) <https://arxiv.org/abs/2009.03294>`__ .
347+
348+
The definition of normalization terms for graph neural networks can be categorized as follows. Here we copy the
349+
definition and description of `<https://arxiv.org/abs/2009.03294>`_ . Note that for keras the batch dimension is
350+
the first dimension.
351+
352+
.. math::
353+
354+
\text{Norm}(\hat{h}_{i,j,g}) = \gamma \cdot \frac{\hat{h}_{i,j,g} - \mu}{\sigma} + \beta,
355+
356+
357+
Consider a batch of graphs :math:`{G_{1}, \dots , G_{b}}` where :math:`b` is the batch size.
358+
Let :math:`n_{g}` be the number of nodes in graph :math:`G_{g}` .
359+
We generally denote :math:`\hat{h}_{i,j,g}` as the inputs to the normalization module, e.g.,
360+
the :math:`j` -th feature value of node :math:`v_i` of graph :math:`G_{g}` ,
361+
:math:`i = 1, \dots , n_{g}` , :math:`j = 1, \dots , d` , :math:`g = 1, \dots , b` .
362+
363+
For InstanceNorm, we regard each graph as an instance. The normalization is
364+
then applied to the feature values across all nodes for each
365+
individual graph, i.e., over dimension :math:`i` of :math:`\hat{h}_{i,j,g}` .
366+
367+
Additionally, the following proposed additions for GraphNorm are added when compared to InstanceNorm.
368+
369+
.. math::
370+
371+
\text{GraphNorm}(\hat{h}_{i,j}) = \gamma_j \cdot \frac{\hat{h}_{i,j} - \alpha_j \mu_j }{\hat{\sigma}_j}+\beta_j
372+
373+
where :math:`\mu_j = \frac{\sum^n_{i=1} hat{h}_{i,j}}{n}` ,
374+
:math:`\hat{\sigma}^2_j = \frac{\sum^n_{i=1} (hat{h}_{i,j} - \alpha_j \mu_j)^2}{n}` ,
375+
and :math:`\gamma_j` , :math:`beta_j` are the affine parameters as in other normalization methods.
376+
377+
"""
378+
def __init__(self,
379+
axis=None,
380+
epsilon=1e-3, center=True, scale=True,
381+
beta_initializer='zeros', gamma_initializer='ones',
382+
beta_regularizer=None, gamma_regularizer=None, beta_constraint=None,
383+
gamma_constraint=None,
384+
**kwargs):
385+
r"""Initialize layer :obj:`GraphBatchNormalization`.
386+
387+
Args:
388+
axis: Integer or List/Tuple. The axis or axes to normalize across in addition to graph instances.
389+
This should be always > 1 or None. Default is None.
390+
epsilon: Small float added to variance to avoid dividing by zero. Defaults to 1e-3.
391+
center: If True, add offset of `beta` to normalized tensor. If False,
392+
`beta` is ignored. Defaults to True.
393+
scale: If True, multiply by `gamma`. If False, `gamma` is not used.
394+
Defaults to True. When the next layer is linear (also e.g. `nn.relu`),
395+
this can be disabled since the scaling will be done by the next layer.
396+
beta_initializer: Initializer for the beta weight. Defaults to zeros.
397+
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
398+
beta_regularizer: Optional regularizer for the beta weight. None by default.
399+
gamma_regularizer: Optional regularizer for the gamma weight. None by default.
400+
beta_constraint: Optional constraint for the beta weight. None by default.
401+
gamma_constraint: Optional constraint for the gamma weight. None by default.
402+
403+
"""
404+
super(GraphNormalization, self).__init__(**kwargs)
405+
406+
def build(self, input_shape):
407+
"""Build layer."""
408+
super(GraphNormalization, self).build(input_shape)
409+
410+
def call(self, inputs, **kwargs):
411+
"""Forward pass.
412+
413+
Args:
414+
inputs (tf.RaggedTensor, tf.Tensor): Node or edge embeddings of shape (batch, [M], F, ...)
415+
416+
Returns:
417+
tf.RaggedTensor: Normalized ragged tensor of identical shape (batch, [M], F, ...)
418+
"""
419+
raise NotImplementedError("Not yet implemented")
420+
421+
def get_config(self):
422+
"""Get layer configuration."""
423+
config = super(GraphNormalization, self).get_config()
424+
config.update({})
425+
return config
426+
427+

training/hyper/hyper_mp_is_metal.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
"cross_validation": {"class_name": "KFold",
3333
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
3434
"fit": {
35-
"batch_size": 32, "epochs": 1000, "validation_freq": 10, "verbose": 2,
35+
"batch_size": 32, "epochs": 100, "validation_freq": 10, "verbose": 2,
3636
"callbacks": [
3737
{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {
38-
"learning_rate_start": 0.0005, "learning_rate_stop": 0.5e-05, "epo_min": 100, "epo": 1000,
38+
"learning_rate_start": 0.0005, "learning_rate_stop": 0.5e-05, "epo_min": 100, "epo": 100,
3939
"verbose": 0}
4040
}
4141
]
@@ -97,10 +97,10 @@
9797
"cross_validation": {"class_name": "KFold",
9898
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
9999
"fit": {
100-
"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2,
100+
"batch_size": 32, "epochs": 80, "validation_freq": 10, "verbose": 2,
101101
"callbacks": [
102102
{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {
103-
"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 800,
103+
"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 80,
104104
"verbose": 0}
105105
}
106106
]
@@ -157,10 +157,10 @@
157157
"cross_validation": {"class_name": "KFold",
158158
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
159159
"fit": {
160-
"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2,
160+
"batch_size": 32, "epochs": 80, "validation_freq": 10, "verbose": 2,
161161
"callbacks": [
162162
{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {
163-
"learning_rate_start": 0.0001, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 800,
163+
"learning_rate_start": 0.0001, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 80,
164164
"verbose": 0}
165165
}
166166
]
@@ -221,7 +221,7 @@
221221
"cross_validation": {"class_name": "KFold",
222222
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
223223
"fit": {
224-
"batch_size": 16, "epochs": 780, "validation_freq": 10, "verbose": 2, "callbacks": [],
224+
"batch_size": 16, "epochs": 78, "validation_freq": 10, "verbose": 2, "callbacks": [],
225225
"validation_batch_size": 8
226226
},
227227
"compile": {
@@ -299,10 +299,10 @@
299299
"cross_validation": {"class_name": "KFold",
300300
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
301301
"fit": {
302-
"batch_size": 128, "epochs": 1000, "validation_freq": 10, "verbose": 2,
302+
"batch_size": 128, "epochs": 100, "validation_freq": 10, "verbose": 2,
303303
"callbacks": [
304304
{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {
305-
"learning_rate_start": 1e-03, "learning_rate_stop": 1e-05, "epo_min": 500, "epo": 1000,
305+
"learning_rate_start": 1e-03, "learning_rate_stop": 1e-05, "epo_min": 500, "epo": 100,
306306
"verbose": 0}
307307
}
308308
]
@@ -356,13 +356,13 @@
356356
"training": {
357357
"fit": {
358358
"batch_size": 64,
359-
"epochs": 800,
359+
"epochs": 80,
360360
"validation_freq": 1,
361361
"verbose": 2,
362362
"callbacks": [
363363
{
364364
"class_name": "kgcnn>LinearLearningRateScheduler", "config": {
365-
"learning_rate_start": 5e-04, "learning_rate_stop": 1e-05, "epo_min": 5, "epo": 800,
365+
"learning_rate_start": 5e-04, "learning_rate_stop": 1e-05, "epo_min": 5, "epo": 80,
366366
"verbose": 0}
367367
}
368368
]

0 commit comments

Comments
 (0)