Skip to content

Commit 99b176f

Browse files
committed
Add orthogonal initialization.
1 parent b00bea8 commit 99b176f

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

beacon8/init/Ortho.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as _np
2+
3+
def ortho_qr(gain=_np.sqrt(2)):
4+
# gain should be set based on the activation function:
5+
# linear activations g = 1 (or greater)
6+
# tanh activations g > 1
7+
# ReLU activations g = sqrt(2) (or greater)
8+
9+
def init(shape, **_):
10+
# Note that this is not strictly correct.
11+
#
12+
# What we'd really want is for an initialization which reuses ortho
13+
# matrices across layers, but we can't have that with the current arch:
14+
#
15+
# From A. Saxe's comment in https://plus.google.com/+SoumithChintala/posts/RZfdrRQWL6u
16+
# > This initialization uses orthogonal matrices, but there’s a bit of
17+
# > subtlety when it comes to undercomplete layers—basically you need to
18+
# > make sure that the paths from the input layer to output layer, through
19+
# > the bottleneck, are preserved. This is accomplished by reusing parts of
20+
# > the same orthogonal matrices across different layers of the network.
21+
flat = (shape[0], _np.prod(shape[1:]))
22+
q1, _ = _np.linalg.qr(_np.random.randn(flat[0], flat[0]))
23+
q2, _ = _np.linalg.qr(_np.random.randn(flat[1], flat[1]))
24+
w = _np.dot(q1[:,:min(flat)], q2[:min(flat),:])
25+
return gain * w.reshape(shape)
26+
return init
27+
28+
def ortho_svd(gain=_np.sqrt(2)):
29+
# gain should be set based on the activation function:
30+
# linear activations g = 1 (or greater)
31+
# tanh activations g > 1
32+
# ReLU activations g = sqrt(2) (or greater)
33+
34+
def init(shape, **_):
35+
flat = (shape[0], _np.prod(shape[1:]))
36+
u, _, v = _np.linalg.svd(_np.random.randn(*flat), full_matrices=False)
37+
w = u if u.shape == flat else v
38+
return gain * w.reshape(shape)
39+
return init

beacon8/init/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .PReLU import prelu, preluN
44
from .Normal import normal
55
from .Uniform import uniform
6+
from .Ortho import ortho_qr, ortho_svd

0 commit comments

Comments
 (0)