Skip to content

Commit c285f1a

Browse files
committed
Allow passing scalar as initializer.
This is a shorthand and will instantiate a df.init.const with the scalar.
1 parent 58a8710 commit c285f1a

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

DeepFried2/Param.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ def __init__(self, shape, init, fan=None, name=None, learn=True, decay=True, dty
1010
self.fan = fan
1111
self.decay = decay
1212

13-
# Support a useful shortcut for initializing with an array-like:
14-
# TODO: It would be nicer to use Python's buffer-interface.
13+
# Support a couple useful shortcut for initializing:
1514
if hasattr(init, 'shape') and hasattr(init, 'dtype'):
15+
# TODO: It would be nicer to use Python's buffer-interface.
1616
self.init = df.init.array(init)
17+
elif _np.isscalar(init):
18+
self.init = df.init.const(init)
1719

1820
val = self.init(self.shape, self.fan).astype(dtype)
1921
self.param = df.th.shared(val, name=name, **kw)

0 commit comments

Comments
 (0)