Skip to content

Commit 1b1bce1

Browse files
committed
Adds more common initializations.
Note this also fixes Xavier to correspond to real Xavier, it previously didn't!
1 parent 7c91a4f commit 1b1bce1

File tree

8 files changed

+50
-8
lines changed

8 files changed

+50
-8
lines changed

beacon8/init/Normal.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import numpy as _np
2+
3+
4+
def normal(std):
5+
def init(shape, fan):
6+
return std*_np.random.randn(*shape)
7+
return init

beacon8/init/PReLU.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import numpy as _np
2+
from beacon8.init import xavier, xavierN
3+
4+
def prelu(gain=1):
5+
return xavier(gain * _np.sqrt(2))
6+
7+
def preluN(gain=1):
8+
return xavierN(gain * _np.sqrt(2))

beacon8/init/Uniform.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import numpy as _np
2+
3+
4+
def uniform(low, high):
5+
def init(shape, fan):
6+
return _np.random.uniform(low=low, high=high, size=shape)
7+
return init

beacon8/init/Xavier.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,24 @@
11
import numpy as _np
22

3-
def xavier(shape, fan):
4-
assert fan is not None, "The parameter's `fan` needs to be specified when using Xavier initialization."
3+
def xavier(gain=1):
4+
def init(shape, fan):
5+
assert fan is not None, "The parameter's `fan` needs to be specified when using Xavier initialization."
56

6-
w_bound = _np.sqrt(4. / sum(fan))
7-
return _np.random.uniform(low=-w_bound, high=w_bound, size=shape)
7+
fan_mean = _np.mean(fan)
8+
bound = _np.sqrt(3./fan_mean)
9+
return _np.random.uniform(low=-bound, high=bound, size=shape)
10+
return init
11+
12+
def xavierN(gain=1):
13+
def init(shape, fan):
14+
assert fan is not None, "The parameter's `fan` needs to be specified when using Xavier initialization."
15+
16+
fan_mean = _np.mean(fan)
17+
return _np.sqrt(1./fan_mean) * _np.random.randn(*shape)
18+
return init
19+
20+
def xavierSigm(gain=1):
21+
return xavier(gain * _np.sqrt(2))
22+
23+
def xavierSigmN(gain=1):
24+
return xavierN(gain * _np.sqrt(2))

beacon8/init/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
from .Const import const, zero, one
2-
from .Xavier import xavier
2+
from .Xavier import xavier, xavierN, xavierSigm, xavierSigmN
3+
from .PReLU import prelu, preluN
4+
from .Normal import normal
5+
from .Uniform import uniform

beacon8/layers/Linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class Linear(Module):
1010

11-
def __init__(self, nin, nout, init=xavier, with_bias=True, init_b=zero):
11+
def __init__(self, nin, nout, init=xavier(), with_bias=True, init_b=zero):
1212
Module.__init__(self)
1313

1414
self.nin = nin

beacon8/layers/SpatialConvolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class SpatialConvolution(Module):
10-
def __init__(self, n_input_plane, n_output_plane, k_w, k_h, d_w=1, d_h=1, with_bias=True, border_mode='valid', imshape=None, init=xavier, init_b=zero):
10+
def __init__(self, n_input_plane, n_output_plane, k_w, k_h, d_w=1, d_h=1, with_bias=True, border_mode='valid', imshape=None, init=xavier(), init_b=zero):
1111
Module.__init__(self)
1212
self.n_input_plane = n_input_plane
1313
self.n_output_plane = n_output_plane

beacon8/layers/SpatialConvolutionCUDNN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class SpatialConvolutionCUDNN(Module):
11-
def __init__(self, n_input_plane, n_output_plane, k_w, k_h, d_w=1, d_h=1, pad_w=0, pad_h=0, with_bias=True, init=xavier, init_b=zero):
11+
def __init__(self, n_input_plane, n_output_plane, k_w, k_h, d_w=1, d_h=1, pad_w=0, pad_h=0, with_bias=True, init=xavier(), init_b=zero):
1212
Module.__init__(self)
1313
self.n_input_plane = n_input_plane
1414
self.n_output_plane = n_output_plane

0 commit comments

Comments
 (0)