11# -*- coding: utf-8 -*-
22
3- """
4- adapted from jax.example_libraries.stax.BatchNorm
5- https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
6- """
7-
8-
93from typing import Union
104
115import jax .nn
@@ -29,14 +23,23 @@ class BatchNorm(Node):
2923 Most commonly, the first axis of the data is the batch, and the last is
3024 the channel. However, users can specify the axes to be normalized.
3125
26+ adapted from jax.example_libraries.stax.BatchNorm
27+ https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
28+
3229 Parameters
3330 ----------
34- axis: axes where the data will be normalized. The axis of channels should be excluded.
35- epsilon: a value added to the denominator for numerical stability. Default: 1e-5
36- translate: whether to translate data in refactoring
37- scale: whether to scale data in refactoring
38- beta_init: an initializer generating the original translation matrix
39- gamma_init: an initializer generating the original scaling matrix
31+ axis: int, tuple, list
32+ axes where the data will be normalized. The axis of channels should be excluded.
33+ epsilon: float
34+ a value added to the denominator for numerical stability. Default: 1e-5
35+ translate: bool
36+ whether to translate data in refactoring
37+ scale: bool
38+ whether to scale data in refactoring
39+ beta_init: brainpy.init.Initializer
40+ an initializer generating the original translation matrix
41+ gamma_init: brainpy.init.Initializer
42+ an initializer generating the original scaling matrix
4043 """
4144 def __init__ (self ,
4245 axis : Union [int , tuple , list ],
@@ -86,10 +89,14 @@ class BatchNorm1d(BatchNorm):
8689 axes where the data will be normalized. The axis of channels should be excluded.
8790 epsilon: float
8891 a value added to the denominator for numerical stability. Default: 1e-5
89- translate: whether to translate data in refactoring
90- scale: whether to scale data in refactoring
91- beta_init: an initializer generating the original translation matrix
92- gamma_init: an initializer generating the original scaling matrix
92+ translate: bool
93+ whether to translate data in refactoring
94+ scale: bool
95+ whether to scale data in refactoring
96+ beta_init: brainpy.init.Initializer
97+ an initializer generating the original translation matrix
98+ gamma_init: brainpy.init.Initializer
99+ an initializer generating the original scaling matrix
93100 """
94101 def __init__ (self , axis = (0 , 1 ), ** kwargs ):
95102 super (BatchNorm1d , self ).__init__ (axis = axis , ** kwargs )
@@ -138,20 +145,24 @@ def _check_input_dim(self):
138145
139146class BatchNorm3d (BatchNorm ):
140147 """3-D batch normalization.
141- The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
142- `h` is the height dimension, `w` is the width dimension, `d` is the depth
143- dimension, and `c` is the channel dimension.
144-
145- Parameters
146- ----------
147- axis: int, tuple, list
148- axes where the data will be normalized. The axis of channels should be excluded.
149- epsilon: float
150- a value added to the denominator for numerical stability. Default: 1e-5
151- translate: whether to translate data in refactoring
152- scale: whether to scale data in refactoring
153- beta_init: an initializer generating the original translation matrix
154- gamma_init: an initializer generating the original scaling matrix
148+ The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
149+ `h` is the height dimension, `w` is the width dimension, `d` is the depth
150+ dimension, and `c` is the channel dimension.
151+
152+ Parameters
153+ ----------
154+ axis: int, tuple, list
155+ axes where the data will be normalized. The axis of channels should be excluded.
156+ epsilon: float
157+ a value added to the denominator for numerical stability. Default: 1e-5
158+ translate: bool
159+ whether to translate data in refactoring
160+ scale: bool
161+ whether to scale data in refactoring
162+ beta_init: brainpy.init.Initializer
163+ an initializer generating the original translation matrix
164+ gamma_init: brainpy.init.Initializer
165+ an initializer generating the original scaling matrix
155166 """
156167 def __init__ (self , axis = (0 , 1 , 2 , 3 ), ** kwargs ):
157168 super (BatchNorm3d , self ).__init__ (axis = axis , ** kwargs )
0 commit comments