Skip to content

Commit b835835

Browse files
committed
Fix init argument names for Ortho.
1 parent 93d7435 commit b835835

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

DeepFried2/init/Ortho.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def ortho_qr(gain=_np.sqrt(2)):
66
# tanh activations g > 1
77
# ReLU activations g = sqrt(2) (or greater)
88

9-
def init(shape, **_):
9+
def init(shape, fan):
1010
# Note that this is not strictly correct.
1111
#
1212
# What we'd really want is for an initialization which reuses ortho
@@ -31,7 +31,7 @@ def ortho_svd(gain=_np.sqrt(2)):
3131
# tanh activations g > 1
3232
# ReLU activations g = sqrt(2) (or greater)
3333

34-
def init(shape, **_):
34+
def init(shape, fan):
3535
flat = (shape[0], _np.prod(shape[1:]))
3636
u, _, v = _np.linalg.svd(_np.random.randn(*flat), full_matrices=False)
3737
w = u if u.shape == flat else v

0 commit comments

Comments
 (0)